# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
from typing import Optional, List
import numpy as np
from Muscat.Types import MuscatIndex, MuscatFloat
from Muscat.MeshContainers.Mesh import Mesh
GlobalIdsFieldName = "GlobalIDs"
[docs]
class PartitionedMesh:
"""
Class to store a partitioned mesh.
Two fields are used to store the global numbering of elements and nodes. These fields are
stored in every mesh's nodeFields['GlobalIDs'] and elemFields['GlobalIDs'].
"""
def __init__(self) -> None:
super().__init__()
self.storage: List[Mesh] = []
[docs]
def AddMesh(self, mesh: Mesh, nodesGlobalIDs: Optional[np.ndarray] = None,
elementsGlobalIDs: Optional[np.ndarray] = None) -> None:
"""
Add a mesh to the PartitionedMesh collection.
Parameters
----------
mesh : Mesh
The mesh to be added.
nodesGlobalIDs : Optional[np.ndarray]
An array of node global IDs. If None, must already exist in mesh.nodeFields.
elementsGlobalIDs : Optional[np.ndarray]
An array of element global IDs. If None, must already exist in mesh.elemFields.
Raises
------
RuntimeError: If required global IDs are not provided and not found in the mesh.
"""
if nodesGlobalIDs is None:
if GlobalIdsFieldName not in mesh.nodeFields:
raise RuntimeError("globalNodesIds not provided and absent from nodeFields")
else:
mesh.nodeFields[GlobalIdsFieldName] = nodesGlobalIDs
if elementsGlobalIDs is None:
if GlobalIdsFieldName not in mesh.elemFields:
raise RuntimeError("globalElementIds not provided and absent from elemFields")
else:
mesh.elemFields[GlobalIdsFieldName] = elementsGlobalIDs
self.storage.append(mesh)
[docs]
def GetNumberOfNodes(self) -> int:
"""Return the total number of nodes in the partitioned mesh."""
return np.sum(np.fromiter((m.GetNumberOfNodes() for m in self.storage), dtype=int))
[docs]
def ComputeNumberOfGlobalNodes(self) -> int:
"""Return the maximum global node id found in all meshes."""
return np.max(np.max(m.nodeFields[GlobalIdsFieldName]) for m in self.storage)
[docs]
def GetNumberOfElements(self) -> int:
"""Return the total number of elements in the partitioned mesh."""
return np.sum(np.fromiter((m.GetNumberOfElements() for m in self.storage), dtype=int))
[docs]
def PrepareForOutput(self, verifyIntegrity: Optional[bool]=True ) -> None:
"""Prepare all meshes for output, optionally verifying their integrity."""
for mesh in self.storage:
mesh.PrepareForOutput(verifyIntegrity)
[docs]
def GetMonolithicMesh(self) -> Mesh:
"""Return a mesh composed of all partitioned meshes merged together."""
resultMesh = Mesh()
for mesh in self.storage:
resultMesh.Merge(mesh)
return resultMesh
def __str__(self) -> str:
from Muscat.Helpers.TextFormatHelper import TFormat as TF
result = ""
for i, mesh in enumerate(self.storage):
result += TF.Center(f"Part {i}") + "\n"
result += mesh.__str__()
return result
[docs]
def GenerateGlobalIdsBasedOnNodePositions(self):
"""
Assigns unique global IDs to elements and nodes based on node positions.
"""
elementCpt = 0
for i, mesh in enumerate(self.storage):
nbElements = mesh.GetNumberOfElements()
mesh.elemFields[GlobalIdsFieldName] = np.arange(elementCpt, elementCpt + nbElements, dtype=MuscatIndex)
elementCpt += nbElements
allSurfaceNodes = np.empty((self.GetNumberOfNodes(), self.storage[0].GetPointsDimensionality()), dtype=MuscatFloat)
allSurfaceIds = np.empty(self.GetNumberOfNodes(), dtype=MuscatIndex)
nodesCpt = 0
offset = 0
from scipy.spatial import KDTree
bbmin, bbmax = self.storage[0].ComputeBoundingBox()
tol = np.linalg.norm(bbmax - bbmin) * 1e-15
for i, mesh in enumerate(self.storage):
surfaceIds = mesh.nodesTags["SurfaceIds"].GetIds()
surfaceNodes = mesh.nodes[surfaceIds, :]
kdtree = KDTree(allSurfaceNodes[0:nodesCpt, :])
# Query all points on the surface with respect to the skin of the other
d, idx = kdtree.query(surfaceNodes, k=1, distance_upper_bound=tol, workers=-1)
# mask to keep track of doubles nodes
# true, new node, false duplicated node
mask = np.ones(mesh.GetNumberOfNodes(), dtype=bool)
# if point is found in kdtree is not a new point
mask[surfaceIds[d != np.inf]] = False
#number of new point to numbering
nnp = np.sum(mask)
#initialization
globalIds = np.zeros(mesh.GetNumberOfNodes(), dtype=MuscatIndex) - 1
# numbering of new nodes
globalIds[mask] = np.arange(offset, offset + nnp)
#id of duplicated points
globalIds[surfaceIds[d != np.inf]] = allSurfaceIds[idx[d != np.inf]]
allSurfaceNodes[nodesCpt:nodesCpt + len(surfaceIds), :] = surfaceNodes
allSurfaceIds[nodesCpt:nodesCpt + len(surfaceIds)] = globalIds[surfaceIds]
#keep track of number of surface nodes
nodesCpt += len(surfaceIds)
#keep track of offset for the new points
offset += nnp
mesh.nodeFields[GlobalIdsFieldName] = globalIds
[docs]
def CheckIntegrity(GUI: bool = False) -> str:
"""
Checks core integrity of PartitionedMesh operations.
Runs several assertions:
- Adds mesh with proper 'GlobalIDs' fields
- Adds mesh with explicit global IDs arrays
- Checks responses for missing fields (should raise exception)
- Checks that node and element counts are preserved
Parameters
----------
GUI : bool
Not used; for compatibility.
Returns
-------
str
"OK" if all tests pass, otherwise returns an error message
"""
from Muscat.MeshContainers.MeshGraphTools import PartitionMesh
from Muscat.MeshContainers.Filters.FilterObjects import ElementFilter
from Muscat.MeshTools.MeshInspectionTools import ExtractElementsByElementFilter
from Muscat.MeshTools.MeshCreationTools import CreateCube
# Create a test mesh and filter to 3D elements
myMesh = CreateCube(dimensions=[10, 10, 10], spacing=[1, 1, 1], origin=[0, 0, 0], ofTetras=False)
myMesh = ExtractElementsByElementFilter(myMesh, ElementFilter(dimensionality=3))
partitions = PartitionMesh(myMesh, 4)
# Assign global IDs to this mesh
myMesh.nodeFields[GlobalIdsFieldName] = np.arange(myMesh.GetNumberOfNodes())
myMesh.elemFields[GlobalIdsFieldName] = np.arange(myMesh.GetNumberOfElements())
pm = PartitionedMesh()
# Test: Add mesh with IDs inside the mesh (should succeed)
try:
pm.AddMesh(myMesh)
except Exception as e:
return f"FAILED (unexpected error adding mesh with fields): {repr(e)}"
# Test: Add with explicit IDs passed (should succeed)
try:
pm.AddMesh(
myMesh,
nodesGlobalIDs=myMesh.nodeFields[GlobalIdsFieldName],
elementsGlobalIDs=myMesh.elemFields[GlobalIdsFieldName]
)
except Exception as e:
return f"FAILED (unexpected error adding mesh with explicit IDs): {repr(e)}"
nNodes = pm.GetNumberOfNodes()
nElements = pm.GetNumberOfElements()
if nNodes != 2 * myMesh.GetNumberOfNodes():
return "FAILED: Node count mismatch in partitioned mesh"
if nElements != 2 * myMesh.GetNumberOfElements():
return "FAILED: Element count mismatch in partitioned mesh"
# Test: Add mesh with missing IDs
from Muscat.MeshTools.MeshCreationTools import CreateCube
badMesh = CreateCube(dimensions=[2, 2, 2], spacing=[1, 1, 1], origin=[0, 0, 0], ofTetras=False)
badPm = PartitionedMesh()
# Should raise for missing nodeFields
try:
badPm.AddMesh(badMesh)
return "FAILED: Adding mesh with missing nodeFields did not raise"
except RuntimeError:
pass # Expected
# Add nodeFields, but not elemFields: should raise
badMesh.nodeFields[GlobalIdsFieldName] = np.arange(badMesh.GetNumberOfNodes())
try:
badPm.AddMesh(badMesh)
return "FAILED: Adding mesh with missing elemFields did not raise"
except RuntimeError:
pass # Expected
# Now supply both: should succeed
badMesh.elemFields[GlobalIdsFieldName] = np.arange(badMesh.GetNumberOfElements())
try:
badPm.AddMesh(badMesh)
except Exception as e:
return f"FAILED (unexpected error adding mesh after correcting fields): {repr(e)}"
return "OK"
if __name__ == '__main__':
print(CheckIntegrity(True))