feat: redis user caching (#2774)

This commit is contained in:
Andy Li
2025-06-12 17:32:07 -07:00
committed by GitHub
parent 22b640a5dd
commit 336896dc5c
5 changed files with 186 additions and 13 deletions

View File

@@ -337,6 +337,7 @@ WEB_SEARCH_CLIP_CONTENT = False
WEB_SEARCH_INCLUDE_SCORE = False
WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n"
REDIS_INCLUDE = "INCLUDE"
REDIS_EXCLUDE = "EXCLUDE"
REDIS_INCLUDE = "include"
REDIS_EXCLUDE = "exclude"
REDIS_SET_DEFAULT_VAL = "None"
REDIS_DEFAULT_CACHE_PREFIX = "letta_cache"

View File

@@ -2,12 +2,17 @@ import asyncio
from functools import wraps
from typing import Any, Optional, Set, Union
import redis.asyncio as redis
from redis import RedisError
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
from letta.log import get_logger
try:
from redis import RedisError
from redis.asyncio import ConnectionPool, Redis
except ImportError:
RedisError = None
Redis = None
ConnectionPool = None
logger = get_logger(__name__)
_client_instance = None
@@ -44,7 +49,7 @@ class AsyncRedisClient:
retry_on_timeout: Retry operations on timeout
health_check_interval: Seconds between health checks
"""
self.pool = redis.ConnectionPool(
self.pool = ConnectionPool(
host=host,
port=port,
db=db,
@@ -59,12 +64,12 @@ class AsyncRedisClient:
self._client = None
self._lock = asyncio.Lock()
async def get_client(self) -> redis.Redis:
async def get_client(self) -> Redis:
"""Get or create Redis client instance."""
if self._client is None:
async with self._lock:
if self._client is None:
self._client = redis.Redis(connection_pool=self.pool)
self._client = Redis(connection_pool=self.pool)
return self._client
async def close(self):
@@ -213,8 +218,8 @@ class AsyncRedisClient:
return await client.decr(key)
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
exclude_key = f"{group}_{REDIS_EXCLUDE}"
include_key = f"{group}_{REDIS_INCLUDE}"
exclude_key = self._get_group_exclusion_key(group)
include_key = self._get_group_inclusion_key(group)
# 1. if the member IS excluded from the group
if self.exists(exclude_key) and await self.scard(exclude_key) > 1:
return bool(await self.smismember(exclude_key, member))
@@ -231,14 +236,29 @@ class AsyncRedisClient:
@staticmethod
def _get_group_inclusion_key(group: str) -> str:
return f"{group}_{REDIS_INCLUDE}"
return f"{group}:{REDIS_INCLUDE}"
@staticmethod
def _get_group_exclusion_key(group: str) -> str:
return f"{group}_{REDIS_EXCLUDE}"
return f"{group}:{REDIS_EXCLUDE}"
class NoopAsyncRedisClient(AsyncRedisClient):
# noinspection PyMissingConstructor
def __init__(self):
pass
async def set(
self,
key: str,
value: Union[str, int, float],
ex: Optional[int] = None,
px: Optional[int] = None,
nx: bool = False,
xx: bool = False,
) -> bool:
return False
async def get(self, key: str, default: Any = None) -> Any:
return default

View File

@@ -1,7 +1,13 @@
import inspect
import json
from dataclasses import dataclass
from functools import wraps
from typing import Callable
from pydantic import BaseModel
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.log import get_logger
from letta.plugins.plugins import get_experimental_checker
from letta.settings import settings
@@ -67,3 +73,88 @@ def deprecated(message: str):
return wrapper
return decorator
@dataclass
class CacheStats:
"""Note: this will be approximate to not add overhead of locking on counters.
For exact measurements, use redis or track in other places.
"""
hits: int = 0
misses: int = 0
invalidations: int = 0
def async_redis_cache(
key_func: Callable, prefix: str = REDIS_DEFAULT_CACHE_PREFIX, ttl_s: int = 300, model_class: type[BaseModel] | None = None
):
"""
Decorator for caching async function results in Redis. May be a Noop if redis is not available.
Will handle pydantic objects and raw values.
Attempts to write to and retrieve from cache, but does not fail on those cases
Args:
key_func: function to generate cache key (preferably lowercase strings to follow redis convention)
prefix: cache key prefix
ttl_s: time to live (s)
model_class: custom pydantic model class for serialization/deserialization
TODO (cliandy): move to class with generics for type hints
"""
def decorator(func):
stats = CacheStats()
@wraps(func)
async def async_wrapper(*args, **kwargs):
redis_client = await get_redis_client()
# Don't bother going through other operations for no reason.
if isinstance(redis_client, NoopAsyncRedisClient):
return await func(*args, **kwargs)
cache_key = get_cache_key(*args, **kwargs)
cached_value = await redis_client.get(cache_key)
try:
if cached_value is not None:
stats.hits += 1
if model_class:
return model_class.model_validate_json(cached_value)
return json.loads(cached_value)
except Exception as e:
logger.warning(f"Failed to retrieve value from cache: {e}")
stats.misses += 1
result = await func(*args, **kwargs)
try:
if model_class:
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
elif isinstance(result, (dict, list, str, int, float, bool)):
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
else:
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return result
async def invalidate(*args, **kwargs) -> bool:
stats.invalidations += 1
try:
redis_client = await get_redis_client()
cache_key = get_cache_key(*args, **kwargs)
return (await redis_client.delete(cache_key)) > 0
except Exception as e:
logger.error(f"Failed to invalidate cache: {e}")
return False
def get_cache_key(*args, **kwargs):
return f"{prefix}:{key_func(*args, **kwargs)}"
# async_wrapper.cache_invalidate = invalidate
async_wrapper.cache_key_func = get_cache_key
async_wrapper.cache_stats = stats
return async_wrapper
return decorator

View File

@@ -3,6 +3,9 @@ from typing import List, Optional
from sqlalchemy import select, text
from letta.constants import DEFAULT_ORG_ID
from letta.data_sources.redis_client import get_redis_client
from letta.helpers.decorators import async_redis_cache
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.orm.user import User as UserModel
@@ -12,6 +15,8 @@ from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.utils import enforce_types
logger = get_logger(__name__)
class UserManager:
"""Manager class to handle business logic related to Users."""
@@ -58,6 +63,7 @@ class UserManager:
# If it doesn't exist, make it
actor = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id)
await actor.create_async(session)
await self._invalidate_actor_cache(self.DEFAULT_USER_ID)
return actor.to_pydantic()
@@ -77,6 +83,7 @@ class UserManager:
async with db_registry.async_session() as session:
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
await new_user.create_async(session)
await self._invalidate_actor_cache(new_user.id)
return new_user.to_pydantic()
@enforce_types
@@ -111,6 +118,7 @@ class UserManager:
# Commit the updated user
await existing_user.update_async(session)
await self._invalidate_actor_cache(user_update.id)
return existing_user.to_pydantic()
@enforce_types
@@ -132,6 +140,7 @@ class UserManager:
# Delete from user table
user = await UserModel.read_async(db_session=session, identifier=user_id)
await user.hard_delete_async(session)
await self._invalidate_actor_cache(user_id)
@enforce_types
@trace_method
@@ -143,6 +152,7 @@ class UserManager:
@enforce_types
@trace_method
@async_redis_cache(key_func=lambda self, actor_id: f"actor_id:{actor_id}", model_class=PydanticUser)
async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser:
"""Fetch a user by ID asynchronously."""
async with db_registry.async_session() as session:
@@ -225,3 +235,15 @@ class UserManager:
limit=limit,
)
return [user.to_pydantic() for user in users]
async def _invalidate_actor_cache(self, actor_id: str) -> bool:
"""Invalidates the actor cache on CRUD operations.
TODO (cliandy): see notes on redis cache decorator
"""
try:
redis_client = await get_redis_client()
cache_key = self.get_actor_by_id_async.cache_key_func(self, actor_id)
return (await redis_client.delete(cache_key)) > 0
except Exception as e:
logger.error(f"Failed to invalidate cache: {e}")
return False

View File

@@ -37,10 +37,12 @@ from letta.constants import (
MCP_TOOL_TAG_NAME_PREFIX,
MULTI_AGENT_TOOLS,
)
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.embeddings import embedding_model
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import AsyncTimer
from letta.jobs.types import ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
from letta.orm import Base, Block
from letta.orm.block_history import BlockHistory
@@ -83,7 +85,7 @@ from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.block_manager import BlockManager
from letta.settings import tool_settings
from letta.settings import settings, tool_settings
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
from tests.utils import random_string
@@ -2733,6 +2735,43 @@ async def test_update_user(server: SyncServer, event_loop):
assert user.organization_id == test_org.id
@pytest.mark.asyncio
async def test_user_caching(server: SyncServer, event_loop, default_user, performance_pct=0.4):
if isinstance(await get_redis_client(), NoopAsyncRedisClient):
pytest.skip("redis not available")
# Invalidate previous cache behavior.
await server.user_manager._invalidate_actor_cache(default_user.id)
before_stats = server.user_manager.get_actor_by_id_async.cache_stats
before_cache_misses = before_stats.misses
before_cache_hits = before_stats.hits
# First call (expected to miss the cache)
async with AsyncTimer() as timer:
actor = await server.user_manager.get_actor_by_id_async(default_user.id)
duration_first = timer.elapsed_ns
print(f"Call 1: {duration_first:.2e}ns")
assert actor.id == default_user.id
assert duration_first > 0 # Sanity check: took non-zero time
cached_hits = 10
durations = []
for i in range(cached_hits):
async with AsyncTimer() as timer:
actor_cached = await server.user_manager.get_actor_by_id_async(default_user.id)
duration = timer.elapsed_ns
durations.append(duration)
print(f"Call {i+2}: {duration:.2e}ns")
assert actor_cached == actor
for d in durations:
assert d < duration_first * performance_pct
stats = server.user_manager.get_actor_by_id_async.cache_stats
print(f"Before calls: {before_stats}")
print(f"After calls: {stats}")
# Assert cache stats
assert stats.misses - before_cache_misses == 1
assert stats.hits - before_cache_hits == cached_hits
# ======================================================================================================================
# ToolManager Tests
# ======================================================================================================================