@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from opentelemetry import trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
|
||||
@@ -13,6 +14,7 @@ from letta.plugins.plugins import get_experimental_checker
|
||||
from letta.settings import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
def experimental(feature_name: str, fallback_function: Callable, **kwargs):
|
||||
@@ -109,35 +111,59 @@ def async_redis_cache(
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
redis_client = await get_redis_client()
|
||||
with tracer.start_as_current_span("redis_cache", attributes={"cache.function": func.__name__}) as span:
|
||||
# 1. Get Redis client
|
||||
with tracer.start_as_current_span("redis_cache.get_client"):
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
# Don't bother going through other operations for no reason.
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
return await func(*args, **kwargs)
|
||||
cache_key = get_cache_key(*args, **kwargs)
|
||||
cached_value = await redis_client.get(cache_key)
|
||||
# Don't bother going through other operations for no reason.
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
span.set_attribute("cache.noop", True)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
if cached_value is not None:
|
||||
stats.hits += 1
|
||||
if model_class:
|
||||
return model_class.model_validate_json(cached_value)
|
||||
return json.loads(cached_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to retrieve value from cache: {e}")
|
||||
cache_key = get_cache_key(*args, **kwargs)
|
||||
span.set_attribute("cache.key", cache_key)
|
||||
|
||||
stats.misses += 1
|
||||
result = await func(*args, **kwargs)
|
||||
try:
|
||||
if model_class:
|
||||
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
|
||||
elif isinstance(result, (dict, list, str, int, float, bool)):
|
||||
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
|
||||
else:
|
||||
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return result
|
||||
# 2. Try cache read
|
||||
with tracer.start_as_current_span("redis_cache.get") as get_span:
|
||||
cached_value = await redis_client.get(cache_key)
|
||||
get_span.set_attribute("cache.hit", cached_value is not None)
|
||||
|
||||
try:
|
||||
if cached_value is not None:
|
||||
stats.hits += 1
|
||||
span.set_attribute("cache.result", "hit")
|
||||
# 3. Deserialize cache hit
|
||||
with tracer.start_as_current_span("redis_cache.deserialize"):
|
||||
if model_class:
|
||||
return model_class.model_validate_json(cached_value)
|
||||
return json.loads(cached_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to retrieve value from cache: {e}")
|
||||
span.record_exception(e)
|
||||
|
||||
stats.misses += 1
|
||||
span.set_attribute("cache.result", "miss")
|
||||
|
||||
# 4. Call original function
|
||||
with tracer.start_as_current_span("redis_cache.call_original"):
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# 5. Write to cache
|
||||
try:
|
||||
with tracer.start_as_current_span("redis_cache.set") as set_span:
|
||||
if model_class:
|
||||
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
|
||||
elif isinstance(result, (dict, list, str, int, float, bool)):
|
||||
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
|
||||
else:
|
||||
set_span.set_attribute("cache.set_skipped", True)
|
||||
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
span.record_exception(e)
|
||||
|
||||
return result
|
||||
|
||||
async def invalidate(*args, **kwargs) -> bool:
|
||||
stats.invalidations += 1
|
||||
|
||||
Reference in New Issue
Block a user