Files
letta-server/letta/data_sources/redis_client.py
Kian Jones f5c4ab50f4 chore: add ty + pre-commit hook and repeal even more ruff rules (#9504)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import

* remove all ignores, add FastAPI rules and Ruff rules

* add ty and precommit

* ruff stuff

* ty check fixes

* ty check fixes pt 2

* error on invalid
2026-02-24 10:55:11 -08:00

603 lines
19 KiB
Python

import asyncio
from functools import wraps
from typing import Any, Dict, List, Optional, Set, Union
from letta.constants import (
CONVERSATION_LOCK_PREFIX,
CONVERSATION_LOCK_TTL_SECONDS,
MEMORY_REPO_LOCK_PREFIX,
MEMORY_REPO_LOCK_TTL_SECONDS,
REDIS_EXCLUDE,
REDIS_INCLUDE,
REDIS_SET_DEFAULT_VAL,
)
from letta.errors import ConversationBusyError, MemoryRepoBusyError
from letta.log import get_logger
from letta.settings import settings
try:
from redis import RedisError
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.lock import Lock
except ImportError:
RedisError = None
Redis = None
ConnectionPool = None
Lock = None
logger = get_logger(__name__)
_client_instance = None
class AsyncRedisClient:
"""Async Redis client with connection pooling and error handling"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
max_connections: int = 50,
decode_responses: bool = True,
socket_timeout: int = 5,
socket_connect_timeout: int = 5,
retry_on_timeout: bool = True,
health_check_interval: int = 30,
):
"""
Initialize Redis client with connection pool.
Args:
host: Redis server hostname
port: Redis server port
db: Database number
password: Redis password if required
max_connections: Maximum number of connections in pool
decode_responses: Decode byte responses to strings
socket_timeout: Socket timeout in seconds
socket_connect_timeout: Socket connection timeout
retry_on_timeout: Retry operations on timeout
health_check_interval: Seconds between health checks
"""
self.pool = ConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
decode_responses=decode_responses,
socket_timeout=socket_timeout,
socket_connect_timeout=socket_connect_timeout,
retry_on_timeout=retry_on_timeout,
health_check_interval=health_check_interval,
)
self._client = None
self._lock = asyncio.Lock()
async def get_client(self) -> Redis:
"""Get or create Redis client instance."""
if self._client is None:
async with self._lock:
if self._client is None:
self._client = Redis(connection_pool=self.pool)
return self._client
async def close(self):
"""Close Redis connection and cleanup."""
if self._client:
await self._client.close()
await self.pool.disconnect()
self._client = None
async def __aenter__(self):
"""Async context manager entry."""
await self.get_client()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.close()
# Health check and connection management
async def ping(self) -> bool:
"""Check if Redis is accessible."""
try:
client = await self.get_client()
await client.ping()
return True
except RedisError:
logger.exception("Redis ping failed")
return False
async def wait_for_ready(self, timeout: int = 30, interval: float = 0.5):
"""Wait for Redis to be ready."""
start_time = asyncio.get_event_loop().time()
while (asyncio.get_event_loop().time() - start_time) < timeout:
if await self.ping():
return
await asyncio.sleep(interval)
raise ConnectionError(f"Redis not ready after {timeout} seconds")
# Retry decorator for resilience
def with_retry(max_attempts: int = 3, delay: float = 0.1):
"""Decorator to retry Redis operations on failure."""
def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
last_error = None
for attempt in range(max_attempts):
try:
return await func(self, *args, **kwargs)
except (ConnectionError, TimeoutError) as e:
last_error = e
if attempt < max_attempts - 1:
await asyncio.sleep(delay * (2**attempt))
logger.warning(f"Retry {attempt + 1}/{max_attempts} for {func.__name__}: {e}")
raise last_error
return wrapper
return decorator
# Basic operations with error handling
@with_retry()
async def get(self, key: str, default: Any = None) -> Any:
"""Get value by key."""
try:
client = await self.get_client()
return await client.get(key)
except Exception:
return default
@with_retry()
async def set(
self,
key: str,
value: Union[str, int, float],
ex: Optional[int] = None,
px: Optional[int] = None,
nx: bool = False,
xx: bool = False,
) -> bool:
"""
Set key-value with options.
Args:
key: Redis key
value: Value to store
ex: Expire time in seconds
px: Expire time in milliseconds
nx: Only set if key doesn't exist
xx: Only set if key exists
"""
client = await self.get_client()
return await client.set(key, value, ex=ex, px=px, nx=nx, xx=xx)
@with_retry()
async def delete(self, *keys: str) -> int:
"""Delete one or more keys."""
client = await self.get_client()
return await client.delete(*keys)
async def acquire_conversation_lock(
self,
conversation_id: str,
token: str,
) -> Optional["Lock"]:
"""
Acquire a distributed lock for a conversation.
Args:
conversation_id: The ID for the conversation
token: Unique identifier for the lock holder (for debugging/tracing)
Returns:
Lock object if acquired, raises ConversationBusyError if in use
"""
if Lock is None:
return None
client = await self.get_client()
lock_key = f"{CONVERSATION_LOCK_PREFIX}{conversation_id}"
lock = Lock(
client,
lock_key,
timeout=CONVERSATION_LOCK_TTL_SECONDS,
blocking=False,
thread_local=False, # We manage token explicitly
raise_on_release_error=False, # We handle release errors ourselves
)
if await lock.acquire(token=token):
return lock
lock_holder_token = await client.get(lock_key)
raise ConversationBusyError(
conversation_id=conversation_id,
lock_holder_token=lock_holder_token,
)
async def release_conversation_lock(self, conversation_id: str) -> bool:
"""
Release a conversation lock by conversation_id.
Args:
conversation_id: The conversation ID to release the lock for
Returns:
True if lock was released, False if release failed
"""
try:
client = await self.get_client()
lock_key = f"{CONVERSATION_LOCK_PREFIX}{conversation_id}"
await client.delete(lock_key)
return True
except Exception as e:
logger.warning(f"Failed to release conversation lock for conversation {conversation_id}: {e}")
return False
async def acquire_memory_repo_lock(
self,
agent_id: str,
token: str,
) -> Optional["Lock"]:
"""
Acquire a distributed lock for a memory repository.
Prevents concurrent modifications to an agent's git-based memory.
Args:
agent_id: The agent ID whose memory is being modified
token: Unique identifier for the lock holder (for debugging/tracing)
Returns:
Lock object if acquired, raises MemoryRepoBusyError if in use
"""
if Lock is None:
return None
client = await self.get_client()
lock_key = f"{MEMORY_REPO_LOCK_PREFIX}{agent_id}"
lock = Lock(
client,
lock_key,
timeout=MEMORY_REPO_LOCK_TTL_SECONDS,
blocking=False,
thread_local=False,
raise_on_release_error=False,
)
if await lock.acquire(token=token):
return lock
lock_holder_token = await client.get(lock_key)
raise MemoryRepoBusyError(
agent_id=agent_id,
lock_holder_token=lock_holder_token,
)
async def release_memory_repo_lock(self, agent_id: str) -> bool:
"""
Release a memory repo lock by agent_id.
Args:
agent_id: The agent ID to release the lock for
Returns:
True if lock was released, False if release failed
"""
try:
client = await self.get_client()
lock_key = f"{MEMORY_REPO_LOCK_PREFIX}{agent_id}"
await client.delete(lock_key)
return True
except Exception as e:
logger.warning(f"Failed to release memory repo lock for agent {agent_id}: {e}")
return False
@with_retry()
async def exists(self, *keys: str) -> int:
"""Check if keys exist."""
client = await self.get_client()
return await client.exists(*keys)
# Set operations
async def sadd(self, key: str, *members: Union[str, int, float]) -> int:
"""Add members to set."""
client = await self.get_client()
return await client.sadd(key, *members)
async def smembers(self, key: str) -> Set[str]:
"""Get all set members."""
client = await self.get_client()
return await client.smembers(key)
@with_retry()
async def smismember(self, key: str, values: list[Any] | Any) -> list[int] | int:
"""clever!: set member is member"""
try:
client = await self.get_client()
result = await client.smismember(key, values)
return result if isinstance(values, list) else result[0]
except Exception:
return [0] * len(values) if isinstance(values, list) else 0
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
"""Remove members from set."""
client = await self.get_client()
return await client.srem(key, *members)
async def scard(self, key: str) -> int:
client = await self.get_client()
return await client.scard(key)
# Atomic operations
async def incr(self, key: str) -> int:
"""Increment key value."""
client = await self.get_client()
return await client.incr(key)
async def decr(self, key: str) -> int:
"""Decrement key value."""
client = await self.get_client()
return await client.decr(key)
# Stream operations
@with_retry()
async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str:
"""Add entry to a stream.
Args:
stream: Stream name
fields: Dict of field-value pairs to add
id: Entry ID ('*' for auto-generation)
maxlen: Maximum length of the stream
approximate: Whether maxlen is approximate
Returns:
The ID of the added entry
"""
client = await self.get_client()
return await client.xadd(stream, fields, id=id, maxlen=maxlen, approximate=approximate)
@with_retry()
async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]:
"""Read from streams.
Args:
streams: Dict mapping stream names to IDs
count: Maximum number of entries to return
block: Milliseconds to block waiting for data (None = no blocking)
Returns:
List of entries from the streams
"""
client = await self.get_client()
return await client.xread(streams, count=count, block=block)
@with_retry()
async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]:
"""Read range of entries from a stream.
Args:
stream: Stream name
start: Start ID (inclusive)
end: End ID (inclusive)
count: Maximum number of entries to return
Returns:
List of entries in the specified range
"""
client = await self.get_client()
return await client.xrange(stream, start, end, count=count)
@with_retry()
async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]:
"""Read range of entries from a stream in reverse order.
Args:
stream: Stream name
start: Start ID (inclusive)
end: End ID (inclusive)
count: Maximum number of entries to return
Returns:
List of entries in the specified range in reverse order
"""
client = await self.get_client()
return await client.xrevrange(stream, start, end, count=count)
@with_retry()
async def xlen(self, stream: str) -> int:
"""Get the length of a stream.
Args:
stream: Stream name
Returns:
Number of entries in the stream
"""
client = await self.get_client()
return await client.xlen(stream)
@with_retry()
async def xdel(self, stream: str, *ids: str) -> int:
"""Delete entries from a stream.
Args:
stream: Stream name
ids: IDs of entries to delete
Returns:
Number of entries deleted
"""
client = await self.get_client()
return await client.xdel(stream, *ids)
@with_retry()
async def xinfo_stream(self, stream: str) -> Dict:
"""Get information about a stream.
Args:
stream: Stream name
Returns:
Dict with stream information
"""
client = await self.get_client()
return await client.xinfo_stream(stream)
@with_retry()
async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int:
"""Trim a stream to a maximum length.
Args:
stream: Stream name
maxlen: Maximum length
approximate: Whether maxlen is approximate
Returns:
Number of entries removed
"""
client = await self.get_client()
return await client.xtrim(stream, maxlen=maxlen, approximate=approximate)
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
exclude_key = self._get_group_exclusion_key(group)
include_key = self._get_group_inclusion_key(group)
# 1. if the member IS excluded from the group
if self.exists(exclude_key) and await self.scard(exclude_key) > 1:
return bool(await self.smismember(exclude_key, member))
# 2. if the group HAS an include set, is the member in that set?
if self.exists(include_key) and await self.scard(include_key) > 1:
return bool(await self.smismember(include_key, member))
# 3. if the group does NOT HAVE an include set and member NOT excluded
return True
async def create_inclusion_exclusion_keys(self, group: str) -> None:
redis_client = await self.get_client()
await redis_client.sadd(self._get_group_inclusion_key(group), REDIS_SET_DEFAULT_VAL)
await redis_client.sadd(self._get_group_exclusion_key(group), REDIS_SET_DEFAULT_VAL)
@staticmethod
def _get_group_inclusion_key(group: str) -> str:
return f"{group}:{REDIS_INCLUDE}"
@staticmethod
def _get_group_exclusion_key(group: str) -> str:
return f"{group}:{REDIS_EXCLUDE}"
class NoopAsyncRedisClient(AsyncRedisClient):
# noinspection PyMissingConstructor
def __init__(self):
pass
async def set(
self,
key: str,
value: Union[str, int, float],
ex: Optional[int] = None,
px: Optional[int] = None,
nx: bool = False,
xx: bool = False,
) -> bool:
return False
async def get(self, key: str, default: Any = None) -> Any:
return default
async def exists(self, *keys: str) -> int:
return 0
async def sadd(self, key: str, *members: Union[str, int, float]) -> int:
return 0
async def smismember(self, key: str, values: list[Any] | Any) -> list[int] | int:
return [0] * len(values) if isinstance(values, list) else 0
async def delete(self, *keys: str) -> int:
return 0
async def acquire_conversation_lock(
self,
conversation_id: str,
token: str,
) -> Optional["Lock"]:
return None
async def release_conversation_lock(self, conversation_id: str) -> bool:
return False
async def acquire_memory_repo_lock(
self,
agent_id: str,
token: str,
) -> Optional["Lock"]:
return None
async def release_memory_repo_lock(self, agent_id: str) -> bool:
return False
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
return False
async def create_inclusion_exclusion_keys(self, group: str) -> None:
return None
async def scard(self, key: str) -> int:
return 0
async def smembers(self, key: str) -> Set[str]:
return set()
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
return 0
# Stream operations
async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str:
return ""
async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]:
return []
async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]:
return []
async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]:
return []
async def xlen(self, stream: str) -> int:
return 0
async def xdel(self, stream: str, *ids: str) -> int:
return 0
async def xinfo_stream(self, stream: str) -> Dict:
return {}
async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int:
return 0
async def get_redis_client() -> AsyncRedisClient:
global _client_instance
if _client_instance is None:
try:
# If Redis settings are not configured, use noop client
if settings.redis_host is None or settings.redis_port is None:
logger.info("Redis not configured, using noop client")
_client_instance = NoopAsyncRedisClient()
else:
_client_instance = AsyncRedisClient(
host=settings.redis_host,
port=settings.redis_port,
)
await _client_instance.wait_for_ready(timeout=5)
logger.info("Redis client initialized")
except Exception as e:
logger.warning(f"Failed to initialize Redis: {e}")
_client_instance = NoopAsyncRedisClient()
return _client_instance