Source code for Muscat.Helpers.SymExpr

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
from typing import List, Dict,Optional
import numpy as np
from sympy import lambdify
from sympy import symbols as getsymbols

from Muscat.Types import ArrayLike


[docs] class SymExprBase(): """Store read from string and store a symbolic expression. the expression read using sympy. the first and the second derivative are automatically computed. """ def __init__(self, string:Optional[str]=None, symbols=None, derivatives: bool = True) -> None: """create and set a symbolic expression. A list of symbols (list[str]) can be provided to determine the free variable of the expression. if no symbols are provided the symbol 't' (for time) with value 0.0 is defined by default. """ super().__init__() self.derivatives = derivatives self._expression = "" """The string representation of the expression. """ self.constants = {} """Constants used to evaluate the expression. (free variables). """ if symbols is None: self.SetConstant("t", 0.0) else: for s in symbols: self.SetConstant(s, 0.0) if string is not None: self.SetExpression(string)
[docs] def SetExpression(self, string: str, symbols: List[str] = None) -> None: """Set the expression to be used. A list of symbols (list[str]) can be provided to determine the free variable of the expression. if no symbols are provided the symbol 't' (for time) with value 0.0 is defined by default. Parameters ---------- string : str the string representation of the expression to be parsed by sympy symbols : list[str], optional List of symbol to be used to parse the expression, by default None """ _symbols = symbols from sympy.parsing.sympy_parser import parse_expr if _symbols is None: self.stringSymbols = list(self.constants.keys()) self._expression = string self.sympy_expr = parse_expr(self._expression) self.func = lambdify(getsymbols(self.stringSymbols), self.sympy_expr) if self.derivatives: self.__ComputeDerivatives()
def __ComputeDerivatives(self) -> None: from sympy import Symbol self.dFund = dict() self.d2Fund2 = dict() for s in self.stringSymbols: self.dFund[s] = lambdify(getsymbols(self.stringSymbols), self.sympy_expr.diff(Symbol(s))) for s2 in self.stringSymbols: self.d2Fund2[(s, s2)] = lambdify(getsymbols(self.stringSymbols), self.sympy_expr.diff(Symbol(s)).diff(Symbol(s2)))
[docs] def SetConstant(self, name: str, value: np.number) -> None: """Add/Set the value of the free variables of the expression Parameters ---------- name : str name of the variable value : np.number value """ self.constants[name] = value
[docs] def GetValue(self) -> np.number: """Return the evaluated expression Parameters ---------- Returns ------- np.number the evaluated expression """ return self.func(**self.constants)
[docs] def GetValueDerivative(self, symbol: str) -> np.number: """Return the first derivative of the expression with respect to symbol Parameters ---------- symbol : str the name of the variable to be used by the derivative pos : _type_, optional Not Used, by default None Returns ------- np.number the evaluation of the derivative of the expression with respect to symbol. """ return self.dFund[symbol](**self.constants)
[docs] def GetValueSecondDerivative(self, symbol1: str, symbol2: str) -> np.number: """Return the second derivative of the expression with respect to symbol1 and symbol2 Parameters ---------- symbol1 : str the name of the variable to be used by the derivative symbol2 : str the name of the variable to be used by the derivative pos : _type_, optional Not Used, by default None Returns ------- np.number the evaluation of the derivative of the expression with respect to symbol1 and symbol2. d2/(dSymbol1 * dSymbol2) * expr """ return self.d2Fund2[(symbol1, symbol2)](**self.constants)
def __call__(self) -> np.number: """Wrapper for the GetValue. Return the evaluated expression Parameters ---------- Returns ------- np.number the evaluated expression """ return self.GetValue()
[docs] class SymExprWithPos(SymExprBase): """Store read from string and store a symbolic expression depending implicitly on (x,y,z). the expression read using sympy. the first and the second derivative are automatically computed. """ def __init__(self, string:Optional[str]=None, symbols=None, derivatives: bool = True) -> None: super().__init__(string=string, symbols=symbols, derivatives=derivatives)
[docs] def SetExpression(self, string:str) -> None: self.stringSymbols = list(self.constants.keys()) self.stringSymbols.extend("xyz") super().SetExpression(string, self.stringSymbols)
[docs] def GetValue(self, pos: ArrayLike) -> np.ndarray: res = self.func(x=pos[:, 0], y=pos[:, 1], z=pos[:, 2], **self.constants) if type(res) == np.ndarray: return res else: return np.full((pos.shape[0],), fill_value=res)
[docs] def GetValueDerivative(self, symbol: str, pos: ArrayLike) -> np.ndarray: res = self.dFund[symbol](x=pos[:, 0], y=pos[:, 1], z=pos[:, 2], **self.constants) if type(res) == np.ndarray: return res else: return np.full((pos.shape[0],), fill_value=res)
[docs] def GetValueSecondDerivative(self, symbol1: str, symbol2: str, pos: ArrayLike) -> np.ndarray: res = self.d2Fund2[(symbol1, symbol2)](x=pos[:, 0], y=pos[:, 1], z=pos[:, 2], **self.constants) if type(res) == np.ndarray: return res else: return np.full((pos.shape[0],), fill_value=res)
def __str__(self) -> str: res = f"SymExprWithPos('{self._expression}') " return res
[docs] def CreateSymExprWithPos(ops: Dict) -> SymExprWithPos: """Simple wrapper to create a SymExprWithPos from a dict. ["val"] is used to extract the expression to be used for the construction of the SymExprWithPos Parameters ---------- ops : Dict _description_ Returns ------- SymExprWithPos the Symbolic expression dependent of the position """ sym = SymExprWithPos() sym.SetExpression(ops["val"]) return sym
[docs] def CheckIntegrity(GUI: bool = False): obj = SymExprBase("2*t") obj = SymExprBase("2*x+3*v*t", symbols=["x", "v", "t"]) print(obj) print("f = ", obj._expression) obj.SetConstant("x", 10) obj.SetConstant("v", -7) obj.SetConstant("t", 1) np.testing.assert_equal(obj.GetValue(), -1) np.testing.assert_equal(obj(), -1) np.testing.assert_equal(obj.GetValueDerivative("x"), 2) np.testing.assert_equal(obj.GetValueDerivative("t"), -21) np.testing.assert_equal(obj.GetValueDerivative("v"), 3) np.testing.assert_equal(obj.GetValueSecondDerivative("v", "t"), 3) string = """<Pressure eTag="ET2" val="sin(3*t)+x**2" />""" import xml.etree.ElementTree as ET root = ET.fromstring(string) data = root.attrib data.pop("id", None) obj = CreateSymExprWithPos(data) obj.SetConstant("t", 3.14159/6.) print(obj) print("data : ") data = np.array([[100.0, 0.1, 0.2], [0, 0.1, 0.2]]) print(data) print("f = ", obj._expression) print(obj.GetValue(data)) print(obj.GetValue(data)) print("df/dx :") print(obj.GetValueDerivative("x", data)) print("df/dt :") c = obj.GetValueDerivative("t", data) print(c) print("d2f/dx2 :") print(obj.GetValueSecondDerivative("t", "t", data)) print(obj.GetValueSecondDerivative("x", "x", data)) # more complicated case sym = SymExprWithPos() sym.SetExpression("2") # a constant field over the points data = np.array([[1, 3, 5], [2, 4, 6]]) np.testing.assert_equal(sym.GetValue(data), [2, 2]) sym.SetExpression("2*x**3") # a constant field over the points np.testing.assert_equal(sym.GetValue(data), [2, 16]) np.testing.assert_equal(sym.GetValueDerivative("x", data), [6, 24]) np.testing.assert_equal(sym.GetValueSecondDerivative("x", "x", data), [12, 24]) return 'OK'
if __name__ == '__main__': print(CheckIntegrity(GUI=True)) # pragma: no coverage