Source code for Muscat.LinAlg.LinearSolver

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
import numpy as np
import scipy.sparse as sps
import scipy.sparse.linalg as spslin

from Muscat.Types import MuscatFloat, MuscatIndex
from Muscat.Helpers.Logger import Debug, Info
from Muscat.Helpers.TextFormatHelper import TFormat as TF
from Muscat.LinAlg.ConstraintsHolder import ConstraintsHolder
from Muscat.Helpers.Factory import Factory
from Muscat.Helpers.Logger import Warning

defaultSolver = "EigenCG"
defaultIfError = "CG"

[docs] class SolverFactory(Factory): _Catalog = {} _SetCatalog = set() def __init__(self): super().__init__()
[docs] def GetAvailableSolvers(): return list(SolverFactory._Catalog.keys())
[docs] def RegisterSolverClass(name, classtype, constructor=None, withError = True): return SolverFactory.RegisterClass(name,classtype, constructor=constructor, withError = withError )
[docs] def RegisterSolverClassUsingName(cls): RegisterSolverClass(cls().name, cls)
[docs] class LinearSolverBase(): def __init__(self): super().__init__() self.op = None # the operator to solve self.originalOp = None self.solver = None self.name = '' self.u = None self.constraints = ConstraintsHolder() self._can_use_u0 = False
[docs] def GetConstraints(self) -> ConstraintsHolder: return self.constraints
[docs] def SetConstraints(self, constraints:ConstraintsHolder): self.constraints = constraints
[docs] def HasConstraints(self)-> bool: return self.constraints.numberOfEquations > 0
[docs] def GetNumberOfDofs(self) -> MuscatIndex: return self.op.shape[0]
[docs] def ComputeProjector(self, op ): Info(" With constraints (" +str(self.constraints.numberOfEquations) + ")") Info(" Treating constraints using "+ str(type(self.constraints.method)) ) self.constraints.SetNumberOfDofs(op.shape[1]) Debug(" Setting Op" ) self.constraints.SetOp( op ) Debug(" UpdateCSystem()" ) self.constraints.UpdateCSystem() Debug(" GetCOp ") op = self.constraints.GetCOp() Info('Constraints treatment Done ') return op
[docs] def SetOp(self, op): Info('In SetOp (type:' +str(self.name) + ')') self.originalOp = op if self.HasConstraints(): op = self.ComputeProjector(op) self.op = op Debug('going in LinearSolver._setop_imp(op)') self._setop_imp(op) Debug('In LinearSolver.SetOp(...) Done')
[docs] def Solve(self, rhs, u0=None): if self.HasConstraints(): rhs = self.constraints.GetCRhs(rhs.squeeze()) rhs = np.atleast_1d(rhs.squeeze()) if self.u is not None and self.originalOp is not None and len(self.u) != self.originalOp.shape[1]: self.u = None Debug('In LinearProblem Solve ' + self.name) if self._can_use_u0 : if self.HasConstraints(): if u0 is not None: u0 = self.constraints.RestrictSolution(u0) else: if self.u is not None: u0 = self.constraints.RestrictSolution(self.u) if u0 is None: u0 = np.zeros_like(rhs) u0 = np.atleast_1d(u0.squeeze()) u = self._solve_imp(rhs,u0=u0) else: if u0 is not None and self.__print_warning_u0_ignored: print("u0 ignored for direct solvers") self.__print_warning_u0_ignored = False u = self._solve_imp(rhs,u0=None) Debug("Done Linear solver "+str(u.shape)) if self.HasConstraints(): self.u = self.constraints.RestoreSolution(u) else: self.u = u return self.u
def _setop_imp(self,op): pass
[docs] class LinearSolverIterativeBase(LinearSolverBase): def __init__(self): super().__init__() self.tol = 1.e-6 self._can_use_u0 = True
[docs] def SetTolerance(self,tol): self.tol = tol
[docs] class LinearSolverCG(LinearSolverIterativeBase): def __init__(self): super().__init__() self.name = "CG" def _solve_imp(self, rhs, u0): diag = self.op.diagonal() diag[diag == 0] = 1.0 M = sps.dia_matrix((1./diag,0), shape=self.op.shape) norm = np.linalg.norm(rhs) if u0 is None: res = spslin.cg(self.op, rhs/norm, M = M, rtol = self.tol, atol = self.tol) else: res = spslin.cg(self.op, rhs/norm, M = M, x0 = u0/norm, rtol = self.tol, atol = self.tol) u = res[0][np.newaxis].T*norm u = u[:,0] if res[1] > 0 : Warning(TF.InYellowBackGround(TF.InRed("Convergence to tolerance not achieved"))) #pragma: no cover if res[1] < 0 : Warning(TF.InYellowBackGround(TF.InRed("Illegal input or breakdown"))) #pragma: no cover return u
RegisterSolverClassUsingName(LinearSolverCG)
[docs] class LinearSolvergmres(LinearSolverIterativeBase): def __init__(self): super().__init__() self.name = "gmres" def _solve_imp(self, rhs, u0): diag = self.op.diagonal() diag[diag == 0] = 1.0 M = sps.dia_matrix((1./diag,0), shape=self.op.shape) return spslin.gmres(self.op, rhs, x0 = u0,M = M, rtol = self.tol, atol= self.tol)[0]
RegisterSolverClassUsingName(LinearSolvergmres)
[docs] class LinearSolverlsqr(LinearSolverIterativeBase): def __init__(self): super().__init__() self.name = "lsqr" def _solve_imp(self, rhs, u0): return spslin.lsqr(self.op, rhs, atol=self.tol, btol=self.tol, x0 = u0)[0]
RegisterSolverClassUsingName(LinearSolverlsqr)
[docs] class LinearSolverlgmres(LinearSolverIterativeBase): def __init__(self): super().__init__() self.name = "lgmres" def _solve_imp(self, rhs, u0): diag = self.op.diagonal() diag[diag == 0] = 1.0 M = sps.dia_matrix((1./diag,0), shape=self.op.shape) return spslin.lgmres(self.op, rhs, x0 = u0,M = M, rtol = self.tol, atol= self.tol)[0]
RegisterSolverClassUsingName(LinearSolverlgmres)
[docs] class LinearSolverlAMG(LinearSolverIterativeBase): def __init__(self): super().__init__() self.name = "AMG" def _setop_imp(self,op): import pyamg self._internal_solver = pyamg.ruge_stuben_solver(op.tocsr()) def _solve_imp(self, rhs, u0): return self._internal_solver.solve(rhs,x0=u0,tol=self.tol)
try: import pyamg RegisterSolverClassUsingName(LinearSolverlAMG) except: pass
[docs] class LinearSolverDirect(LinearSolverIterativeBase): def __init__(self): super().__init__() self.name = "Direct" def _setop_imp(self,op): self._internal_solver = sps.linalg.factorized(op.tocsc()) def _solve_imp(self, rhs, u0=None): return self._internal_solver(rhs)
RegisterSolverClassUsingName(LinearSolverDirect)
[docs] class LinearSolverCholesky(LinearSolverDirect): def __init__(self): super().__init__() self.name = "cholesky" def _setop_imp(self,op): from sksparse.cholmod import cholesky self._internal_solver = cholesky(op) def _solve_imp(self, rhs, u0= None): return self._internal_solver(rhs)
try: from sksparse.cholmod import cholesky RegisterSolverClassUsingName(LinearSolverCholesky) except: pass
[docs] class LinearSolverPardiso(LinearSolverDirect): def __init__(self): super().__init__() self.name = "Pardiso" self._internal_solver = None self._internal_solver_allocated = False def _setop_imp(self,op): self.FreeMemory() self._internal_solver = FastInitPyPardisoSolver() self._internal_solver.factorize(op) self._internal_solver_allocated = True def _solve_imp(self, rhs, u0= None): return self._internal_solver.solve(self.op, rhs).squeeze() def __del__(self): self.FreeMemory()
[docs] def FreeMemory(self): if self._internal_solver_allocated: self._internal_solver.free_memory(everything=True) self._internal_solver_allocated = False
try: import pypardiso class FastInitPyPardisoSolver(pypardiso.PyPardisoSolver): """This is a fast init version of the pypardiso.PyPardisoSolver class We use a class member to store the link to the mkl-library function """ backend = None def __init__(self, mtype=11, phase=13, size_limit_storage:int=1000): """We do only one search and initialization of the mkl, and reuse the Args: mtype (int, optional): pypardiso is only teste for mtype=11 (real and nonsymetric). Defaults to 11. phase (int, optional): _description_. Defaults to 13. size_limit_storage (int, optional): _description_. Defaults to 5e7. """ if FastInitPyPardisoSolver.backend is None: FastInitPyPardisoSolver.backend = pypardiso.PyPardisoSolver(mtype=mtype, phase=phase, size_limit_storage=size_limit_storage) self.libmkl = FastInitPyPardisoSolver.backend.libmkl self._mkl_pardiso = FastInitPyPardisoSolver.backend.libmkl.pardiso self._pt_type = FastInitPyPardisoSolver.backend._pt_type ## this part if a copy of the init of PyPardisoSolver self.pt = np.zeros(64, dtype=self._pt_type[1]) self.iparm = np.zeros(64, dtype=np.int32) self.perm = np.zeros(0, dtype=np.int32) self.mtype = mtype self.phase = phase self.msglvl = False self.factorized_A = sps.csr_matrix((0, 0)) self.size_limit_storage = size_limit_storage self._solve_transposed = False RegisterSolverClassUsingName(LinearSolverPardiso) except: pass
[docs] class LinearSolverEigen(LinearSolverIterativeBase): def __init__(self,subtype): super().__init__() self.SetSolver(subtype) import Muscat.LinAlg.NativeEigenSolver as NativeEigenSolver self.solver = NativeEigenSolver.CEigenSolvers() from Muscat.Helpers.CPU import GetNumberOfAvailableCores self.solver.ForceNumberOfThreads(GetNumberOfAvailableCores())
[docs] def SetSolver(self, subtype): self.name = "Eigen"+subtype self.subtype = subtype self.solver = None
def _setop_imp(self,op): self.solver.SetSolverType(self.subtype) self.solver.SetTolerance(self.tol) self.solver.SetOp(op) def _solve_imp(self, rhs, u0=None): # for the eigen solver we must allocate on the python side if u0 is None: u0 = np.zeros_like(rhs) return self.solver.Solve(rhs,u0)
[docs] def GetSPQRRank(self): return self.solver.GetSPQRRank()
[docs] def GetSPQR_Q(self): return self.solver.GetSPQR_Q()
[docs] def GetSPQR_R(self): return self.solver.GetSPQR_R()
[docs] def GetSPQR_P(self): return self.solver.GetSPQR_P()
[docs] @classmethod def GetAvailableSolvers(cls): return ['CG','LU','BiCGSTAB', 'SPQR']
try: import Muscat.LinAlg.NativeEigenSolver as NativeEigenSolver for eigenSubTypes in LinearSolverEigen.GetAvailableSolvers():
[docs] def GenerateEigenConstructor(type): return lambda x : LinearSolverEigen(type)
RegisterSolverClass("Eigen"+eigenSubTypes,LinearSolverEigen, GenerateEigenConstructor(eigenSubTypes) ) defaultSolver = "EigenCG" defaultIfError = "CG" except: print(f"WARNING! Error loading Eigen linear solver using {defaultSolver} as default ") linearSolverDispatcherDefaultThreshold = MuscatIndex(1000) if "EigenLU" in LinearSolverEigen.GetAvailableSolvers(): linearSolverDispatcherDefaultLowerSolver = "EigenLU" else: linearSolverDispatcherDefaultLowerSolver = "Direct" if "EigenCG" in LinearSolverEigen.GetAvailableSolvers(): linearSolverDispatcherDefaultLowerSolver = "EigenCG" else: linearSolverDispatcherDefaultUpperSolver = defaultSolver
[docs] class LinearSolverDispatcher(LinearSolverIterativeBase): """Class to select a linear solver depending on the size of the problem (number of dofs) at runtime """ def __init__(self): super().__init__() self.name = "Dispatcher" self.dofThreshold = MuscatIndex(linearSolverDispatcherDefaultThreshold) self.lowerSolverName = "EigenLU" self.lowerSolverOps = {} self.upperSolverName = "EigenCG" self.upperSolverOps = {} self._internal_solver = None
[docs] def SetSmallSizeSolver(self, solverName:str, ops={}): """Set the small solver name and the option to be used at creation Parameters ---------- solverName : str the name of the solver to be used for "small" systems ops : dict, optional option to be passed to the solver, by default {} """ self.lowerSolverName = solverName self.lowerSolverOps = dict(ops)
[docs] def SetBigSizeSolver(self, solverName:str, ops={}): """Set the big solver name and the option to be used at creation Parameters ---------- solverName : str the name of the solver to be used for "big" systems ops : dict, optional option to be passed to the solver, by default {} """ self.upperSolverName = solverName self.upperSolverOps = dict(ops)
[docs] def SetDofThreshold(self, threshold: MuscatIndex): """Set the threshold value to select the small/big solver. Parameters ---------- threshold : MuscatIndex threshold value to select the small or the big solver. """ self.dofThreshold = MuscatIndex(threshold)
def _setop_imp(self,op): nbdoffs = op.shape[0] if nbdoffs < self.dofThreshold: if self._internal_solver is None or self._internal_solver.name != self.lowerSolverName: self._internal_solver = SolverFactory.Create(self.lowerSolverName,ops=self.lowerSolverOps) else: if self._internal_solver is None or self._internal_solver.name != self.upperSolverName: self._internal_solver = SolverFactory.Create(self.upperSolverName,ops=self.upperSolverOps) self._internal_solver._setop_imp(op) def _solve_imp(self, rhs, u0= None): return self._internal_solver._solve_imp(rhs, u0= u0)
RegisterSolverClassUsingName(LinearSolverDispatcher) ###############################################################################################################################
[docs] class LinearProblem(): def __init__(self): super().__init__() self.realsolver = None self.SetAlgo(defaultSolver)
[docs] def SetTolerance(self,tol): if self.realsolver == None: #pragma: no cover raise(Exception("Please set the solver type first")) self.realsolver.SetTolerance(tol)
[docs] def HasConstraints(self): if self.realsolver == None: #pragma: no cover raise(Exception("Please set the solver type first")) return self.realsolver.HasConstraints()
[docs] def GetNumberOfDofs(self): if self.realsolver == None: #pragma: no cover raise(Exception("Please set the solver type first")) return self.realsolver.GetNumberOfDofs()
[docs] def ComputeProjector(self,mesh,fields): if self.realsolver == None: #pragma: no cover raise(Exception("Please set the solver type first")) self.realsolver.ComputeProjector(mesh,fields)
# you must set SetAlgo before setting the Op
[docs] def SetOp(self, op): if self.realsolver == None: #pragma: no cover raise(Exception("Please set the solver type first")) self.realsolver.SetOp(op)
[docs] def SetAlgo(self, name, ops=None, withErrorIfNotFound=False): try: if self.realsolver is not None: constraints = self.realsolver.GetConstraints() self.realsolver = SolverFactory.Create(name,ops=ops) self.realsolver.SetConstraints(constraints) else: self.realsolver = SolverFactory.Create(name,ops=ops) except: if withErrorIfNotFound: #pragma: no cover raise else: print(f"Solver {name} unavailable, falling back to solver {defaultIfError} instead.") self.SetAlgo(defaultIfError, withErrorIfNotFound=True)
[docs] def Solve(self, rhs): if self.realsolver == None: #pragma: no cover raise Exception("Please set the solver type first") return self.realsolver.Solve(rhs)
@property def constraints(self) -> ConstraintsHolder: if self.realsolver == None: #pragma: no cover raise Exception("Please set the solver type first") return self.realsolver.GetConstraints() @constraints.setter def constraints(self, constraints): if self.realsolver == None: #pragma: no cover raise(Exception("Please set the solver type first")) return self.realsolver.SetConstraints(constraints)
[docs] def CheckSolver(GUI,solver): print("Solver "+ str(solver)) LS = LinearProblem () LS.SetAlgo(solver) print("Number of dofs") LS.SetOp(sps.csc_matrix(np.array([[0.5,0],[0,1]]),dtype=MuscatFloat)) print("HasConstraints : ", LS.HasConstraints() ) print("Number of dofs : ", LS.GetNumberOfDofs() ) sol = LS.Solve(np.array([[1.],[2.]])) # second run sol = LS.Solve(np.array([[1.],[2.]])) if abs(sol[0] - 2.) >1e-15 : print(sol[0]-2) #pragma: no cover print("sol : ",sol) #pragma: no cover raise Exception() #pragma: no cover if abs(sol[1] - 2.) > 1e-15 : raise Exception()
[docs] def CheckSPQR(GUI): realsolver = LinearSolverEigen("SPQR") A = sps.csc_matrix(np.array([[0,0,0,1,1], [0,0,0,1,1], [0.51,0.5,0.5,0,10], [0.5,0.5,0.5,0,10], [0,1,0,0,5], ]),dtype=MuscatFloat) def QRTest(A): realsolver.SetTolerance(1e-5) print(A.toarray()) realsolver.SetOp(A) rank = realsolver.GetSPQRRank() print("Rank",rank ) Q = realsolver.GetSPQR_Q() print("Q") print(Q.toarray()) R = realsolver.GetSPQR_R() print("R") print(R.toarray()) P = realsolver.GetSPQR_P() print("P" ,P) print("-------------------------------------------------") print("AP") print(A.toarray()[:,P]) print("QR") print(Q.tocsr().dot(R.tocsr().toarray() ) ) print("norm A[:,P]-QR") print(np.linalg.norm(A.toarray()[:,P] - Q.tocsr()[:,0:rank].dot(R.tocsr()[0:rank,:].toarray() ) ) ) print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") print("Q réduit") print(Q.tocsr()[:,0:rank].toarray()) print("R réduit") print(R.tocsr()[0:rank,0:rank].toarray()) print("QR reduit") print(Q.tocsr()[:,0:rank].dot(R.tocsr()[0:rank,0:rank].toarray() ) ) print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") error = np.linalg.norm(A.toarray()[:,P[0:rank]] - Q.tocsr()[:,0:rank].dot(R.tocsr()[0:rank,0:rank].toarray() ) ) if abs(error) > 1e-10: raise print(error) print ("OK EigenSPQR" ) QRTest(A.T)
[docs] def CheckIntegrity(GUI:bool=False): obj = SolverFactory() solvers = GetAvailableSolvers() for s in solvers: CheckSolver(GUI,s) CheckSPQR(GUI) return "ok"
if __name__ == '__main__': print(CheckIntegrity()) #pragma: no cover