* wait I forgot to comit locally * cp the entire core directory and then rm the .git subdir
161 lines
5.8 KiB
Python
161 lines
5.8 KiB
Python
import inspect
|
|
import json
|
|
from dataclasses import dataclass
|
|
from functools import wraps
|
|
from typing import Callable
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
|
|
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
|
from letta.log import get_logger
|
|
from letta.plugins.plugins import get_experimental_checker
|
|
from letta.settings import settings
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def experimental(feature_name: str, fallback_function: Callable, **kwargs):
|
|
"""Decorator that runs a fallback function if experimental feature is not enabled.
|
|
|
|
- kwargs from the decorator will be combined with function kwargs and overwritten only for experimental evaluation.
|
|
- if the decorated function, fallback_function, or experimental checker function is async, the whole call will be async
|
|
"""
|
|
|
|
def decorator(f):
|
|
experimental_checker = get_experimental_checker()
|
|
is_f_async = inspect.iscoroutinefunction(f)
|
|
is_fallback_async = inspect.iscoroutinefunction(fallback_function)
|
|
is_experimental_checker_async = inspect.iscoroutinefunction(experimental_checker)
|
|
|
|
async def call_function(func, is_async, *args, **_kwargs):
|
|
if is_async:
|
|
return await func(*args, **_kwargs)
|
|
return func(*args, **_kwargs)
|
|
|
|
# asynchronous wrapper if any function is async
|
|
if any((is_f_async, is_fallback_async, is_experimental_checker_async)):
|
|
|
|
@wraps(f)
|
|
async def async_wrapper(*args, **_kwargs):
|
|
result = await call_function(experimental_checker, is_experimental_checker_async, feature_name, **dict(_kwargs, **kwargs))
|
|
if result:
|
|
return await call_function(f, is_f_async, *args, **_kwargs)
|
|
else:
|
|
return await call_function(fallback_function, is_fallback_async, *args, **_kwargs)
|
|
|
|
return async_wrapper
|
|
|
|
else:
|
|
|
|
@wraps(f)
|
|
def wrapper(*args, **_kwargs):
|
|
if experimental_checker(feature_name, **dict(_kwargs, **kwargs)):
|
|
return f(*args, **_kwargs)
|
|
else:
|
|
return fallback_function(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def deprecated(message: str):
|
|
"""Simple decorator that marks a method as deprecated."""
|
|
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
if settings.debug:
|
|
logger.warning(f"Function {f.__name__} is deprecated: {message}.")
|
|
return f(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
@dataclass
|
|
class CacheStats:
|
|
"""Note: this will be approximate to not add overhead of locking on counters.
|
|
For exact measurements, use redis or track in other places.
|
|
"""
|
|
|
|
hits: int = 0
|
|
misses: int = 0
|
|
invalidations: int = 0
|
|
|
|
|
|
def async_redis_cache(
|
|
key_func: Callable, prefix: str = REDIS_DEFAULT_CACHE_PREFIX, ttl_s: int = 600, model_class: type[BaseModel] | None = None
|
|
):
|
|
"""
|
|
Decorator for caching async function results in Redis. May be a Noop if redis is not available.
|
|
Will handle pydantic objects and raw values.
|
|
|
|
Attempts to write to and retrieve from cache, but does not fail on those cases
|
|
|
|
Args:
|
|
key_func: function to generate cache key (preferably lowercase strings to follow redis convention)
|
|
prefix: cache key prefix
|
|
ttl_s: time to live (s)
|
|
model_class: custom pydantic model class for serialization/deserialization
|
|
|
|
TODO (cliandy): move to class with generics for type hints
|
|
"""
|
|
|
|
def decorator(func):
|
|
stats = CacheStats()
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
redis_client = await get_redis_client()
|
|
|
|
# Don't bother going through other operations for no reason.
|
|
if isinstance(redis_client, NoopAsyncRedisClient):
|
|
return await func(*args, **kwargs)
|
|
cache_key = get_cache_key(*args, **kwargs)
|
|
cached_value = await redis_client.get(cache_key)
|
|
|
|
try:
|
|
if cached_value is not None:
|
|
stats.hits += 1
|
|
if model_class:
|
|
return model_class.model_validate_json(cached_value)
|
|
return json.loads(cached_value)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to retrieve value from cache: {e}")
|
|
|
|
stats.misses += 1
|
|
result = await func(*args, **kwargs)
|
|
try:
|
|
if model_class:
|
|
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
|
|
elif isinstance(result, (dict, list, str, int, float, bool)):
|
|
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
|
|
else:
|
|
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
|
|
except Exception as e:
|
|
logger.warning(f"Redis cache set failed: {e}")
|
|
return result
|
|
|
|
async def invalidate(*args, **kwargs) -> bool:
|
|
stats.invalidations += 1
|
|
try:
|
|
redis_client = await get_redis_client()
|
|
cache_key = get_cache_key(*args, **kwargs)
|
|
return (await redis_client.delete(cache_key)) > 0
|
|
except Exception as e:
|
|
logger.error(f"Failed to invalidate cache: {e}")
|
|
return False
|
|
|
|
def get_cache_key(*args, **kwargs):
|
|
return f"{prefix}:{key_func(*args, **kwargs)}"
|
|
|
|
async_wrapper.cache_invalidate = invalidate
|
|
async_wrapper.cache_key_func = get_cache_key
|
|
async_wrapper.cache_stats = stats
|
|
return async_wrapper
|
|
|
|
return decorator
|