# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
import numpy as np
import scipy.sparse as sps
import scipy.sparse.linalg as spslin
from Muscat.Types import MuscatFloat, MuscatIndex
from Muscat.Helpers.Logger import Debug, Info
from Muscat.Helpers.TextFormatHelper import TFormat as TF
from Muscat.LinAlg.ConstraintsHolder import ConstraintsHolder
from Muscat.Helpers.Factory import Factory
from Muscat.Helpers.Logger import Warning
defaultSolver = "EigenCG"
defaultIfError = "CG"
[docs]
class SolverFactory(Factory):
_Catalog = {}
_SetCatalog = set()
def __init__(self):
super().__init__()
[docs]
def GetAvailableSolvers():
return list(SolverFactory._Catalog.keys())
[docs]
def RegisterSolverClass(name, classtype, constructor=None, withError = True):
return SolverFactory.RegisterClass(name,classtype, constructor=constructor, withError = withError )
[docs]
def RegisterSolverClassUsingName(cls):
RegisterSolverClass(cls().name, cls)
[docs]
class LinearSolverBase():
def __init__(self):
super().__init__()
self.op = None # the operator to solve
self.originalOp = None
self.solver = None
self.name = ''
self.u = None
self.constraints = ConstraintsHolder()
self._can_use_u0 = False
[docs]
def GetConstraints(self) -> ConstraintsHolder:
return self.constraints
[docs]
def SetConstraints(self, constraints:ConstraintsHolder):
self.constraints = constraints
[docs]
def HasConstraints(self)-> bool:
return self.constraints.numberOfEquations > 0
[docs]
def GetNumberOfDofs(self) -> MuscatIndex:
return self.op.shape[0]
[docs]
def ComputeProjector(self, op ):
Info(" With constraints (" +str(self.constraints.numberOfEquations) + ")")
Info(" Treating constraints using "+ str(type(self.constraints.method)) )
self.constraints.SetNumberOfDofs(op.shape[1])
Debug(" Setting Op" )
self.constraints.SetOp( op )
Debug(" UpdateCSystem()" )
self.constraints.UpdateCSystem()
Debug(" GetCOp ")
op = self.constraints.GetCOp()
Info('Constraints treatment Done ')
return op
[docs]
def SetOp(self, op):
Info('In SetOp (type:' +str(self.name) + ')')
self.originalOp = op
if self.HasConstraints():
op = self.ComputeProjector(op)
self.op = op
Debug('going in LinearSolver._setop_imp(op)')
self._setop_imp(op)
Debug('In LinearSolver.SetOp(...) Done')
[docs]
def Solve(self, rhs, u0=None):
if self.HasConstraints():
rhs = self.constraints.GetCRhs(rhs.squeeze())
rhs = np.atleast_1d(rhs.squeeze())
if self.u is not None and self.originalOp is not None and len(self.u) != self.originalOp.shape[1]:
self.u = None
Debug('In LinearProblem Solve ' + self.name)
if self._can_use_u0 :
if self.HasConstraints():
if u0 is not None:
u0 = self.constraints.RestrictSolution(u0)
else:
if self.u is not None:
u0 = self.constraints.RestrictSolution(self.u)
if u0 is None:
u0 = np.zeros_like(rhs)
u0 = np.atleast_1d(u0.squeeze())
u = self._solve_imp(rhs,u0=u0)
else:
if u0 is not None and self.__print_warning_u0_ignored:
print("u0 ignored for direct solvers")
self.__print_warning_u0_ignored = False
u = self._solve_imp(rhs,u0=None)
Debug("Done Linear solver "+str(u.shape))
if self.HasConstraints():
self.u = self.constraints.RestoreSolution(u)
else:
self.u = u
return self.u
def _setop_imp(self,op):
pass
[docs]
class LinearSolverIterativeBase(LinearSolverBase):
def __init__(self):
super().__init__()
self.tol = 1.e-6
self._can_use_u0 = True
[docs]
def SetTolerance(self,tol):
self.tol = tol
[docs]
class LinearSolverCG(LinearSolverIterativeBase):
def __init__(self):
super().__init__()
self.name = "CG"
def _solve_imp(self, rhs, u0):
diag = self.op.diagonal()
diag[diag == 0] = 1.0
M = sps.dia_matrix((1./diag,0), shape=self.op.shape)
norm = np.linalg.norm(rhs)
if u0 is None:
res = spslin.cg(self.op, rhs/norm, M = M, rtol = self.tol, atol = self.tol)
else:
res = spslin.cg(self.op, rhs/norm, M = M, x0 = u0/norm, rtol = self.tol, atol = self.tol)
u = res[0][np.newaxis].T*norm
u = u[:,0]
if res[1] > 0 :
Warning(TF.InYellowBackGround(TF.InRed("Convergence to tolerance not achieved"))) #pragma: no cover
if res[1] < 0 :
Warning(TF.InYellowBackGround(TF.InRed("Illegal input or breakdown"))) #pragma: no cover
return u
RegisterSolverClassUsingName(LinearSolverCG)
[docs]
class LinearSolvergmres(LinearSolverIterativeBase):
def __init__(self):
super().__init__()
self.name = "gmres"
def _solve_imp(self, rhs, u0):
diag = self.op.diagonal()
diag[diag == 0] = 1.0
M = sps.dia_matrix((1./diag,0), shape=self.op.shape)
return spslin.gmres(self.op, rhs, x0 = u0,M = M, rtol = self.tol, atol= self.tol)[0]
RegisterSolverClassUsingName(LinearSolvergmres)
[docs]
class LinearSolverlsqr(LinearSolverIterativeBase):
def __init__(self):
super().__init__()
self.name = "lsqr"
def _solve_imp(self, rhs, u0):
return spslin.lsqr(self.op, rhs, atol=self.tol, btol=self.tol, x0 = u0)[0]
RegisterSolverClassUsingName(LinearSolverlsqr)
[docs]
class LinearSolverlgmres(LinearSolverIterativeBase):
def __init__(self):
super().__init__()
self.name = "lgmres"
def _solve_imp(self, rhs, u0):
diag = self.op.diagonal()
diag[diag == 0] = 1.0
M = sps.dia_matrix((1./diag,0), shape=self.op.shape)
return spslin.lgmres(self.op, rhs, x0 = u0,M = M, rtol = self.tol, atol= self.tol)[0]
RegisterSolverClassUsingName(LinearSolverlgmres)
[docs]
class LinearSolverlAMG(LinearSolverIterativeBase):
def __init__(self):
super().__init__()
self.name = "AMG"
def _setop_imp(self,op):
import pyamg
self._internal_solver = pyamg.ruge_stuben_solver(op.tocsr())
def _solve_imp(self, rhs, u0):
return self._internal_solver.solve(rhs,x0=u0,tol=self.tol)
try:
import pyamg
RegisterSolverClassUsingName(LinearSolverlAMG)
except:
pass
[docs]
class LinearSolverDirect(LinearSolverIterativeBase):
def __init__(self):
super().__init__()
self.name = "Direct"
def _setop_imp(self,op):
self._internal_solver = sps.linalg.factorized(op.tocsc())
def _solve_imp(self, rhs, u0=None):
return self._internal_solver(rhs)
RegisterSolverClassUsingName(LinearSolverDirect)
[docs]
class LinearSolverCholesky(LinearSolverDirect):
def __init__(self):
super().__init__()
self.name = "cholesky"
def _setop_imp(self,op):
from sksparse.cholmod import cholesky
self._internal_solver = cholesky(op)
def _solve_imp(self, rhs, u0= None):
return self._internal_solver(rhs)
try:
from sksparse.cholmod import cholesky
RegisterSolverClassUsingName(LinearSolverCholesky)
except:
pass
[docs]
class LinearSolverPardiso(LinearSolverDirect):
def __init__(self):
super().__init__()
self.name = "Pardiso"
self._internal_solver = None
self._internal_solver_allocated = False
def _setop_imp(self,op):
self.FreeMemory()
self._internal_solver = FastInitPyPardisoSolver()
self._internal_solver.factorize(op)
self._internal_solver_allocated = True
def _solve_imp(self, rhs, u0= None):
return self._internal_solver.solve(self.op, rhs).squeeze()
def __del__(self):
self.FreeMemory()
[docs]
def FreeMemory(self):
if self._internal_solver_allocated:
self._internal_solver.free_memory(everything=True)
self._internal_solver_allocated = False
try:
import pypardiso
class FastInitPyPardisoSolver(pypardiso.PyPardisoSolver):
"""This is a fast init version of the pypardiso.PyPardisoSolver class
We use a class member to store the link to the mkl-library function
"""
backend = None
def __init__(self, mtype=11, phase=13, size_limit_storage:int=1000):
"""We do only one search and initialization of the mkl, and reuse the
Args:
mtype (int, optional): pypardiso is only teste for mtype=11 (real and nonsymetric). Defaults to 11.
phase (int, optional): _description_. Defaults to 13.
size_limit_storage (int, optional): _description_. Defaults to 5e7.
"""
if FastInitPyPardisoSolver.backend is None:
FastInitPyPardisoSolver.backend = pypardiso.PyPardisoSolver(mtype=mtype, phase=phase, size_limit_storage=size_limit_storage)
self.libmkl = FastInitPyPardisoSolver.backend.libmkl
self._mkl_pardiso = FastInitPyPardisoSolver.backend.libmkl.pardiso
self._pt_type = FastInitPyPardisoSolver.backend._pt_type
## this part if a copy of the init of PyPardisoSolver
self.pt = np.zeros(64, dtype=self._pt_type[1])
self.iparm = np.zeros(64, dtype=np.int32)
self.perm = np.zeros(0, dtype=np.int32)
self.mtype = mtype
self.phase = phase
self.msglvl = False
self.factorized_A = sps.csr_matrix((0, 0))
self.size_limit_storage = size_limit_storage
self._solve_transposed = False
RegisterSolverClassUsingName(LinearSolverPardiso)
except:
pass
[docs]
class LinearSolverEigen(LinearSolverIterativeBase):
def __init__(self,subtype):
super().__init__()
self.SetSolver(subtype)
import Muscat.LinAlg.NativeEigenSolver as NativeEigenSolver
self.solver = NativeEigenSolver.CEigenSolvers()
from Muscat.Helpers.CPU import GetNumberOfAvailableCores
self.solver.ForceNumberOfThreads(GetNumberOfAvailableCores())
[docs]
def SetSolver(self, subtype):
self.name = "Eigen"+subtype
self.subtype = subtype
self.solver = None
def _setop_imp(self,op):
self.solver.SetSolverType(self.subtype)
self.solver.SetTolerance(self.tol)
self.solver.SetOp(op)
def _solve_imp(self, rhs, u0=None):
# for the eigen solver we must allocate on the python side
if u0 is None:
u0 = np.zeros_like(rhs)
return self.solver.Solve(rhs,u0)
[docs]
def GetSPQRRank(self):
return self.solver.GetSPQRRank()
[docs]
def GetSPQR_Q(self):
return self.solver.GetSPQR_Q()
[docs]
def GetSPQR_R(self):
return self.solver.GetSPQR_R()
[docs]
def GetSPQR_P(self):
return self.solver.GetSPQR_P()
[docs]
@classmethod
def GetAvailableSolvers(cls):
return ['CG','LU','BiCGSTAB', 'SPQR']
try:
import Muscat.LinAlg.NativeEigenSolver as NativeEigenSolver
for eigenSubTypes in LinearSolverEigen.GetAvailableSolvers():
[docs]
def GenerateEigenConstructor(type):
return lambda x : LinearSolverEigen(type)
RegisterSolverClass("Eigen"+eigenSubTypes,LinearSolverEigen, GenerateEigenConstructor(eigenSubTypes) )
defaultSolver = "EigenCG"
defaultIfError = "CG"
except:
print(f"WARNING! Error loading Eigen linear solver using {defaultSolver} as default ")
linearSolverDispatcherDefaultThreshold = MuscatIndex(1000)
if "EigenLU" in LinearSolverEigen.GetAvailableSolvers():
linearSolverDispatcherDefaultLowerSolver = "EigenLU"
else:
linearSolverDispatcherDefaultLowerSolver = "Direct"
if "EigenCG" in LinearSolverEigen.GetAvailableSolvers():
linearSolverDispatcherDefaultLowerSolver = "EigenCG"
else:
linearSolverDispatcherDefaultUpperSolver = defaultSolver
[docs]
class LinearSolverDispatcher(LinearSolverIterativeBase):
"""Class to select a linear solver depending on the size of the problem (number of dofs) at runtime
"""
def __init__(self):
super().__init__()
self.name = "Dispatcher"
self.dofThreshold = MuscatIndex(linearSolverDispatcherDefaultThreshold)
self.lowerSolverName = "EigenLU"
self.lowerSolverOps = {}
self.upperSolverName = "EigenCG"
self.upperSolverOps = {}
self._internal_solver = None
[docs]
def SetSmallSizeSolver(self, solverName:str, ops={}):
"""Set the small solver name and the option to be used at creation
Parameters
----------
solverName : str
the name of the solver to be used for "small" systems
ops : dict, optional
option to be passed to the solver, by default {}
"""
self.lowerSolverName = solverName
self.lowerSolverOps = dict(ops)
[docs]
def SetBigSizeSolver(self, solverName:str, ops={}):
"""Set the big solver name and the option to be used at creation
Parameters
----------
solverName : str
the name of the solver to be used for "big" systems
ops : dict, optional
option to be passed to the solver, by default {}
"""
self.upperSolverName = solverName
self.upperSolverOps = dict(ops)
[docs]
def SetDofThreshold(self, threshold: MuscatIndex):
"""Set the threshold value to select the small/big solver.
Parameters
----------
threshold : MuscatIndex
threshold value to select the small or the big solver.
"""
self.dofThreshold = MuscatIndex(threshold)
def _setop_imp(self,op):
nbdoffs = op.shape[0]
if nbdoffs < self.dofThreshold:
if self._internal_solver is None or self._internal_solver.name != self.lowerSolverName:
self._internal_solver = SolverFactory.Create(self.lowerSolverName,ops=self.lowerSolverOps)
else:
if self._internal_solver is None or self._internal_solver.name != self.upperSolverName:
self._internal_solver = SolverFactory.Create(self.upperSolverName,ops=self.upperSolverOps)
self._internal_solver._setop_imp(op)
def _solve_imp(self, rhs, u0= None):
return self._internal_solver._solve_imp(rhs, u0= u0)
RegisterSolverClassUsingName(LinearSolverDispatcher)
###############################################################################################################################
[docs]
class LinearProblem():
def __init__(self):
super().__init__()
self.realsolver = None
self.SetAlgo(defaultSolver)
[docs]
def SetTolerance(self,tol):
if self.realsolver == None: #pragma: no cover
raise(Exception("Please set the solver type first"))
self.realsolver.SetTolerance(tol)
[docs]
def HasConstraints(self):
if self.realsolver == None: #pragma: no cover
raise(Exception("Please set the solver type first"))
return self.realsolver.HasConstraints()
[docs]
def GetNumberOfDofs(self):
if self.realsolver == None: #pragma: no cover
raise(Exception("Please set the solver type first"))
return self.realsolver.GetNumberOfDofs()
[docs]
def ComputeProjector(self,mesh,fields):
if self.realsolver == None: #pragma: no cover
raise(Exception("Please set the solver type first"))
self.realsolver.ComputeProjector(mesh,fields)
# you must set SetAlgo before setting the Op
[docs]
def SetOp(self, op):
if self.realsolver == None: #pragma: no cover
raise(Exception("Please set the solver type first"))
self.realsolver.SetOp(op)
[docs]
def SetAlgo(self, name, ops=None, withErrorIfNotFound=False):
try:
if self.realsolver is not None:
constraints = self.realsolver.GetConstraints()
self.realsolver = SolverFactory.Create(name,ops=ops)
self.realsolver.SetConstraints(constraints)
else:
self.realsolver = SolverFactory.Create(name,ops=ops)
except:
if withErrorIfNotFound: #pragma: no cover
raise
else:
print(f"Solver {name} unavailable, falling back to solver {defaultIfError} instead.")
self.SetAlgo(defaultIfError, withErrorIfNotFound=True)
[docs]
def Solve(self, rhs):
if self.realsolver == None: #pragma: no cover
raise Exception("Please set the solver type first")
return self.realsolver.Solve(rhs)
@property
def constraints(self) -> ConstraintsHolder:
if self.realsolver == None: #pragma: no cover
raise Exception("Please set the solver type first")
return self.realsolver.GetConstraints()
@constraints.setter
def constraints(self, constraints):
if self.realsolver == None: #pragma: no cover
raise(Exception("Please set the solver type first"))
return self.realsolver.SetConstraints(constraints)
[docs]
def CheckSolver(GUI,solver):
print("Solver "+ str(solver))
LS = LinearProblem ()
LS.SetAlgo(solver)
print("Number of dofs")
LS.SetOp(sps.csc_matrix(np.array([[0.5,0],[0,1]]),dtype=MuscatFloat))
print("HasConstraints : ", LS.HasConstraints() )
print("Number of dofs : ", LS.GetNumberOfDofs() )
sol = LS.Solve(np.array([[1.],[2.]]))
# second run
sol = LS.Solve(np.array([[1.],[2.]]))
if abs(sol[0] - 2.) >1e-15 :
print(sol[0]-2) #pragma: no cover
print("sol : ",sol) #pragma: no cover
raise Exception() #pragma: no cover
if abs(sol[1] - 2.) > 1e-15 : raise Exception()
[docs]
def CheckSPQR(GUI):
realsolver = LinearSolverEigen("SPQR")
A = sps.csc_matrix(np.array([[0,0,0,1,1],
[0,0,0,1,1],
[0.51,0.5,0.5,0,10],
[0.5,0.5,0.5,0,10],
[0,1,0,0,5],
]),dtype=MuscatFloat)
def QRTest(A):
realsolver.SetTolerance(1e-5)
print(A.toarray())
realsolver.SetOp(A)
rank = realsolver.GetSPQRRank()
print("Rank",rank )
Q = realsolver.GetSPQR_Q()
print("Q")
print(Q.toarray())
R = realsolver.GetSPQR_R()
print("R")
print(R.toarray())
P = realsolver.GetSPQR_P()
print("P" ,P)
print("-------------------------------------------------")
print("AP")
print(A.toarray()[:,P])
print("QR")
print(Q.tocsr().dot(R.tocsr().toarray() ) )
print("norm A[:,P]-QR")
print(np.linalg.norm(A.toarray()[:,P] - Q.tocsr()[:,0:rank].dot(R.tocsr()[0:rank,:].toarray() ) ) )
print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
print("Q réduit")
print(Q.tocsr()[:,0:rank].toarray())
print("R réduit")
print(R.tocsr()[0:rank,0:rank].toarray())
print("QR reduit")
print(Q.tocsr()[:,0:rank].dot(R.tocsr()[0:rank,0:rank].toarray() ) )
print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
error = np.linalg.norm(A.toarray()[:,P[0:rank]] - Q.tocsr()[:,0:rank].dot(R.tocsr()[0:rank,0:rank].toarray() ) )
if abs(error) > 1e-10:
raise
print(error)
print ("OK EigenSPQR" )
QRTest(A.T)
[docs]
def CheckIntegrity(GUI:bool=False):
obj = SolverFactory()
solvers = GetAvailableSolvers()
for s in solvers:
CheckSolver(GUI,s)
CheckSPQR(GUI)
return "ok"
if __name__ == '__main__':
print(CheckIntegrity()) #pragma: no cover