feat: add redis stream support (#4144)

This commit is contained in:
cthomas
2025-08-24 17:15:34 -07:00
committed by GitHub
parent f918ca0a59
commit 399b633899

View File

@@ -1,6 +1,6 @@
import asyncio
from functools import wraps
from typing import Any, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
from letta.log import get_logger
@@ -218,6 +218,126 @@ class AsyncRedisClient:
client = await self.get_client()
return await client.decr(key)
# Stream operations
@with_retry()
async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str:
"""Add entry to a stream.
Args:
stream: Stream name
fields: Dict of field-value pairs to add
id: Entry ID ('*' for auto-generation)
maxlen: Maximum length of the stream
approximate: Whether maxlen is approximate
Returns:
The ID of the added entry
"""
client = await self.get_client()
return await client.xadd(stream, fields, id=id, maxlen=maxlen, approximate=approximate)
@with_retry()
async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]:
"""Read from streams.
Args:
streams: Dict mapping stream names to IDs
count: Maximum number of entries to return
block: Milliseconds to block waiting for data (None = no blocking)
Returns:
List of entries from the streams
"""
client = await self.get_client()
return await client.xread(streams, count=count, block=block)
@with_retry()
async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]:
"""Read range of entries from a stream.
Args:
stream: Stream name
start: Start ID (inclusive)
end: End ID (inclusive)
count: Maximum number of entries to return
Returns:
List of entries in the specified range
"""
client = await self.get_client()
return await client.xrange(stream, start, end, count=count)
@with_retry()
async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]:
"""Read range of entries from a stream in reverse order.
Args:
stream: Stream name
start: Start ID (inclusive)
end: End ID (inclusive)
count: Maximum number of entries to return
Returns:
List of entries in the specified range in reverse order
"""
client = await self.get_client()
return await client.xrevrange(stream, start, end, count=count)
@with_retry()
async def xlen(self, stream: str) -> int:
"""Get the length of a stream.
Args:
stream: Stream name
Returns:
Number of entries in the stream
"""
client = await self.get_client()
return await client.xlen(stream)
@with_retry()
async def xdel(self, stream: str, *ids: str) -> int:
"""Delete entries from a stream.
Args:
stream: Stream name
ids: IDs of entries to delete
Returns:
Number of entries deleted
"""
client = await self.get_client()
return await client.xdel(stream, *ids)
@with_retry()
async def xinfo_stream(self, stream: str) -> Dict:
"""Get information about a stream.
Args:
stream: Stream name
Returns:
Dict with stream information
"""
client = await self.get_client()
return await client.xinfo_stream(stream)
@with_retry()
async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int:
"""Trim a stream to a maximum length.
Args:
stream: Stream name
maxlen: Maximum length
approximate: Whether maxlen is approximate
Returns:
Number of entries removed
"""
client = await self.get_client()
return await client.xtrim(stream, maxlen=maxlen, approximate=approximate)
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
exclude_key = self._get_group_exclusion_key(group)
include_key = self._get_group_inclusion_key(group)
@@ -290,6 +410,31 @@ class NoopAsyncRedisClient(AsyncRedisClient):
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
return 0
# Stream operations
async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str:
return ""
async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]:
return []
async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]:
return []
async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]:
return []
async def xlen(self, stream: str) -> int:
return 0
async def xdel(self, stream: str, *ids: str) -> int:
return 0
async def xinfo_stream(self, stream: str) -> Dict:
return {}
async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int:
return 0
async def get_redis_client() -> AsyncRedisClient:
global _client_instance