feat: redis user caching (#2774)
This commit is contained in:
@@ -337,6 +337,7 @@ WEB_SEARCH_CLIP_CONTENT = False
|
||||
WEB_SEARCH_INCLUDE_SCORE = False
|
||||
WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n"
|
||||
|
||||
REDIS_INCLUDE = "INCLUDE"
|
||||
REDIS_EXCLUDE = "EXCLUDE"
|
||||
REDIS_INCLUDE = "include"
|
||||
REDIS_EXCLUDE = "exclude"
|
||||
REDIS_SET_DEFAULT_VAL = "None"
|
||||
REDIS_DEFAULT_CACHE_PREFIX = "letta_cache"
|
||||
|
||||
@@ -2,12 +2,17 @@ import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Optional, Set, Union
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis import RedisError
|
||||
|
||||
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
|
||||
from letta.log import get_logger
|
||||
|
||||
try:
|
||||
from redis import RedisError
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
except ImportError:
|
||||
RedisError = None
|
||||
Redis = None
|
||||
ConnectionPool = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_client_instance = None
|
||||
@@ -44,7 +49,7 @@ class AsyncRedisClient:
|
||||
retry_on_timeout: Retry operations on timeout
|
||||
health_check_interval: Seconds between health checks
|
||||
"""
|
||||
self.pool = redis.ConnectionPool(
|
||||
self.pool = ConnectionPool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
@@ -59,12 +64,12 @@ class AsyncRedisClient:
|
||||
self._client = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(self) -> redis.Redis:
|
||||
async def get_client(self) -> Redis:
|
||||
"""Get or create Redis client instance."""
|
||||
if self._client is None:
|
||||
async with self._lock:
|
||||
if self._client is None:
|
||||
self._client = redis.Redis(connection_pool=self.pool)
|
||||
self._client = Redis(connection_pool=self.pool)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
@@ -213,8 +218,8 @@ class AsyncRedisClient:
|
||||
return await client.decr(key)
|
||||
|
||||
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
|
||||
exclude_key = f"{group}_{REDIS_EXCLUDE}"
|
||||
include_key = f"{group}_{REDIS_INCLUDE}"
|
||||
exclude_key = self._get_group_exclusion_key(group)
|
||||
include_key = self._get_group_inclusion_key(group)
|
||||
# 1. if the member IS excluded from the group
|
||||
if self.exists(exclude_key) and await self.scard(exclude_key) > 1:
|
||||
return bool(await self.smismember(exclude_key, member))
|
||||
@@ -231,14 +236,29 @@ class AsyncRedisClient:
|
||||
|
||||
@staticmethod
|
||||
def _get_group_inclusion_key(group: str) -> str:
|
||||
return f"{group}_{REDIS_INCLUDE}"
|
||||
return f"{group}:{REDIS_INCLUDE}"
|
||||
|
||||
@staticmethod
|
||||
def _get_group_exclusion_key(group: str) -> str:
|
||||
return f"{group}_{REDIS_EXCLUDE}"
|
||||
return f"{group}:{REDIS_EXCLUDE}"
|
||||
|
||||
|
||||
class NoopAsyncRedisClient(AsyncRedisClient):
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Union[str, int, float],
|
||||
ex: Optional[int] = None,
|
||||
px: Optional[int] = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
async def get(self, key: str, default: Any = None) -> Any:
|
||||
return default
|
||||
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.log import get_logger
|
||||
from letta.plugins.plugins import get_experimental_checker
|
||||
from letta.settings import settings
|
||||
@@ -67,3 +73,88 @@ def deprecated(message: str):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Note: this will be approximate to not add overhead of locking on counters.
|
||||
For exact measurements, use redis or track in other places.
|
||||
"""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
invalidations: int = 0
|
||||
|
||||
|
||||
def async_redis_cache(
|
||||
key_func: Callable, prefix: str = REDIS_DEFAULT_CACHE_PREFIX, ttl_s: int = 300, model_class: type[BaseModel] | None = None
|
||||
):
|
||||
"""
|
||||
Decorator for caching async function results in Redis. May be a Noop if redis is not available.
|
||||
Will handle pydantic objects and raw values.
|
||||
|
||||
Attempts to write to and retrieve from cache, but does not fail on those cases
|
||||
|
||||
Args:
|
||||
key_func: function to generate cache key (preferably lowercase strings to follow redis convention)
|
||||
prefix: cache key prefix
|
||||
ttl_s: time to live (s)
|
||||
model_class: custom pydantic model class for serialization/deserialization
|
||||
|
||||
TODO (cliandy): move to class with generics for type hints
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
stats = CacheStats()
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
# Don't bother going through other operations for no reason.
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
return await func(*args, **kwargs)
|
||||
cache_key = get_cache_key(*args, **kwargs)
|
||||
cached_value = await redis_client.get(cache_key)
|
||||
|
||||
try:
|
||||
if cached_value is not None:
|
||||
stats.hits += 1
|
||||
if model_class:
|
||||
return model_class.model_validate_json(cached_value)
|
||||
return json.loads(cached_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to retrieve value from cache: {e}")
|
||||
|
||||
stats.misses += 1
|
||||
result = await func(*args, **kwargs)
|
||||
try:
|
||||
if model_class:
|
||||
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
|
||||
elif isinstance(result, (dict, list, str, int, float, bool)):
|
||||
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
|
||||
else:
|
||||
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return result
|
||||
|
||||
async def invalidate(*args, **kwargs) -> bool:
|
||||
stats.invalidations += 1
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
cache_key = get_cache_key(*args, **kwargs)
|
||||
return (await redis_client.delete(cache_key)) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to invalidate cache: {e}")
|
||||
return False
|
||||
|
||||
def get_cache_key(*args, **kwargs):
|
||||
return f"{prefix}:{key_func(*args, **kwargs)}"
|
||||
|
||||
# async_wrapper.cache_invalidate = invalidate
|
||||
async_wrapper.cache_key_func = get_cache_key
|
||||
async_wrapper.cache_stats = stats
|
||||
return async_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -3,6 +3,9 @@ from typing import List, Optional
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.orm.user import User as UserModel
|
||||
@@ -12,6 +15,8 @@ from letta.schemas.user import UserUpdate
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UserManager:
|
||||
"""Manager class to handle business logic related to Users."""
|
||||
@@ -58,6 +63,7 @@ class UserManager:
|
||||
# If it doesn't exist, make it
|
||||
actor = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id)
|
||||
await actor.create_async(session)
|
||||
await self._invalidate_actor_cache(self.DEFAULT_USER_ID)
|
||||
|
||||
return actor.to_pydantic()
|
||||
|
||||
@@ -77,6 +83,7 @@ class UserManager:
|
||||
async with db_registry.async_session() as session:
|
||||
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
|
||||
await new_user.create_async(session)
|
||||
await self._invalidate_actor_cache(new_user.id)
|
||||
return new_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -111,6 +118,7 @@ class UserManager:
|
||||
|
||||
# Commit the updated user
|
||||
await existing_user.update_async(session)
|
||||
await self._invalidate_actor_cache(user_update.id)
|
||||
return existing_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -132,6 +140,7 @@ class UserManager:
|
||||
# Delete from user table
|
||||
user = await UserModel.read_async(db_session=session, identifier=user_id)
|
||||
await user.hard_delete_async(session)
|
||||
await self._invalidate_actor_cache(user_id)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -143,6 +152,7 @@ class UserManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@async_redis_cache(key_func=lambda self, actor_id: f"actor_id:{actor_id}", model_class=PydanticUser)
|
||||
async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser:
|
||||
"""Fetch a user by ID asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -225,3 +235,15 @@ class UserManager:
|
||||
limit=limit,
|
||||
)
|
||||
return [user.to_pydantic() for user in users]
|
||||
|
||||
async def _invalidate_actor_cache(self, actor_id: str) -> bool:
|
||||
"""Invalidates the actor cache on CRUD operations.
|
||||
TODO (cliandy): see notes on redis cache decorator
|
||||
"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
cache_key = self.get_actor_by_id_async.cache_key_func(self, actor_id)
|
||||
return (await redis_client.delete(cache_key)) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to invalidate cache: {e}")
|
||||
return False
|
||||
|
||||
@@ -37,10 +37,12 @@ from letta.constants import (
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
MULTI_AGENT_TOOLS,
|
||||
)
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import AsyncTimer
|
||||
from letta.jobs.types import ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
|
||||
from letta.orm import Base, Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
@@ -83,7 +85,7 @@ from letta.schemas.user import UserUpdate
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.settings import tool_settings
|
||||
from letta.settings import settings, tool_settings
|
||||
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
|
||||
from tests.utils import random_string
|
||||
|
||||
@@ -2733,6 +2735,43 @@ async def test_update_user(server: SyncServer, event_loop):
|
||||
assert user.organization_id == test_org.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_caching(server: SyncServer, event_loop, default_user, performance_pct=0.4):
|
||||
if isinstance(await get_redis_client(), NoopAsyncRedisClient):
|
||||
pytest.skip("redis not available")
|
||||
# Invalidate previous cache behavior.
|
||||
await server.user_manager._invalidate_actor_cache(default_user.id)
|
||||
before_stats = server.user_manager.get_actor_by_id_async.cache_stats
|
||||
before_cache_misses = before_stats.misses
|
||||
before_cache_hits = before_stats.hits
|
||||
|
||||
# First call (expected to miss the cache)
|
||||
async with AsyncTimer() as timer:
|
||||
actor = await server.user_manager.get_actor_by_id_async(default_user.id)
|
||||
duration_first = timer.elapsed_ns
|
||||
print(f"Call 1: {duration_first:.2e}ns")
|
||||
assert actor.id == default_user.id
|
||||
assert duration_first > 0 # Sanity check: took non-zero time
|
||||
cached_hits = 10
|
||||
durations = []
|
||||
for i in range(cached_hits):
|
||||
async with AsyncTimer() as timer:
|
||||
actor_cached = await server.user_manager.get_actor_by_id_async(default_user.id)
|
||||
duration = timer.elapsed_ns
|
||||
durations.append(duration)
|
||||
print(f"Call {i+2}: {duration:.2e}ns")
|
||||
assert actor_cached == actor
|
||||
for d in durations:
|
||||
assert d < duration_first * performance_pct
|
||||
stats = server.user_manager.get_actor_by_id_async.cache_stats
|
||||
|
||||
print(f"Before calls: {before_stats}")
|
||||
print(f"After calls: {stats}")
|
||||
# Assert cache stats
|
||||
assert stats.misses - before_cache_misses == 1
|
||||
assert stats.hits - before_cache_hits == cached_hits
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# ToolManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user