Source code for Muscat.FE.WeakForms.ASTOperations

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
from Muscat.FE.WeakForms.ASTWeakForm import ASTWeakForm
import sympy
from sympy import Expr, nsimplify, Function

from typing import List
import Muscat.FE.WeakForms.RPNWeakForm as rpnwf
from Muscat.FE.SymWeakForm import GetTestSufixChar


[docs] def getPartialTrees(expr: Expr, testFields: List, knownFields: List[str], ipFields: List[str], constants) -> tuple[List[ASTWeakForm],List]: trees = [] testFunctions = [] testChar = GetTestSufixChar() def addTestFunctions(subexp): if hasattr(subexp, "args"): for _, arg in enumerate(subexp.args): addTestFunctions(arg) if isinstance(subexp, Function) and str(subexp.func)[-1] == testChar: if subexp not in testFunctions: testFunctions.append(subexp) expr = nsimplify(expr.expand()) addTestFunctions(expr) ##reorder testFunctions using test Fields testFunctionsOrdered=[] for tfield in testFields: testFunctionsOrdered.append(next(x for x in testFunctions if str(x.func) == tfield.name) ) for testFn in testFunctionsOrdered: diffExpr = expr.diff(testFn) testTree = getTree(diffExpr, knownFields, ipFields, constants=constants) trees.append(testTree) return trees, testFunctionsOrdered
[docs] def getTree( expr: sympy.Expr, knownFields: List[str], ipFields: List[str], constants) -> ASTWeakForm: _, rpnList = rpnwf.SymWeakToRPN(expr, knownFields, ipFields, constants=constants) operands = [] for rpnElem in rpnList: curType = rpnElem.type curAst = ASTWeakForm(rpnElem) if curType.isBinaryOperator(): assert len(operands) >= 2 curAst.leftChild = operands.pop() curAst.rightChild = operands.pop() # TODO check if needed (eg sqrt) elif curType.isUnaryOperator(): assert len(operands) >= 1 curAst.leftChild = operands.pop() operands.append(curAst) assert len(operands) == 1 return operands.pop()
[docs] def leftBalance(tree: ASTWeakForm) -> int: if tree is None: return 0 depthLeft = leftBalance(tree.leftChild) depthRight = leftBalance(tree.rightChild) if depthLeft < depthRight and tree.op.isCommutative(): tree.swapChildren() return depthLeft + 1
[docs] def amountTemporaryValues(tree: ASTWeakForm): if tree is None: return 0 cur = 0 if ( tree.leftChild is not None and tree.rightChild is not None and tree.leftChild.op.isOperator() and tree.rightChild.op.isOperator() ): cur = cur + 1 return cur + amountTemporaryValues(tree.leftChild) + amountTemporaryValues(tree.rightChild)
[docs] def getOptimizedTree(exp, knownFields=[], ipFields=[], constants={}): tree = getTree(exp,knownFields=knownFields,ipFields=ipFields,constants=constants) leftBalance(tree) return amountTemporaryValues(tree), tree
[docs] def CheckASTWeakForm(GUI: bool = False) -> str: from Muscat.FE.SymWeakForm import ( Symbol, GetField, GetTestField, Strain, ToVoigtEpsilon, ) import numpy as np from sympy import nsimplify u = GetField("u", 1,sdim=3) ut = GetTestField("u",1,sdim=3) a = Symbol("alpha") b = Symbol("Beta") g = Symbol("Gamma") m = Symbol("Mu") wfToTest = [ (a, 0), (a + b, 0), (a + b + g, 0), (a * b + g, 0), (a * b + g * g, 1), (a * b + g**0.5, 1), (a * b / (g**0.5 + 5 + m), 0), (a * b / (g**0.5 + 5 + m) + a * b, 1), (u[0], 0), (u[0].diff(Symbol("x")), 0), (u[0].diff(Symbol("x")) * ut[0].diff(Symbol("x")), 0), ] from sympy import pprint for wf, expectedTmpVals in wfToTest: print("=== TEST CASE ==================================================================") print(wf) tmpVals, tree = getOptimizedTree(wf, constants= {'alpha':1,"Beta":2,"Gamma":3,"Mu":0.3}) print("GOT", tmpVals, "EXPECTED", expectedTmpVals) np.testing.assert_equal(tmpVals, expectedTmpVals) t = getTree(wf,knownFields=[],ipFields=[],constants={'alpha':1,"Beta":2,"Gamma":3,"Mu":0.3}) return "OK"
[docs] def CheckDiffWeakForm(GUI: bool = False) -> str: from Muscat.FE.SymWeakForm import ( Symbol, GetField, GetTestField, Strain, ToVoigtEpsilon, GetTestSufixChar, ) import numpy as np from sympy import nsimplify u = GetField("u",1,3) ut = GetTestField("u",1,3) m = Symbol("Mu") wf = m * u * ut print("================================================================") print(f"{wf=}") trees, testFunctionsOrdered = getPartialTrees(m * u * ut, [ut[0]],[],[],constants={"Mu":0.3}) print(trees) print(trees[0]) print("-*"*32) np.testing.assert_equal(len(trees) , 1) print(trees[0]) np.testing.assert_equal( trees[0].getDepth() , 2) np.testing.assert_equal( amountTemporaryValues(trees[0]), 0) print("Done CheckDiffWeakForm") return "OK"
[docs] def CheckIntegrity(GUI: bool = False): from Muscat.Helpers.CheckTools import RunListOfCheckIntegrities return RunListOfCheckIntegrities([CheckASTWeakForm, CheckDiffWeakForm])
if __name__ == "__main__": print(CheckIntegrity(GUI=True)) # pragma: no cover