feat: plugin system and backend runtime flags (#2543)

This commit is contained in:
Andy Li
2025-06-05 18:12:44 -07:00
committed by GitHub
parent d2252f2953
commit eaf5682422
20 changed files with 1151 additions and 79 deletions

View File

@@ -19,6 +19,10 @@ ENV LETTA_ENVIRONMENT=${LETTA_ENVIRONMENT} \
POETRY_VIRTUALENVS_CREATE=1 \
POETRY_CACHE_DIR=/tmp/poetry_cache
# Set for other builds
ARG LETTA_VERSION
ENV LETTA_VERSION=${LETTA_VERSION}
WORKDIR /app
# Create and activate virtual environment
@@ -70,7 +74,6 @@ ENV LETTA_ENVIRONMENT=${LETTA_ENVIRONMENT} \
POSTGRES_DB=letta \
COMPOSIO_DISABLE_VERSION_CHECK=true
# Set for other builds
ARG LETTA_VERSION
ENV LETTA_VERSION=${LETTA_VERSION}

View File

@@ -330,3 +330,7 @@ RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"
WEB_SEARCH_CLIP_CONTENT = False
WEB_SEARCH_INCLUDE_SCORE = False
WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n"
REDIS_INCLUDE = "INCLUDE"
REDIS_EXCLUDE = "EXCLUDE"
REDIS_SET_DEFAULT_VAL = "None"

View File

View File

@@ -0,0 +1,282 @@
import asyncio
from functools import wraps
from typing import Any, Optional, Set, Union
import redis.asyncio as redis
from redis import RedisError
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
from letta.log import get_logger
logger = get_logger(__name__)
_client_instance = None
class AsyncRedisClient:
"""Async Redis client with connection pooling and error handling"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
max_connections: int = 50,
decode_responses: bool = True,
socket_timeout: int = 5,
socket_connect_timeout: int = 5,
retry_on_timeout: bool = True,
health_check_interval: int = 30,
):
"""
Initialize Redis client with connection pool.
Args:
host: Redis server hostname
port: Redis server port
db: Database number
password: Redis password if required
max_connections: Maximum number of connections in pool
decode_responses: Decode byte responses to strings
socket_timeout: Socket timeout in seconds
socket_connect_timeout: Socket connection timeout
retry_on_timeout: Retry operations on timeout
health_check_interval: Seconds between health checks
"""
self.pool = redis.ConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
decode_responses=decode_responses,
socket_timeout=socket_timeout,
socket_connect_timeout=socket_connect_timeout,
retry_on_timeout=retry_on_timeout,
health_check_interval=health_check_interval,
)
self._client = None
self._lock = asyncio.Lock()
async def get_client(self) -> redis.Redis:
"""Get or create Redis client instance."""
if self._client is None:
async with self._lock:
if self._client is None:
self._client = redis.Redis(connection_pool=self.pool)
return self._client
async def close(self):
"""Close Redis connection and cleanup."""
if self._client:
await self._client.close()
await self.pool.disconnect()
self._client = None
async def __aenter__(self):
"""Async context manager entry."""
await self.get_client()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.close()
# Health check and connection management
async def ping(self) -> bool:
"""Check if Redis is accessible."""
try:
client = await self.get_client()
await client.ping()
return True
except RedisError:
logger.exception("Redis ping failed")
return False
async def wait_for_ready(self, timeout: int = 30, interval: float = 0.5):
"""Wait for Redis to be ready."""
start_time = asyncio.get_event_loop().time()
while (asyncio.get_event_loop().time() - start_time) < timeout:
if await self.ping():
return
await asyncio.sleep(interval)
raise ConnectionError(f"Redis not ready after {timeout} seconds")
# Retry decorator for resilience
def with_retry(max_attempts: int = 3, delay: float = 0.1):
"""Decorator to retry Redis operations on failure."""
def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
last_error = None
for attempt in range(max_attempts):
try:
return await func(self, *args, **kwargs)
except (ConnectionError, TimeoutError) as e:
last_error = e
if attempt < max_attempts - 1:
await asyncio.sleep(delay * (2**attempt))
logger.warning(f"Retry {attempt + 1}/{max_attempts} for {func.__name__}: {e}")
raise last_error
return wrapper
return decorator
# Basic operations with error handling
@with_retry()
async def get(self, key: str, default: Any = None) -> Any:
"""Get value by key."""
try:
client = await self.get_client()
return await client.get(key)
except:
return default
@with_retry()
async def set(
self,
key: str,
value: Union[str, int, float],
ex: Optional[int] = None,
px: Optional[int] = None,
nx: bool = False,
xx: bool = False,
) -> bool:
"""
Set key-value with options.
Args:
key: Redis key
value: Value to store
ex: Expire time in seconds
px: Expire time in milliseconds
nx: Only set if key doesn't exist
xx: Only set if key exists
"""
client = await self.get_client()
return await client.set(key, value, ex=ex, px=px, nx=nx, xx=xx)
@with_retry()
async def delete(self, *keys: str) -> int:
"""Delete one or more keys."""
client = await self.get_client()
return await client.delete(*keys)
@with_retry()
async def exists(self, *keys: str) -> int:
"""Check if keys exist."""
client = await self.get_client()
return await client.exists(*keys)
# Set operations
async def sadd(self, key: str, *members: Union[str, int, float]) -> int:
"""Add members to set."""
client = await self.get_client()
return await client.sadd(key, *members)
async def smembers(self, key: str) -> Set[str]:
"""Get all set members."""
client = await self.get_client()
return await client.smembers(key)
@with_retry()
async def smismember(self, key: str, values: list[Any] | Any) -> list[int] | int:
"""clever!: set member is member"""
try:
client = await self.get_client()
result = await client.smismember(key, values)
return result if isinstance(values, list) else result[0]
except:
return [0] * len(values) if isinstance(values, list) else 0
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
"""Remove members from set."""
client = await self.get_client()
return await client.srem(key, *members)
async def scard(self, key: str) -> int:
client = await self.get_client()
return await client.scard(key)
# Atomic operations
async def incr(self, key: str) -> int:
"""Increment key value."""
client = await self.get_client()
return await client.incr(key)
async def decr(self, key: str) -> int:
"""Decrement key value."""
client = await self.get_client()
return await client.decr(key)
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
exclude_key = f"{group}_{REDIS_EXCLUDE}"
include_key = f"{group}_{REDIS_INCLUDE}"
# 1. if the member IS excluded from the group
if self.exists(exclude_key) and await self.scard(exclude_key) > 1:
return bool(await self.smismember(exclude_key, member))
# 2. if the group HAS an include set, is the member in that set?
if self.exists(include_key) and await self.scard(include_key) > 1:
return bool(await self.smismember(include_key, member))
# 3. if the group does NOT HAVE an include set and member NOT excluded
return True
async def create_inclusion_exclusion_keys(self, group: str) -> None:
redis_client = await self.get_client()
await redis_client.sadd(self._get_group_inclusion_key(group), REDIS_SET_DEFAULT_VAL)
await redis_client.sadd(self._get_group_exclusion_key(group), REDIS_SET_DEFAULT_VAL)
@staticmethod
def _get_group_inclusion_key(group: str) -> str:
return f"{group}_{REDIS_INCLUDE}"
@staticmethod
def _get_group_exclusion_key(group: str) -> str:
return f"{group}_{REDIS_EXCLUDE}"
class NoopAsyncRedisClient(AsyncRedisClient):
async def get(self, key: str, default: Any = None) -> Any:
return default
async def exists(self, *keys: str) -> int:
return 0
async def sadd(self, key: str, *members: Union[str, int, float]) -> int:
return 0
async def smismember(self, key: str, values: list[Any] | Any) -> list[int] | int:
return [0] * len(values) if isinstance(values, list) else 0
async def delete(self, *keys: str) -> int:
return 0
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
return False
async def create_inclusion_exclusion_keys(self, group: str) -> None:
return None
async def scard(self, key: str) -> int:
return 0
async def get_redis_client() -> AsyncRedisClient:
global _client_instance
if _client_instance is None:
try:
from letta.settings import settings
_client_instance = AsyncRedisClient(
host=settings.redis_host or "localhost",
port=settings.redis_port or 6379,
)
await _client_instance.wait_for_ready(timeout=5)
logger.info("Redis client initialized")
except Exception as e:
logger.warning(f"Failed to initialize Redis: {e}")
_client_instance = NoopAsyncRedisClient()
return _client_instance

View File

@@ -0,0 +1,69 @@
import inspect
from functools import wraps
from typing import Callable
from letta.log import get_logger
from letta.plugins.plugins import get_experimental_checker
from letta.settings import settings
logger = get_logger(__name__)
def experimental(feature_name: str, fallback_function: Callable, **kwargs):
"""Decorator that runs a fallback function if experimental feature is not enabled.
- kwargs from the decorator will be combined with function kwargs and overwritten only for experimental evaluation.
- if the decorated function, fallback_function, or experimental checker function is async, the whole call will be async
"""
def decorator(f):
experimental_checker = get_experimental_checker()
is_f_async = inspect.iscoroutinefunction(f)
is_fallback_async = inspect.iscoroutinefunction(fallback_function)
is_experimental_checker_async = inspect.iscoroutinefunction(experimental_checker)
async def call_function(func, is_async, *args, **_kwargs):
if is_async:
return await func(*args, **_kwargs)
return func(*args, **_kwargs)
# asynchronous wrapper if any function is async
if any((is_f_async, is_fallback_async, is_experimental_checker_async)):
@wraps(f)
async def async_wrapper(*args, **_kwargs):
result = await call_function(experimental_checker, is_experimental_checker_async, feature_name, **dict(_kwargs, **kwargs))
if result:
return await call_function(f, is_f_async, *args, **_kwargs)
else:
return await call_function(fallback_function, is_fallback_async, *args, **_kwargs)
return async_wrapper
else:
@wraps(f)
def wrapper(*args, **_kwargs):
if experimental_checker(feature_name, **dict(_kwargs, **kwargs)):
return f(*args, **_kwargs)
else:
return fallback_function(*args, **kwargs)
return wrapper
return decorator
def deprecated(message: str):
"""Simple decorator that marks a method as deprecated."""
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
if settings.debug:
logger.warning(f"Function {f.__name__} is deprecated: {message}.")
return f(*args, **kwargs)
return wrapper
return decorator

View File

@@ -1,3 +1,4 @@
import inspect
from datetime import datetime
from enum import Enum
from functools import wraps
@@ -24,16 +25,28 @@ logger = get_logger(__name__)
def handle_db_timeout(func):
"""Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception."""
if not inspect.iscoroutinefunction(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
return wrapper
return wrapper
else:
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
return async_wrapper
class AccessType(str, Enum):

22
letta/plugins/README.md Normal file
View File

@@ -0,0 +1,22 @@
### Plugins
Plugins enable plug and play for various components.
Plugin configurations can be set in `letta.settings.settings`.
The plugins will take a delimited list of consisting of individual plugin configs:
`<plugin_name>.<config_name>=<class_or_function>`
joined by `;`
In the default configuration, the top level keys have values `plugin_name`,
the `config_name` is nested under and the `class_or_function` is defined
after in format `<module_path>:<name>`.
```
DEFAULT_PLUGINS = {
"experimental_check": {
"default": "letta.plugins.defaults:is_experimental_enabled",
...
```

View File

11
letta/plugins/defaults.py Normal file
View File

@@ -0,0 +1,11 @@
from letta.settings import settings
def is_experimental_enabled(feature_name: str, **kwargs) -> bool:
if feature_name in ("async_agent_loop", "summarize"):
if not (kwargs.get("eligibility", False) and settings.use_experimental):
return False
return True
# Err on safety here, disabling experimental if not handled here.
return False

72
letta/plugins/plugins.py Normal file
View File

@@ -0,0 +1,72 @@
import importlib
from typing import Protocol, runtime_checkable
from letta.settings import settings
@runtime_checkable
class SummarizerProtocol(Protocol):
"""What a summarizer must implement"""
async def summarize(self, text: str) -> str: ...
def get_name(self) -> str: ...
# Currently this supports one of each plugin type. This can be expanded in the future.
DEFAULT_PLUGINS = {
"experimental_check": {
"protocol": None,
"target": "letta.plugins.defaults:is_experimental_enabled",
},
"summarizer": {
"protocol": SummarizerProtocol,
"target": "letta.services.summarizer.summarizer:Summarizer",
},
}
def get_plugin(plugin_type: str):
"""Get a plugin instance"""
plugin_register = dict(DEFAULT_PLUGINS, **settings.plugin_register_dict)
if plugin_type in plugin_register:
impl_path = plugin_register[plugin_type]["target"]
module_path, name = impl_path.split(":")
module = importlib.import_module(module_path)
plugin = getattr(module, name)
if type(plugin).__name__ == "function":
return plugin
elif type(plugin).__name__ == "class":
if plugin_register["protocol"] and not isinstance(plugin, type(plugin_register["protocol"])):
raise TypeError(f'{plugin} does not implement {type(plugin_register["protocol"]).__name__}')
return plugin()
raise TypeError("Unknown plugin type")
_experimental_checker = None
_summarizer = None
# TODO handle coroutines
# Convenience functions
def get_experimental_checker():
global _experimental_checker
if _experimental_checker is None:
_experimental_checker = get_plugin("experimental_check")
return _experimental_checker
def get_summarizer():
global _summarizer
if _summarizer is None:
_summarizer = get_plugin("summarizer")
return _summarizer
def reset_experimental_checker():
global _experimental_checker
_experimental_checker = None
def reset_summarizer():
global _summarizer
_summarizer = None

View File

@@ -333,7 +333,7 @@ def start_server(
if (os.getenv("LOCAL_HTTPS") == "true") or "--localhttps" in sys.argv:
print(f"▶ Server running at: https://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
if importlib.util.find_spec("granian") is not None and settings.use_uvloop:
if importlib.util.find_spec("granian") is not None and settings.use_granian:
from granian import Granian
# Experimental Granian engine

View File

@@ -958,10 +958,8 @@ async def summarize_agent_conversation(
This endpoint summarizes the current message history for a given agent,
truncating and compressing it down to the specified `max_message_length`.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# user_eligible = actor.organization_id not in ["org-4a3af5dd-4c6a-48cb-ac13-3f73ecaaa4bf", "org-4ab3f6e8-9a44-4bee-aeb6-c681cbbc7bf6"]
# TODO: This is redundant, remove soon
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
agent_eligible = agent.enable_sleeptime or agent.agent_type == AgentType.sleeptime_agent or not agent.multi_agent_group
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]

View File

@@ -152,8 +152,11 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
print(f"Auto-generated code for debugging:\n\n{code}")
raise e
finally:
# Clean up the temp file
os.remove(temp_file_path)
# Clean up the temp file if not debugging
from letta.settings import settings
if not settings.debug:
os.remove(temp_file_path)
async def _prepare_venv(self, local_configs, venv_path: str, env: Dict[str, str]):
"""

View File

@@ -192,13 +192,17 @@ class Settings(BaseSettings):
pool_use_lifo: bool = True
disable_sqlalchemy_pooling: bool = False
redis_host: Optional[str] = None
redis_port: Optional[int] = None
plugin_register: Optional[str] = None
# multi agent settings
multi_agent_send_message_max_retries: int = 3
multi_agent_send_message_timeout: int = 20 * 60
multi_agent_concurrent_sends: int = 50
# telemetry logging
verbose_telemetry_logging: bool = False
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
disable_tracing: bool = False
llm_api_logging: bool = True
@@ -259,6 +263,15 @@ class Settings(BaseSettings):
else:
return None
@property
def plugin_register_dict(self) -> dict:
plugins = {}
if self.plugin_register:
for plugin in self.plugin_register.split(";"):
name, target = plugin.split("=")
plugins[name] = {"target": target}
return plugins
class TestSettings(Settings):
model_config = SettingsConfigDict(env_prefix="letta_test_", extra="ignore")
@@ -266,9 +279,15 @@ class TestSettings(Settings):
letta_dir: Optional[Path] = Field(Path.home() / ".letta/test", env="LETTA_TEST_DIR")
class LogSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="letta_logging_", extra="ignore")
verbose_telemetry_logging: bool = False
# singleton
settings = Settings(_env_parse_none_str="None")
test_settings = TestSettings()
model_settings = ModelSettings()
tool_settings = ToolSettings()
summarizer_settings = SummarizerSettings()
log_settings = LogSettings()

