fix: add redis tracing (#8132)

base
This commit is contained in:
jnjpng
2025-12-29 11:55:46 -08:00
committed by Caren Thomas
parent b6535b7590
commit fa9a98351d

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
from functools import wraps
from typing import Callable
from opentelemetry import trace
from pydantic import BaseModel
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
@@ -13,6 +14,7 @@ from letta.plugins.plugins import get_experimental_checker
from letta.settings import settings
logger = get_logger(__name__)
tracer = trace.get_tracer(__name__)
def experimental(feature_name: str, fallback_function: Callable, **kwargs):
@@ -109,35 +111,59 @@ def async_redis_cache(
@wraps(func)
async def async_wrapper(*args, **kwargs):
redis_client = await get_redis_client()
with tracer.start_as_current_span("redis_cache", attributes={"cache.function": func.__name__}) as span:
# 1. Get Redis client
with tracer.start_as_current_span("redis_cache.get_client"):
redis_client = await get_redis_client()
# Don't bother going through other operations for no reason.
if isinstance(redis_client, NoopAsyncRedisClient):
return await func(*args, **kwargs)
cache_key = get_cache_key(*args, **kwargs)
cached_value = await redis_client.get(cache_key)
# Don't bother going through other operations for no reason.
if isinstance(redis_client, NoopAsyncRedisClient):
span.set_attribute("cache.noop", True)
return await func(*args, **kwargs)
try:
if cached_value is not None:
stats.hits += 1
if model_class:
return model_class.model_validate_json(cached_value)
return json.loads(cached_value)
except Exception as e:
logger.warning(f"Failed to retrieve value from cache: {e}")
cache_key = get_cache_key(*args, **kwargs)
span.set_attribute("cache.key", cache_key)
stats.misses += 1
result = await func(*args, **kwargs)
try:
if model_class:
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
elif isinstance(result, (dict, list, str, int, float, bool)):
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
else:
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return result
# 2. Try cache read
with tracer.start_as_current_span("redis_cache.get") as get_span:
cached_value = await redis_client.get(cache_key)
get_span.set_attribute("cache.hit", cached_value is not None)
try:
if cached_value is not None:
stats.hits += 1
span.set_attribute("cache.result", "hit")
# 3. Deserialize cache hit
with tracer.start_as_current_span("redis_cache.deserialize"):
if model_class:
return model_class.model_validate_json(cached_value)
return json.loads(cached_value)
except Exception as e:
logger.warning(f"Failed to retrieve value from cache: {e}")
span.record_exception(e)
stats.misses += 1
span.set_attribute("cache.result", "miss")
# 4. Call original function
with tracer.start_as_current_span("redis_cache.call_original"):
result = await func(*args, **kwargs)
# 5. Write to cache
try:
with tracer.start_as_current_span("redis_cache.set") as set_span:
if model_class:
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
elif isinstance(result, (dict, list, str, int, float, bool)):
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
else:
set_span.set_attribute("cache.set_skipped", True)
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
span.record_exception(e)
return result
async def invalidate(*args, **kwargs) -> bool:
stats.invalidations += 1