# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
from typing import Optional, Any, Callable
import os
import hashlib
import pickle
import numpy as np
from Muscat.Helpers.IO.TemporaryDirectory import TemporaryDirectory
from Muscat.Helpers.Logger import Info, Warning
#: Disconnect the use of the cache
#: Evaluated at execution time (not at definition)
#: so the user can turn off the use of cache for a single call of a cache function
USE_CACHE: bool = True
[docs]
def CachedResultDecorator(name: Optional[str] = None, path: Optional[str] = None, needDill: Optional[bool] = False):
"""Decorator to add cache capabilities to a function
Parameters
----------
name : str
the name associated to the cache, this value is used to build the file name
to store previous computed values. optional, by default None
path : str, optional
Path to store the cache data, if None, all data is written to a temporary
directory, by default None
needDill: bool, optional
tells the cache that we need dill for pickling object (in the case lambdas are present )
if dill is not present return the original function with no cache
"""
def f(function):
return GetFunctionWithCache(function, name=name, path=path, needDill=needDill)
return f
def __HashFunction(v: Any) -> Any:
"""Function to compute the hash of v.
this function uses the hashlib internally. If it fail
we serialize the object v using pickle.dumps and then
compute the hash
Parameters
----------
v : Any
Value to use to compute tha hash
Returns
-------
Any
hash of ve
"""
if type(v) is str:
return hashlib.sha256(v.encode())
elif v is None:
return __HashFunction("None")
else:
try:
return hashlib.sha256(v)
except:
return __HashFunction(pickle.dumps(v, fix_imports=False))
[docs]
def GetFunctionWithCache(func: Callable, name: Optional[str] = None, path: Optional[str] = None, needDill: bool = False) -> Callable:
"""Helper function to conserve the output of a function for later invocation
with the same arguments. The function must be a pure function (ie. the
function must not rely on any persistent or internal state). The user can
use a dummy argument to force the function to be recalculated.
All the argument must support the == operator and a sha256 or pickle
Parameters
----------
func : Callable
function to be decorated, must be a pure function
name : str, optional
the name associated to the cache, this value is used to build the file name
to store previous computed values. optional, by default None
path : str, optional
Path to store the cache data, if None, all data is written to a temporary
directory, by default None
needDill: bool, default False
tells the cache that we need dill for pickling object (in the case lambdas are present )
if dill is not present return the original function with no cache
Returns
-------
Callable
A function pointer with the same argument as function but,
if needDill == True and the module is not available then the original function (func) is returned
"""
if needDill:
try:
import dill
from Muscat.IO.DillTools import LoadData, SaveData
except: # pragma: no cover
return func
else:
from Muscat.IO.PickleTools import LoadData, SaveData
# warning in the case lambda function are used without name
if func.__name__ == "<lambda>" and name is None:
raise RuntimeError("Cant create a cached function of a lambda function with name=None")
# checking if the user try to pass a already cached function
if func.__name__.find("_withCache") > -1:
raise RuntimeError("Cant create a cached function of a FuncWithCache function")
# default name
if name is None:
name = func.__name__
# default path
if path is None:
path = TemporaryDirectory.GetTempPath()
# make sure the path exist and is created
os.makedirs(path, exist_ok=True)
def compare(a, b):
if isinstance(a, (list, tuple)) and isinstance(b, type(a)):
for i in range(max(len(a), len(b))):
if not compare(a[i], b[i]):
# really unusual (only if sha256 are equal and not the values)
return False # pragma: no cover
return True
else: # numpy arrays
return np.all(a == b)
def FuncWithCache(*args, **kwargs):
if not USE_CACHE:
return func(*args, **kwargs)
# generation of the hash to create a unique filename
hash = ""
# for the name of the function
hash += __HashFunction(func.__name__).hexdigest()
# for each argument
for a in args:
hash += __HashFunction(a).hexdigest()
# for each keyword argument
for k, v in kwargs.items():
hash += __HashFunction(k).hexdigest()
hash += __HashFunction(v).hexdigest()
finalHash = __HashFunction(hash).hexdigest()
fileNameInputs = path + "Cached_inputs_" + name + "_" + finalHash + ".cache"
fileNameOutputs = path + "Cached_outputs_" + name + "_" + finalHash + ".cache"
# detection of old cache and handling hash clash
if os.path.exists(fileNameInputs) and os.path.exists(fileNameOutputs):
try:
data = LoadData(fileNameInputs)
if len(data.unnamed) != len(args):
raise TypeError("Invalid number of arguments") # pragma: no cover
for v0, v1 in zip(data.unnamed, args):
if compare(v0, v1):
continue
raise RuntimeError("Error on arguments (compare)") # pragma: no cover
if len(data.named) != len(kwargs):
raise TypeError("Invalid number of keyword arguments") # pragma: no cover
for k in data.named:
v0 = data.named[k]
v1 = kwargs[k]
if np.all(v0 == v1):
continue
raise RuntimeError("Error on keyword arguments") # pragma: no cover
# all ok send back the cached result
data = LoadData(fileNameOutputs)
return data.unnamed[0]
except Exception as e: # pragma: no cover
Warning("Warning: error reading data from disk, or cached data not valid (rebuilding cache) ")
Warning(fileNameInputs)
Warning(fileNameOutputs)
pass
# do the heavy computation
res = func(*args, **kwargs)
# try to save the input and the output data
try:
SaveData(fileNameInputs, *args, **kwargs)
SaveData(fileNameOutputs, res)
except: # pragma: no cover
Warning("Warning: Error saving cache data (cache will not be available for later)")
pass
return res
# change the name of the cached function to a more explicit name
FuncWithCache.__name__ = func.__name__ + "_withCache"
return FuncWithCache
[docs]
def CheckIntegrity(GUI: bool = False) -> str:
cpt = 0
def CountTheNumberOfExecutions(arg):
import time
return time.localtime()
print(CountTheNumberOfExecutions("hola"))
f = GetFunctionWithCache(CountTheNumberOfExecutions)
print(f("hola"))
print(f("hola"))
import numpy as np
a = np.arange(5)
def plus1(arg):
return arg + 1
f = GetFunctionWithCache(plus1)
f(1)
f(1)
f2 = GetFunctionWithCache(plus1)
print("Original Function ", plus1(a))
print("First call with cache", f(a))
print("Second call with cache", f(a))
print("First call with cache new argument", f(1))
print("Second call with cache new argument ", f(1))
def return2():
return 2
@CachedResultDecorator(name="superFunction")
def return3():
import time
# time.sleep(3)
return 3
f = GetFunctionWithCache(return2)
print("With no args return 2 ", f())
print("With no args return 3 ", return3())
print("With no args return 3 ", return3.__name__)
@CachedResultDecorator(name="superFunction")
def returnAddWithArgs(toto, tata=4):
import time
print(f"running function toto {toto}, tata {tata}**********************************")
return toto + tata
print("----------------------------------------------------")
print("With args return (3) ->", returnAddWithArgs(3))
print("----------------------------------------------------")
from Muscat.Helpers import Cache
OLD_USE_CACHE = Cache.USE_CACHE
Cache.USE_CACHE = not OLD_USE_CACHE
print("----------------------------------------------------")
print("With args return (3) ->", returnAddWithArgs(3))
print("----------------------------------------------------")
Cache.USE_CACHE = OLD_USE_CACHE
print("With args return (3,tata=4) ->", returnAddWithArgs(3, tata=4))
print("----------------------------------------------------")
print("With args return (3,tata=4) ->", returnAddWithArgs(3, tata=4))
print("----------------------------------------------------")
@CachedResultDecorator(name="NoneReturningFunction")
def NoneReturningFunction(a):
return None
NoneReturningFunction(None)
from Muscat.Helpers.CheckTools import MustFailFunction, MustFailFunctionWith
# this call fails in pure python because the lambda does not have a name
# if we compile the python (cythonize) then the lambda is given a "c" name and
# the the next line does not fait anymore.
# MustFailFunction(GetFunctionWithCache, lambda x: None)
# but works if we have a name
cf = GetFunctionWithCache(lambda x: None, name="lambda func1 ")
MustFailFunction(GetFunctionWithCache, cf)
MustFailFunctionWith(RuntimeError, GetFunctionWithCache, lambda x: 0)
# test list of tuples of list ...
def lllPlus1(arg):
return arg[0][0][0] + 1
lllf = GetFunctionWithCache(lllPlus1)
lllf(([np.array([3])],))
np.testing.assert_equal(lllf(([np.array([3])],)), 4)
lllf(([np.array([2])],))
return "ok"
if __name__ == "__main__":
print(CheckIntegrity()) # pragma: no cover