# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
from Muscat.FE.WeakForms.ASTWeakForm import ASTWeakForm
import sympy
from sympy import Expr, nsimplify, Function
from typing import List
import Muscat.FE.WeakForms.RPNWeakForm as rpnwf
from Muscat.FE.SymWeakForm import GetTestSufixChar
[docs]
def getPartialTrees(expr: Expr, testFields: List, knownFields: List[str], ipFields: List[str], constants) -> tuple[List[ASTWeakForm],List]:
trees = []
testFunctions = []
testChar = GetTestSufixChar()
def addTestFunctions(subexp):
if hasattr(subexp, "args"):
for _, arg in enumerate(subexp.args):
addTestFunctions(arg)
if isinstance(subexp, Function) and str(subexp.func)[-1] == testChar:
if subexp not in testFunctions:
testFunctions.append(subexp)
expr = nsimplify(expr.expand())
addTestFunctions(expr)
##reorder testFunctions using test Fields
testFunctionsOrdered=[]
for tfield in testFields:
testFunctionsOrdered.append(next(x for x in testFunctions if str(x.func) == tfield.name) )
for testFn in testFunctionsOrdered:
diffExpr = expr.diff(testFn)
testTree = getTree(diffExpr, knownFields, ipFields, constants=constants)
trees.append(testTree)
return trees, testFunctionsOrdered
[docs]
def getTree(
expr: sympy.Expr, knownFields: List[str], ipFields: List[str], constants) -> ASTWeakForm:
_, rpnList = rpnwf.SymWeakToRPN(expr, knownFields, ipFields, constants=constants)
operands = []
for rpnElem in rpnList:
curType = rpnElem.type
curAst = ASTWeakForm(rpnElem)
if curType.isBinaryOperator():
assert len(operands) >= 2
curAst.leftChild = operands.pop()
curAst.rightChild = operands.pop() # TODO check if needed (eg sqrt)
elif curType.isUnaryOperator():
assert len(operands) >= 1
curAst.leftChild = operands.pop()
operands.append(curAst)
assert len(operands) == 1
return operands.pop()
[docs]
def leftBalance(tree: ASTWeakForm) -> int:
if tree is None:
return 0
depthLeft = leftBalance(tree.leftChild)
depthRight = leftBalance(tree.rightChild)
if depthLeft < depthRight and tree.op.isCommutative():
tree.swapChildren()
return depthLeft + 1
[docs]
def amountTemporaryValues(tree: ASTWeakForm):
if tree is None:
return 0
cur = 0
if (
tree.leftChild is not None
and tree.rightChild is not None
and tree.leftChild.op.isOperator()
and tree.rightChild.op.isOperator()
):
cur = cur + 1
return cur + amountTemporaryValues(tree.leftChild) + amountTemporaryValues(tree.rightChild)
[docs]
def getOptimizedTree(exp, knownFields=[], ipFields=[], constants={}):
tree = getTree(exp,knownFields=knownFields,ipFields=ipFields,constants=constants)
leftBalance(tree)
return amountTemporaryValues(tree), tree
[docs]
def CheckIntegrity(GUI: bool = False):
from Muscat.Helpers.CheckTools import RunListOfCheckIntegrities
return RunListOfCheckIntegrities([CheckASTWeakForm, CheckDiffWeakForm])
if __name__ == "__main__":
print(CheckIntegrity(GUI=True)) # pragma: no cover