Created
November 1, 2024 14:54
-
-
Save mkbabb/a12d604791a35a215177ada9ffa27ccd to your computer and use it in GitHub Desktop.
cache-with-stale-interval
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def cache_with_stale_interval( | |
| stale_interval: datetime.timedelta | float | int | None = None, | |
| ) -> Callable[[Callable[P, R]], Callable[P, R]]: | |
| """ | |
| Decorator that caches function results with optional time-based invalidation. | |
| Caching works by pickling the output alongside a JSON metadata file containing | |
| the timestamp of the cached result. If the stale_interval is set, the cached | |
| result is considered stale after the specified time and the function will be | |
| re-run. | |
| Args: | |
| stale_interval: Time after which cached results are considered stale. | |
| Can be timedelta, seconds (int/float), or None for no expiry. | |
| """ | |
| if isinstance(stale_interval, (int, float)): | |
| stale_interval = datetime.timedelta(seconds=stale_interval) | |
| def decorator(func: Callable[P, R]) -> Callable[P, R]: | |
| @functools.wraps(func) | |
| def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | |
| # Generate cache key from function inputs | |
| input_data = {"args": args, "kwargs": kwargs} | |
| input_hash = hashlib.md5( | |
| json.dumps(input_data, default=str, sort_keys=True).encode() | |
| ).hexdigest() | |
| cache_dir = get_cache_dir() | |
| output_path = cache_dir / f"{input_hash}_output.json" | |
| pickled_output_path = cache_dir / f"{input_hash}_output.pkl" | |
| logger.debug(f"Cache lookup for {func.__name__} with hash {input_hash}") | |
| # Check for existing cached result | |
| if output_path.exists(): | |
| with output_path.open("r") as f: | |
| cached_data = json.load(f) | |
| cached_timestamp = datetime.datetime.fromisoformat( | |
| cached_data["timestamp"] | |
| ) | |
| # Return cached result if no stale interval or not stale | |
| if stale_interval is None: | |
| logger.info( | |
| f"Cache hit for {func.__name__} (no stale interval)" | |
| ) | |
| with open(cached_data["pickled_output_path"], "rb") as pkl_file: | |
| return pickle.load(pkl_file) | |
| age = datetime.datetime.now() - cached_timestamp | |
| if age <= stale_interval: | |
| logger.info(f"Cache hit for {func.__name__} (age: {age})") | |
| with open(cached_data["pickled_output_path"], "rb") as pkl_file: | |
| return pickle.load(pkl_file) | |
| else: | |
| logger.info( | |
| f"Cache stale for {func.__name__} (age: {age} > {stale_interval})" | |
| ) | |
| # Cache miss or stale - compute new value and cache it | |
| logger.info(f"Cache miss for {func.__name__}, computing new value") | |
| output_data = func(*args, **kwargs) | |
| with open(pickled_output_path, "wb") as pkl_file: | |
| pickle.dump(output_data, pkl_file) | |
| with output_path.open("w") as f: | |
| json.dump( | |
| { | |
| "pickled_output_path": str(pickled_output_path), | |
| "timestamp": datetime.datetime.now().isoformat(), | |
| }, | |
| f, | |
| indent=4, | |
| ) | |
| logger.debug(f"Cached new value for {func.__name__} with hash {input_hash}") | |
| return output_data | |
| return wrapper | |
| return decorator | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment