From 399b633899d0dd086191fe233c7f2f70c814a5ea Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 24 Aug 2025 17:15:34 -0700 Subject: [PATCH] feat: add redis stream support (#4144) --- letta/data_sources/redis_client.py | 147 ++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 1 deletion(-) diff --git a/letta/data_sources/redis_client.py b/letta/data_sources/redis_client.py index 9aa8effe..be149ab2 100644 --- a/letta/data_sources/redis_client.py +++ b/letta/data_sources/redis_client.py @@ -1,6 +1,6 @@ import asyncio from functools import wraps -from typing import Any, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL from letta.log import get_logger @@ -218,6 +218,126 @@ class AsyncRedisClient: client = await self.get_client() return await client.decr(key) + # Stream operations + @with_retry() + async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str: + """Add entry to a stream. + + Args: + stream: Stream name + fields: Dict of field-value pairs to add + id: Entry ID ('*' for auto-generation) + maxlen: Maximum length of the stream + approximate: Whether maxlen is approximate + + Returns: + The ID of the added entry + """ + client = await self.get_client() + return await client.xadd(stream, fields, id=id, maxlen=maxlen, approximate=approximate) + + @with_retry() + async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]: + """Read from streams. + + Args: + streams: Dict mapping stream names to IDs + count: Maximum number of entries to return + block: Milliseconds to block waiting for data (None = no blocking) + + Returns: + List of entries from the streams + """ + client = await self.get_client() + return await client.xread(streams, count=count, block=block) + + @with_retry() + async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]: + """Read range of entries from a stream. + + Args: + stream: Stream name + start: Start ID (inclusive) + end: End ID (inclusive) + count: Maximum number of entries to return + + Returns: + List of entries in the specified range + """ + client = await self.get_client() + return await client.xrange(stream, start, end, count=count) + + @with_retry() + async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]: + """Read range of entries from a stream in reverse order. + + Args: + stream: Stream name + start: Start ID (inclusive) + end: End ID (inclusive) + count: Maximum number of entries to return + + Returns: + List of entries in the specified range in reverse order + """ + client = await self.get_client() + return await client.xrevrange(stream, start, end, count=count) + + @with_retry() + async def xlen(self, stream: str) -> int: + """Get the length of a stream. + + Args: + stream: Stream name + + Returns: + Number of entries in the stream + """ + client = await self.get_client() + return await client.xlen(stream) + + @with_retry() + async def xdel(self, stream: str, *ids: str) -> int: + """Delete entries from a stream. + + Args: + stream: Stream name + ids: IDs of entries to delete + + Returns: + Number of entries deleted + """ + client = await self.get_client() + return await client.xdel(stream, *ids) + + @with_retry() + async def xinfo_stream(self, stream: str) -> Dict: + """Get information about a stream. + + Args: + stream: Stream name + + Returns: + Dict with stream information + """ + client = await self.get_client() + return await client.xinfo_stream(stream) + + @with_retry() + async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int: + """Trim a stream to a maximum length. + + Args: + stream: Stream name + maxlen: Maximum length + approximate: Whether maxlen is approximate + + Returns: + Number of entries removed + """ + client = await self.get_client() + return await client.xtrim(stream, maxlen=maxlen, approximate=approximate) + async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool: exclude_key = self._get_group_exclusion_key(group) include_key = self._get_group_inclusion_key(group) @@ -290,6 +410,31 @@ class NoopAsyncRedisClient(AsyncRedisClient): async def srem(self, key: str, *members: Union[str, int, float]) -> int: return 0 + # Stream operations + async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str: + return "" + + async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]: + return [] + + async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]: + return [] + + async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]: + return [] + + async def xlen(self, stream: str) -> int: + return 0 + + async def xdel(self, stream: str, *ids: str) -> int: + return 0 + + async def xinfo_stream(self, stream: str) -> Dict: + return {} + + async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int: + return 0 + async def get_redis_client() -> AsyncRedisClient: global _client_instance