View File

@@ -515,6 +515,11 @@ def is_optional_type(hint):
def enforce_types(func):
"""Enforces that values passed in match the expected types.
Technically will handle coroutines as well.
"""
@wraps(func)
def wrapper(*args, **kwargs):
# Get type hints, excluding the return type hint
@@ -1078,9 +1083,9 @@ def log_telemetry(logger: Logger, event: str, **kwargs):
:param event: A string describing the event.
:param kwargs: Additional key-value pairs for logging metadata.
"""
from letta.settings import settings
from letta.settings import log_settings
if settings.verbose_telemetry_logging:
if log_settings.verbose_telemetry_logging:
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S,%f UTC") # More readable timestamp
extra_data = " | ".join(f"{key}={value}" for key, value in kwargs.items() if value is not None)
logger.info(f"[{timestamp}] EVENT: {event} | {extra_data}")

551
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -96,10 +96,12 @@ async-lru = "^2.0.5"
mistralai = "^1.8.1"
uvloop = {version = "^0.21.0", optional = true}
granian = {version = "^2.3.2", extras = ["uvloop", "reload"], optional = true}
redis = {version = "^6.2.0", optional = true}
[tool.poetry.extras]
postgres = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "asyncpg"]
redis = ["redis"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "locust"]
experimental = ["uvloop", "granian"]
server = ["websockets", "fastapi", "uvicorn"]
@@ -110,7 +112,7 @@ tests = ["wikipedia"]
bedrock = ["boto3"]
google = ["google-genai"]
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian"]
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis"]
[tool.poetry.group.dev.dependencies]
black = "^24.4.2"

View File

@@ -0,0 +1,20 @@
from letta.data_sources.redis_client import get_redis_client
from letta.services.agent_manager import AgentManager
async def is_experimental_okay(feature_name: str, **kwargs) -> bool:
print(feature_name, kwargs)
if feature_name == "test_pass_with_kwarg":
return isinstance(kwargs["agent_manager"], AgentManager)
if feature_name == "test_just_pass":
return True
if feature_name == "test_fail":
return False
if feature_name == "test_override_kwarg":
return kwargs["bool_val"]
if feature_name == "test_redis_flag":
client = await get_redis_client()
user_id = kwargs["user_id"]
return await client.check_inclusion_and_exclusion(member=user_id, group="TEST_GROUP")
# Err on safety here, disabling experimental if not handled here.
return False

92
tests/test_plugins.py Normal file
View File

@@ -0,0 +1,92 @@
import pytest
from letta.data_sources.redis_client import get_redis_client
from letta.helpers.decorators import experimental
from letta.settings import settings
@pytest.mark.asyncio
async def test_default_experimental_decorator(event_loop):
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
@experimental("test_just_pass", fallback_function=lambda: False, kwarg1=3)
def _return_true():
return True
assert _return_true()
settings.plugin_register = ""
@pytest.mark.asyncio
async def test_overwrite_arg_success(event_loop):
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: False, bool_val=True)
async def _return_true(a_val: bool, bool_val: bool):
assert bool_val is False
return True
assert _return_true(False, False)
settings.plugin_register = ""
@pytest.mark.asyncio
async def test_overwrite_arg_fail(event_loop):
# Should fallback to lambda
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: True, bool_val=False)
async def _return_false(a_val: bool, bool_val: bool):
assert bool_val is True
return False
assert _return_false(False, True)
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: False, bool_val=True)
async def _return_true(a_val: bool, bool_val: bool):
assert bool_val is False
return True
assert _return_true(False, bool_val=False)
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: True)
async def _get_true(a_val: bool, bool_val: bool):
return True
assert await _get_true(True, bool_val=True)
with pytest.raises(Exception):
# kwarg must be included in either experimental flag or function call
assert await _get_true(True, True)
settings.plugin_register = ""
@pytest.mark.asyncio
async def test_redis_flag(event_loop):
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
@experimental("test_redis_flag", fallback_function=lambda *args, **kwargs: _raise())
async def _new_feature(user_id: str) -> str:
return "new_feature"
def _raise():
raise Exception()
redis_client = await get_redis_client()
group_name = "TEST_GROUP"
include_key = redis_client._get_group_inclusion_key(group_name)
exclude_key = redis_client._get_group_exclusion_key(group_name)
test_user = "user123"
# reset
for member in await redis_client.smembers(include_key):
await redis_client.srem(include_key, member)
for member in await redis_client.smembers(exclude_key):
await redis_client.srem(exclude_key, member)
await redis_client.create_inclusion_exclusion_keys(group=group_name)
await redis_client.sadd(include_key, test_user)
assert await _new_feature(user_id=test_user) == "new_feature"
with pytest.raises(Exception):
assert await _new_feature(user_id=test_user + "1")
print("members: ", await redis_client.smembers(include_key))

View File

@@ -0,0 +1,26 @@
import pytest
from letta.data_sources.redis_client import get_redis_client
@pytest.mark.asyncio
async def test_redis_client(event_loop):
test_values = {"LETTA_TEST_0": [1, 2, 3], "LETTA_TEST_1": ["apple", "pear", "banana"], "LETTA_TEST_2": ["{}", 3.2, "cat"]}
redis_client = await get_redis_client()
# Clear out keys
await redis_client.delete(*test_values.keys())
# Add items
for k, v in test_values.items():
assert await redis_client.sadd(k, *v) == 3
# Check Membership
for k, v in test_values.items():
assert await redis_client.smembers(k) == set(str(val) for val in v)
for k, v in test_values.items():
assert await redis_client.smismember(k, "invalid") == 0
assert await redis_client.smismember(k, v[0]) == 1
assert await redis_client.smismember(k, v[:2]) == [1, 1]
assert await redis_client.smismember(k, v[2:] + ["invalid"]) == [1, 0]