diff --git a/letta/constants.py b/letta/constants.py index 8c4ce4ac..412c5e19 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -337,6 +337,7 @@ WEB_SEARCH_CLIP_CONTENT = False WEB_SEARCH_INCLUDE_SCORE = False WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n" -REDIS_INCLUDE = "INCLUDE" -REDIS_EXCLUDE = "EXCLUDE" +REDIS_INCLUDE = "include" +REDIS_EXCLUDE = "exclude" REDIS_SET_DEFAULT_VAL = "None" +REDIS_DEFAULT_CACHE_PREFIX = "letta_cache" diff --git a/letta/data_sources/redis_client.py b/letta/data_sources/redis_client.py index 436a0afb..17461c70 100644 --- a/letta/data_sources/redis_client.py +++ b/letta/data_sources/redis_client.py @@ -2,12 +2,17 @@ import asyncio from functools import wraps from typing import Any, Optional, Set, Union -import redis.asyncio as redis -from redis import RedisError - from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL from letta.log import get_logger +try: + from redis import RedisError + from redis.asyncio import ConnectionPool, Redis +except ImportError: + RedisError = None + Redis = None + ConnectionPool = None + logger = get_logger(__name__) _client_instance = None @@ -44,7 +49,7 @@ class AsyncRedisClient: retry_on_timeout: Retry operations on timeout health_check_interval: Seconds between health checks """ - self.pool = redis.ConnectionPool( + self.pool = ConnectionPool( host=host, port=port, db=db, @@ -59,12 +64,12 @@ class AsyncRedisClient: self._client = None self._lock = asyncio.Lock() - async def get_client(self) -> redis.Redis: + async def get_client(self) -> Redis: """Get or create Redis client instance.""" if self._client is None: async with self._lock: if self._client is None: - self._client = redis.Redis(connection_pool=self.pool) + self._client = Redis(connection_pool=self.pool) return self._client async def close(self): @@ -213,8 +218,8 @@ class AsyncRedisClient: return await client.decr(key) async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool: - exclude_key = f"{group}_{REDIS_EXCLUDE}" - include_key = f"{group}_{REDIS_INCLUDE}" + exclude_key = self._get_group_exclusion_key(group) + include_key = self._get_group_inclusion_key(group) # 1. if the member IS excluded from the group if self.exists(exclude_key) and await self.scard(exclude_key) > 1: return bool(await self.smismember(exclude_key, member)) @@ -231,14 +236,29 @@ class AsyncRedisClient: @staticmethod def _get_group_inclusion_key(group: str) -> str: - return f"{group}_{REDIS_INCLUDE}" + return f"{group}:{REDIS_INCLUDE}" @staticmethod def _get_group_exclusion_key(group: str) -> str: - return f"{group}_{REDIS_EXCLUDE}" + return f"{group}:{REDIS_EXCLUDE}" class NoopAsyncRedisClient(AsyncRedisClient): + # noinspection PyMissingConstructor + def __init__(self): + pass + + async def set( + self, + key: str, + value: Union[str, int, float], + ex: Optional[int] = None, + px: Optional[int] = None, + nx: bool = False, + xx: bool = False, + ) -> bool: + return False + async def get(self, key: str, default: Any = None) -> Any: return default diff --git a/letta/helpers/decorators.py b/letta/helpers/decorators.py index ae20b4f6..13648a3b 100644 --- a/letta/helpers/decorators.py +++ b/letta/helpers/decorators.py @@ -1,7 +1,13 @@ import inspect +import json +from dataclasses import dataclass from functools import wraps from typing import Callable +from pydantic import BaseModel + +from letta.constants import REDIS_DEFAULT_CACHE_PREFIX +from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client from letta.log import get_logger from letta.plugins.plugins import get_experimental_checker from letta.settings import settings @@ -67,3 +73,88 @@ def deprecated(message: str): return wrapper return decorator + + +@dataclass +class CacheStats: + """Note: this will be approximate to not add overhead of locking on counters. + For exact measurements, use redis or track in other places. + """ + + hits: int = 0 + misses: int = 0 + invalidations: int = 0 + + +def async_redis_cache( + key_func: Callable, prefix: str = REDIS_DEFAULT_CACHE_PREFIX, ttl_s: int = 300, model_class: type[BaseModel] | None = None +): + """ + Decorator for caching async function results in Redis. May be a Noop if redis is not available. + Will handle pydantic objects and raw values. + + Attempts to write to and retrieve from cache, but does not fail on those cases + + Args: + key_func: function to generate cache key (preferably lowercase strings to follow redis convention) + prefix: cache key prefix + ttl_s: time to live (s) + model_class: custom pydantic model class for serialization/deserialization + + TODO (cliandy): move to class with generics for type hints + """ + + def decorator(func): + stats = CacheStats() + + @wraps(func) + async def async_wrapper(*args, **kwargs): + redis_client = await get_redis_client() + + # Don't bother going through other operations for no reason. + if isinstance(redis_client, NoopAsyncRedisClient): + return await func(*args, **kwargs) + cache_key = get_cache_key(*args, **kwargs) + cached_value = await redis_client.get(cache_key) + + try: + if cached_value is not None: + stats.hits += 1 + if model_class: + return model_class.model_validate_json(cached_value) + return json.loads(cached_value) + except Exception as e: + logger.warning(f"Failed to retrieve value from cache: {e}") + + stats.misses += 1 + result = await func(*args, **kwargs) + try: + if model_class: + await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s) + elif isinstance(result, (dict, list, str, int, float, bool)): + await redis_client.set(cache_key, json.dumps(result), ex=ttl_s) + else: + logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}") + except Exception as e: + logger.warning(f"Redis cache set failed: {e}") + return result + + async def invalidate(*args, **kwargs) -> bool: + stats.invalidations += 1 + try: + redis_client = await get_redis_client() + cache_key = get_cache_key(*args, **kwargs) + return (await redis_client.delete(cache_key)) > 0 + except Exception as e: + logger.error(f"Failed to invalidate cache: {e}") + return False + + def get_cache_key(*args, **kwargs): + return f"{prefix}:{key_func(*args, **kwargs)}" + + # async_wrapper.cache_invalidate = invalidate + async_wrapper.cache_key_func = get_cache_key + async_wrapper.cache_stats = stats + return async_wrapper + + return decorator diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 7bdf3211..306ec7f0 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -3,6 +3,9 @@ from typing import List, Optional from sqlalchemy import select, text from letta.constants import DEFAULT_ORG_ID +from letta.data_sources.redis_client import get_redis_client +from letta.helpers.decorators import async_redis_cache +from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.orm.user import User as UserModel @@ -12,6 +15,8 @@ from letta.schemas.user import UserUpdate from letta.server.db import db_registry from letta.utils import enforce_types +logger = get_logger(__name__) + class UserManager: """Manager class to handle business logic related to Users.""" @@ -58,6 +63,7 @@ class UserManager: # If it doesn't exist, make it actor = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id) await actor.create_async(session) + await self._invalidate_actor_cache(self.DEFAULT_USER_ID) return actor.to_pydantic() @@ -77,6 +83,7 @@ class UserManager: async with db_registry.async_session() as session: new_user = UserModel(**pydantic_user.model_dump(to_orm=True)) await new_user.create_async(session) + await self._invalidate_actor_cache(new_user.id) return new_user.to_pydantic() @enforce_types @@ -111,6 +118,7 @@ class UserManager: # Commit the updated user await existing_user.update_async(session) + await self._invalidate_actor_cache(user_update.id) return existing_user.to_pydantic() @enforce_types @@ -132,6 +140,7 @@ class UserManager: # Delete from user table user = await UserModel.read_async(db_session=session, identifier=user_id) await user.hard_delete_async(session) + await self._invalidate_actor_cache(user_id) @enforce_types @trace_method @@ -143,6 +152,7 @@ class UserManager: @enforce_types @trace_method + @async_redis_cache(key_func=lambda self, actor_id: f"actor_id:{actor_id}", model_class=PydanticUser) async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser: """Fetch a user by ID asynchronously.""" async with db_registry.async_session() as session: @@ -225,3 +235,15 @@ class UserManager: limit=limit, ) return [user.to_pydantic() for user in users] + + async def _invalidate_actor_cache(self, actor_id: str) -> bool: + """Invalidates the actor cache on CRUD operations. + TODO (cliandy): see notes on redis cache decorator + """ + try: + redis_client = await get_redis_client() + cache_key = self.get_actor_by_id_async.cache_key_func(self, actor_id) + return (await redis_client.delete(cache_key)) > 0 + except Exception as e: + logger.error(f"Failed to invalidate cache: {e}") + return False diff --git a/tests/test_managers.py b/tests/test_managers.py index c4730c55..ed40aa7a 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -37,10 +37,12 @@ from letta.constants import ( MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS, ) +from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.mcp_client.types import MCPTool from letta.helpers import ToolRulesSolver +from letta.helpers.datetime_helpers import AsyncTimer from letta.jobs.types import ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo from letta.orm import Base, Block from letta.orm.block_history import BlockHistory @@ -83,7 +85,7 @@ from letta.schemas.user import UserUpdate from letta.server.db import db_registry from letta.server.server import SyncServer from letta.services.block_manager import BlockManager -from letta.settings import tool_settings +from letta.settings import settings, tool_settings from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview from tests.utils import random_string @@ -2733,6 +2735,43 @@ async def test_update_user(server: SyncServer, event_loop): assert user.organization_id == test_org.id +@pytest.mark.asyncio +async def test_user_caching(server: SyncServer, event_loop, default_user, performance_pct=0.4): + if isinstance(await get_redis_client(), NoopAsyncRedisClient): + pytest.skip("redis not available") + # Invalidate previous cache behavior. + await server.user_manager._invalidate_actor_cache(default_user.id) + before_stats = server.user_manager.get_actor_by_id_async.cache_stats + before_cache_misses = before_stats.misses + before_cache_hits = before_stats.hits + + # First call (expected to miss the cache) + async with AsyncTimer() as timer: + actor = await server.user_manager.get_actor_by_id_async(default_user.id) + duration_first = timer.elapsed_ns + print(f"Call 1: {duration_first:.2e}ns") + assert actor.id == default_user.id + assert duration_first > 0 # Sanity check: took non-zero time + cached_hits = 10 + durations = [] + for i in range(cached_hits): + async with AsyncTimer() as timer: + actor_cached = await server.user_manager.get_actor_by_id_async(default_user.id) + duration = timer.elapsed_ns + durations.append(duration) + print(f"Call {i+2}: {duration:.2e}ns") + assert actor_cached == actor + for d in durations: + assert d < duration_first * performance_pct + stats = server.user_manager.get_actor_by_id_async.cache_stats + + print(f"Before calls: {before_stats}") + print(f"After calls: {stats}") + # Assert cache stats + assert stats.misses - before_cache_misses == 1 + assert stats.hits - before_cache_hits == cached_hits + + # ====================================================================================================================== # ToolManager Tests # ======================================================================================================================