Pre/Post deep learning¶
Some deep learning workflows applied to physics contexts require the projection of fields defined on an unstructured mesh onto a rectilinear grid, and inversely.
Includes¶
[1]:
import pyvista
pyvista.global_theme._jupyter_backend = 'panel' # remove this line to get interactive 3D plots
import numpy as np
from Muscat.Bridges.PyVistaBridge import PlotMesh, MeshToPyVista
from Muscat.Helpers.IO.TemporaryDirectory import TemporaryDirectory
PREPROCESSING¶
Read the solution generated by a physical code¶
[2]:
from Muscat.IO import XdmfReader as XR
reader = XR.XdmfReader(filename = "PrePostDeepLearning_Input.xmf")
reader.Read()
grid = reader.xdmf.GetDomain(0).GetGrid(0)
# Read the mesh
uMesh = grid.GetSupport()
# Read the solution field 'U'
indexU = grid.GetPointFieldsNames().index("U")
U = grid.GetPointFields()[indexU][:,0:2]
uMesh.nodeFields["U"] = U
PlotMesh(uMesh,scalars="U", show_edges=True, cpos="xy", show_scalar_bar=False)
Create a rectilinear mesh of size 48*48 excluding the left part of the mesh (x<0)¶
[3]:
Nx = 48
Ny = 48
boundingMin, boundingMax = uMesh.ComputeBoundingBox()
Lx = boundingMax[0] - 0.
Ly = boundingMax[1] - boundingMin[1]
from Muscat.Containers.ConstantRectilinearMeshTools import CreateConstantRectilinearMesh
unstructuredRectMesh = CreateConstantRectilinearMesh(dimensions=[Nx,Ny], spacing=[Lx/(Nx-1), Ly/(Ny-1)], origin=[0., boundingMin[1]])
PlotMesh(unstructuredRectMesh, show_edges=True, cpos="xy")
Compute the projection operator from unstructured mesh to structured mesh¶
[4]:
from Muscat.FE.FETools import PrepareFEComputation
from Muscat.Containers.MeshFieldOperations import GetFieldTransferOp
from Muscat.FE.Fields.FEField import FEField
inputFEField = FEField(name="U",mesh=uMesh)
operator, status, entities = GetFieldTransferOp(inputFEField, unstructuredRectMesh.nodes, method = "Interp/Clamp", verbose=True)
# Compute the projected field on the structured mesh
projectedU = operator.dot(U)
unstructuredRectMesh.nodeFields["projectedU"] = projectedU
PlotMesh(unstructuredRectMesh,scalars="projectedU", show_edges=True, cpos="xy", show_scalar_bar=False)
Starting C++ FieldTransfer
Done c++ FieldTransfer
Export the structured mesh and projected field in xdmf format (in ASCII)¶
[5]:
# To visualize this xdmf file you can use ParaView (downloadable from https://www.paraview.org/)
from Muscat.IO import XdmfWriter as XW
writer = XW.XdmfWriter(TemporaryDirectory.GetTempPath() + 'PrePostDeepLearning_OutputI.xdmf')
writer.SetHdf5(False)
writer.Open()
writer.Write(unstructuredRectMesh,PointFields=[projectedU], PointFieldsNames=["U"])
writer.Close()
POSTPROCESSING¶
[6]:
# Compute the projection operator from the structured mesh to the unstructured mesh (inverse projection)
inputFEField = FEField(name="U",mesh=unstructuredRectMesh)
operator, status, entities = GetFieldTransferOp(inputFEField, uMesh.nodes, method = "Interp/Clamp", verbose=True)
# Compute the inverse-projected projected field on the unstructured mesh
inverseProjected_ProjectedU = operator.dot(projectedU)
uMesh.nodeFields["inverseProjected_ProjectedU"] = inverseProjected_ProjectedU
PlotMesh(uMesh,scalars="inverseProjected_ProjectedU", show_edges=False, cpos="xy", show_scalar_bar=False)
Starting C++ FieldTransfer
Done c++ FieldTransfer
CHECK ACCURACY¶
[7]:
# Create a clippedMesh from the unstructured mesh by removing the left part (x<0)
# (with field, inverse-projected projected field and difference between them)
from Muscat.Containers.Filters.FilterObjects import ElementFilter
from Muscat.FE.FETools import ComputePhiAtIntegPoint, ComputeL2ScalarProductMatrix
from Muscat.Containers.MeshInspectionTools import ExtractElementsByElementFilter
from Muscat.Containers.MeshModificationTools import CleanLonelyNodes
from Muscat.Containers.MeshFieldOperations import CopyFieldsFromOriginalMeshToTargetMesh
uMesh.nodeFields['delta'] = U - inverseProjected_ProjectedU
elFilter = ElementFilter(zone = lambda p: (-p[:,0]))
meshClipped = ExtractElementsByElementFilter(uMesh, elFilter)
CleanLonelyNodes(meshClipped)
CopyFieldsFromOriginalMeshToTargetMesh(uMesh, meshClipped)
nbeOfNodes = meshClipped.GetNumberOfNodes()
deltaClippedMesh = meshClipped.nodeFields['delta']
PlotMesh(meshClipped,scalars="delta", show_edges=False, cpos="xy", show_scalar_bar=True)
Computetion of the integral of the delta¶
[8]:
from Muscat.Helpers.TextFormatHelper import TFormat
from Muscat.Helpers.Timer import Timer
print("Compute the L2(Omega) norm of delta by applying numerical quadrature from Lagrange P1")
print("Finite element integration using three different methods")
#### Method 1 #######################################################################################################################
print(TFormat.Center(TFormat.InRed("Method 1: ")+TFormat.InBlue("'by hand'")))
timer = Timer("Duration of method 1")
#compute method 1 three times
for i in range(2):
timer.Start()
integrationWeights, phiAtIntegPoint = ComputePhiAtIntegPoint(meshClipped)
vectDeltaAtIntegPoints = np.empty((2,phiAtIntegPoint.shape[0]))
for i in range(2):
vectDeltaAtIntegPoints[i] = phiAtIntegPoint.dot(deltaClippedMesh[:,i])
normDeltaMethod1 = np.sqrt(np.einsum('ij,ij,j', vectDeltaAtIntegPoints, vectDeltaAtIntegPoints, integrationWeights, optimize = True))
timer.Stop()
print("norm(Delta) =", normDeltaMethod1)
############
#### Method 2 #######################################################################################################################
print(TFormat.Center(TFormat.InRed("Method 2: ")+TFormat.InBlue("using the mass matrix")))
timer = Timer("Duration of method 2").Start()
massMatrix = ComputeL2ScalarProductMatrix(meshClipped, 2)
normDeltaMethod2 = np.sqrt(np.dot(deltaClippedMesh.ravel(order='F'), massMatrix.dot(deltaClippedMesh.ravel(order='F'))))
timer.Stop()
print("norm(Delta) =", normDeltaMethod2)
#### Method 3 #######################################################################################################################
print(TFormat.Center(TFormat.InRed("Method 3: ")+TFormat.InBlue("using Muscat integration tool")))
from Muscat.FE.IntegrationTools import IntegrateField
timer = Timer("Duration of method 3").Start()
deltaSquare_U0 = FEField("delta_U0",meshClipped,data= deltaClippedMesh[:,0].flatten() )**2 + FEField("delta_U1",meshClipped,data= deltaClippedMesh[:,1].flatten() )**2
F = IntegrateField(deltaSquare_U0, ElementFilter(dimensionality=2))
timer.Stop()
print("norm(Delta) =", np.sqrt(F))
#### Method 4 #######################################################################################################################
print(TFormat.Center(TFormat.InRed("Method 3: ")+TFormat.InBlue("using the weak form engine")))
from Muscat.FE.Integration import IntegrateGeneral
from Muscat.FE.SymWeakForm import GetField, GetTestField
from Muscat.FE.Spaces.FESpaces import LagrangeSpaceP1, ConstantSpaceGlobal
from Muscat.FE.DofNumbering import ComputeDofNumbering
timer = Timer("Duration of method 4").Start()
U = GetField("U",2)
Tt = GetTestField("T",1)
wf = U.T*U*Tt
#meshClipped
field1 = FEField("U_0",mesh=meshClipped, data = deltaClippedMesh[:,0].flatten())
field2 = FEField("U_1",mesh=meshClipped, data = deltaClippedMesh[:,1].flatten())
numbering = ComputeDofNumbering(meshClipped,ConstantSpaceGlobal)
unknownField = FEField("T",mesh=meshClipped,space=ConstantSpaceGlobal,numbering=numbering)
elFilter = ElementFilter()
K, F = IntegrateGeneral(mesh=meshClipped,
wform=wf,
constants={},
fields=[field1,field2],
unknownFields=[unknownField],
integrationRuleName="LagrangeP1Quadrature",
elementFilter=elFilter)
normDeltaMethod3 = np.sqrt(F[0])
timer.Stop()
print("norm(Delta) =", normDeltaMethod2)
############
print(Timer.PrintTimes())
Compute the L2(Omega) norm of delta by applying numerical quadrature from Lagrange P1
Finite element integration using three different methods
******************** Method 1: 'by hand' *******************
norm(Delta) = 0.0001466583268059577
************** Method 2: using the mass matrix *************
norm(Delta) = 0.00014665832680595773
********** Method 3: using Muscat integration tool *********
norm(Delta) = 0.0001555718367085269
*********** Method 3: using the weak form engine ***********
norm(Delta) = 0.00014665832680595773
Duration of method 1: (2) : 6.878e+00 s (mean 3.439e+00 s/call)
Duration of method 2: (1) : 3.859e-02 s (mean 3.859e-02 s/call)
Duration of method 3: (1) : 3.744e-01 s (mean 3.744e-01 s/call)
Duration of method 4: (1) : 2.092e-02 s (mean 2.092e-02 s/call)