Source code for Muscat.Helpers.Cache

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
from typing import Optional, Any, Callable
import os
import hashlib
import pickle


import numpy as np

from Muscat.Helpers.IO.TemporaryDirectory import TemporaryDirectory
from Muscat.Helpers.Logger import Info, Warning

#: Disconnect the use of the cache
#: Evaluated at execution time (not at definition)
#: so the user can turn off the use of cache for a single call of a cache function
USE_CACHE: bool = True

[docs] def CachedResultDecorator(name: Optional[str] = None, path: Optional[str] = None, needDill: Optional[bool] = False): """Decorator to add cache capabilities to a function Parameters ---------- name : str the name associated to the cache, this value is used to build the file name to store previous computed values. optional, by default None path : str, optional Path to store the cache data, if None, all data is written to a temporary directory, by default None needDill: bool, optional tells the cache that we need dill for pickling object (in the case lambdas are present ) if dill is not present return the original function with no cache """ def f(function): return GetFunctionWithCache(function, name=name, path=path, needDill=needDill) return f
def __HashFunction(v: Any) -> Any: """Function to compute the hash of v. this function uses the hashlib internally. If it fail we serialize the object v using pickle.dumps and then compute the hash Parameters ---------- v : Any Value to use to compute tha hash Returns ------- Any hash of ve """ if type(v) is str: return hashlib.sha256(v.encode()) elif v is None: return __HashFunction("None") else: try: return hashlib.sha256(v) except: return __HashFunction(pickle.dumps(v, fix_imports=False))
[docs] def GetFunctionWithCache(func: Callable, name: Optional[str] = None, path: Optional[str] = None, needDill: bool = False) -> Callable: """Helper function to conserve the output of a function for later invocation with the same arguments. The function must be a pure function (ie. the function must not rely on any persistent or internal state). The user can use a dummy argument to force the function to be recalculated. All the argument must support the == operator and a sha256 or pickle Parameters ---------- func : Callable function to be decorated, must be a pure function name : str, optional the name associated to the cache, this value is used to build the file name to store previous computed values. optional, by default None path : str, optional Path to store the cache data, if None, all data is written to a temporary directory, by default None needDill: bool, default False tells the cache that we need dill for pickling object (in the case lambdas are present ) if dill is not present return the original function with no cache Returns ------- Callable A function pointer with the same argument as function but, if needDill == True and the module is not available then the original function (func) is returned """ if needDill: try: import dill from Muscat.IO.DillTools import LoadData, SaveData except: # pragma: no cover return func else: from Muscat.IO.PickleTools import LoadData, SaveData # warning in the case lambda function are used without name if func.__name__ == "<lambda>" and name is None: raise RuntimeError("Cant create a cached function of a lambda function with name=None") # checking if the user try to pass a already cached function if func.__name__.find("_withCache") > -1: raise RuntimeError("Cant create a cached function of a FuncWithCache function") # default name if name is None: name = func.__name__ # default path if path is None: path = TemporaryDirectory.GetTempPath() # make sure the path exist and is created os.makedirs(path, exist_ok=True) def compare(a, b): if isinstance(a, (list, tuple)) and isinstance(b, type(a)): for i in range(max(len(a), len(b))): if not compare(a[i], b[i]): # really unusual (only if sha256 are equal and not the values) return False # pragma: no cover return True else: # numpy arrays return np.all(a == b) def FuncWithCache(*args, **kwargs): if not USE_CACHE: return func(*args, **kwargs) # generation of the hash to create a unique filename hash = "" # for the name of the function hash += __HashFunction(func.__name__).hexdigest() # for each argument for a in args: hash += __HashFunction(a).hexdigest() # for each keyword argument for k, v in kwargs.items(): hash += __HashFunction(k).hexdigest() hash += __HashFunction(v).hexdigest() finalHash = __HashFunction(hash).hexdigest() fileNameInputs = path + "Cached_inputs_" + name + "_" + finalHash + ".cache" fileNameOutputs = path + "Cached_outputs_" + name + "_" + finalHash + ".cache" # detection of old cache and handling hash clash if os.path.exists(fileNameInputs) and os.path.exists(fileNameOutputs): try: data = LoadData(fileNameInputs) if len(data.unnamed) != len(args): raise TypeError("Invalid number of arguments") # pragma: no cover for v0, v1 in zip(data.unnamed, args): if compare(v0, v1): continue raise RuntimeError("Error on arguments (compare)") # pragma: no cover if len(data.named) != len(kwargs): raise TypeError("Invalid number of keyword arguments") # pragma: no cover for k in data.named: v0 = data.named[k] v1 = kwargs[k] if np.all(v0 == v1): continue raise RuntimeError("Error on keyword arguments") # pragma: no cover # all ok send back the cached result data = LoadData(fileNameOutputs) return data.unnamed[0] except Exception as e: # pragma: no cover Warning("Warning: error reading data from disk, or cached data not valid (rebuilding cache) ") Warning(fileNameInputs) Warning(fileNameOutputs) pass # do the heavy computation res = func(*args, **kwargs) # try to save the input and the output data try: SaveData(fileNameInputs, *args, **kwargs) SaveData(fileNameOutputs, res) except: # pragma: no cover Warning("Warning: Error saving cache data (cache will not be available for later)") pass return res # change the name of the cached function to a more explicit name FuncWithCache.__name__ = func.__name__ + "_withCache" return FuncWithCache
[docs] def CheckIntegrity(GUI: bool = False) -> str: cpt = 0 def CountTheNumberOfExecutions(arg): import time return time.localtime() print(CountTheNumberOfExecutions("hola")) f = GetFunctionWithCache(CountTheNumberOfExecutions) print(f("hola")) print(f("hola")) import numpy as np a = np.arange(5) def plus1(arg): return arg + 1 f = GetFunctionWithCache(plus1) f(1) f(1) f2 = GetFunctionWithCache(plus1) print("Original Function ", plus1(a)) print("First call with cache", f(a)) print("Second call with cache", f(a)) print("First call with cache new argument", f(1)) print("Second call with cache new argument ", f(1)) def return2(): return 2 @CachedResultDecorator(name="superFunction") def return3(): import time # time.sleep(3) return 3 f = GetFunctionWithCache(return2) print("With no args return 2 ", f()) print("With no args return 3 ", return3()) print("With no args return 3 ", return3.__name__) @CachedResultDecorator(name="superFunction") def returnAddWithArgs(toto, tata=4): import time print(f"running function toto {toto}, tata {tata}**********************************") return toto + tata print("----------------------------------------------------") print("With args return (3) ->", returnAddWithArgs(3)) print("----------------------------------------------------") from Muscat.Helpers import Cache OLD_USE_CACHE = Cache.USE_CACHE Cache.USE_CACHE = not OLD_USE_CACHE print("----------------------------------------------------") print("With args return (3) ->", returnAddWithArgs(3)) print("----------------------------------------------------") Cache.USE_CACHE = OLD_USE_CACHE print("With args return (3,tata=4) ->", returnAddWithArgs(3, tata=4)) print("----------------------------------------------------") print("With args return (3,tata=4) ->", returnAddWithArgs(3, tata=4)) print("----------------------------------------------------") @CachedResultDecorator(name="NoneReturningFunction") def NoneReturningFunction(a): return None NoneReturningFunction(None) from Muscat.Helpers.CheckTools import MustFailFunction, MustFailFunctionWith # this call fails in pure python because the lambda does not have a name # if we compile the python (cythonize) then the lambda is given a "c" name and # the the next line does not fait anymore. # MustFailFunction(GetFunctionWithCache, lambda x: None) # but works if we have a name cf = GetFunctionWithCache(lambda x: None, name="lambda func1 ") MustFailFunction(GetFunctionWithCache, cf) MustFailFunctionWith(RuntimeError, GetFunctionWithCache, lambda x: 0) # test list of tuples of list ... def lllPlus1(arg): return arg[0][0][0] + 1 lllf = GetFunctionWithCache(lllPlus1) lllf(([np.array([3])],)) np.testing.assert_equal(lllf(([np.array([3])],)), 4) lllf(([np.array([2])],)) return "ok"
if __name__ == "__main__": print(CheckIntegrity()) # pragma: no cover