feat: plugin system and backend runtime flags (#2543)
This commit is contained in:
@@ -19,6 +19,10 @@ ENV LETTA_ENVIRONMENT=${LETTA_ENVIRONMENT} \
|
||||
POETRY_VIRTUALENVS_CREATE=1 \
|
||||
POETRY_CACHE_DIR=/tmp/poetry_cache
|
||||
|
||||
# Set for other builds
|
||||
ARG LETTA_VERSION
|
||||
ENV LETTA_VERSION=${LETTA_VERSION}
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Create and activate virtual environment
|
||||
@@ -70,7 +74,6 @@ ENV LETTA_ENVIRONMENT=${LETTA_ENVIRONMENT} \
|
||||
POSTGRES_DB=letta \
|
||||
COMPOSIO_DISABLE_VERSION_CHECK=true
|
||||
|
||||
# Set for other builds
|
||||
ARG LETTA_VERSION
|
||||
ENV LETTA_VERSION=${LETTA_VERSION}
|
||||
|
||||
|
||||
@@ -330,3 +330,7 @@ RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"
|
||||
WEB_SEARCH_CLIP_CONTENT = False
|
||||
WEB_SEARCH_INCLUDE_SCORE = False
|
||||
WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n"
|
||||
|
||||
REDIS_INCLUDE = "INCLUDE"
|
||||
REDIS_EXCLUDE = "EXCLUDE"
|
||||
REDIS_SET_DEFAULT_VAL = "None"
|
||||
|
||||
0
letta/data_sources/__init__.py
Normal file
0
letta/data_sources/__init__.py
Normal file
282
letta/data_sources/redis_client.py
Normal file
282
letta/data_sources/redis_client.py
Normal file
@@ -0,0 +1,282 @@
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Optional, Set, Union
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis import RedisError
|
||||
|
||||
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_client_instance = None
|
||||
|
||||
|
||||
class AsyncRedisClient:
|
||||
"""Async Redis client with connection pooling and error handling"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: Optional[str] = None,
|
||||
max_connections: int = 50,
|
||||
decode_responses: bool = True,
|
||||
socket_timeout: int = 5,
|
||||
socket_connect_timeout: int = 5,
|
||||
retry_on_timeout: bool = True,
|
||||
health_check_interval: int = 30,
|
||||
):
|
||||
"""
|
||||
Initialize Redis client with connection pool.
|
||||
|
||||
Args:
|
||||
host: Redis server hostname
|
||||
port: Redis server port
|
||||
db: Database number
|
||||
password: Redis password if required
|
||||
max_connections: Maximum number of connections in pool
|
||||
decode_responses: Decode byte responses to strings
|
||||
socket_timeout: Socket timeout in seconds
|
||||
socket_connect_timeout: Socket connection timeout
|
||||
retry_on_timeout: Retry operations on timeout
|
||||
health_check_interval: Seconds between health checks
|
||||
"""
|
||||
self.pool = redis.ConnectionPool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
max_connections=max_connections,
|
||||
decode_responses=decode_responses,
|
||||
socket_timeout=socket_timeout,
|
||||
socket_connect_timeout=socket_connect_timeout,
|
||||
retry_on_timeout=retry_on_timeout,
|
||||
health_check_interval=health_check_interval,
|
||||
)
|
||||
self._client = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(self) -> redis.Redis:
|
||||
"""Get or create Redis client instance."""
|
||||
if self._client is None:
|
||||
async with self._lock:
|
||||
if self._client is None:
|
||||
self._client = redis.Redis(connection_pool=self.pool)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection and cleanup."""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
await self.pool.disconnect()
|
||||
self._client = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.get_client()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
|
||||
# Health check and connection management
|
||||
async def ping(self) -> bool:
|
||||
"""Check if Redis is accessible."""
|
||||
try:
|
||||
client = await self.get_client()
|
||||
await client.ping()
|
||||
return True
|
||||
except RedisError:
|
||||
logger.exception("Redis ping failed")
|
||||
return False
|
||||
|
||||
async def wait_for_ready(self, timeout: int = 30, interval: float = 0.5):
|
||||
"""Wait for Redis to be ready."""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while (asyncio.get_event_loop().time() - start_time) < timeout:
|
||||
if await self.ping():
|
||||
return
|
||||
await asyncio.sleep(interval)
|
||||
raise ConnectionError(f"Redis not ready after {timeout} seconds")
|
||||
|
||||
# Retry decorator for resilience
|
||||
def with_retry(max_attempts: int = 3, delay: float = 0.1):
|
||||
"""Decorator to retry Redis operations on failure."""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
last_error = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
last_error = e
|
||||
if attempt < max_attempts - 1:
|
||||
await asyncio.sleep(delay * (2**attempt))
|
||||
logger.warning(f"Retry {attempt + 1}/{max_attempts} for {func.__name__}: {e}")
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
# Basic operations with error handling
|
||||
@with_retry()
|
||||
async def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get value by key."""
|
||||
try:
|
||||
client = await self.get_client()
|
||||
return await client.get(key)
|
||||
except:
|
||||
return default
|
||||
|
||||
@with_retry()
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Union[str, int, float],
|
||||
ex: Optional[int] = None,
|
||||
px: Optional[int] = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Set key-value with options.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
value: Value to store
|
||||
ex: Expire time in seconds
|
||||
px: Expire time in milliseconds
|
||||
nx: Only set if key doesn't exist
|
||||
xx: Only set if key exists
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.set(key, value, ex=ex, px=px, nx=nx, xx=xx)
|
||||
|
||||
@with_retry()
|
||||
async def delete(self, *keys: str) -> int:
|
||||
"""Delete one or more keys."""
|
||||
client = await self.get_client()
|
||||
return await client.delete(*keys)
|
||||
|
||||
@with_retry()
|
||||
async def exists(self, *keys: str) -> int:
|
||||
"""Check if keys exist."""
|
||||
client = await self.get_client()
|
||||
return await client.exists(*keys)
|
||||
|
||||
# Set operations
|
||||
async def sadd(self, key: str, *members: Union[str, int, float]) -> int:
|
||||
"""Add members to set."""
|
||||
client = await self.get_client()
|
||||
return await client.sadd(key, *members)
|
||||
|
||||
async def smembers(self, key: str) -> Set[str]:
|
||||
"""Get all set members."""
|
||||
client = await self.get_client()
|
||||
return await client.smembers(key)
|
||||
|
||||
@with_retry()
|
||||
async def smismember(self, key: str, values: list[Any] | Any) -> list[int] | int:
|
||||
"""clever!: set member is member"""
|
||||
try:
|
||||
client = await self.get_client()
|
||||
result = await client.smismember(key, values)
|
||||
return result if isinstance(values, list) else result[0]
|
||||
except:
|
||||
return [0] * len(values) if isinstance(values, list) else 0
|
||||
|
||||
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
|
||||
"""Remove members from set."""
|
||||
client = await self.get_client()
|
||||
return await client.srem(key, *members)
|
||||
|
||||
async def scard(self, key: str) -> int:
|
||||
client = await self.get_client()
|
||||
return await client.scard(key)
|
||||
|
||||
# Atomic operations
|
||||
async def incr(self, key: str) -> int:
|
||||
"""Increment key value."""
|
||||
client = await self.get_client()
|
||||
return await client.incr(key)
|
||||
|
||||
async def decr(self, key: str) -> int:
|
||||
"""Decrement key value."""
|
||||
client = await self.get_client()
|
||||
return await client.decr(key)
|
||||
|
||||
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
|
||||
exclude_key = f"{group}_{REDIS_EXCLUDE}"
|
||||
include_key = f"{group}_{REDIS_INCLUDE}"
|
||||
# 1. if the member IS excluded from the group
|
||||
if self.exists(exclude_key) and await self.scard(exclude_key) > 1:
|
||||
return bool(await self.smismember(exclude_key, member))
|
||||
# 2. if the group HAS an include set, is the member in that set?
|
||||
if self.exists(include_key) and await self.scard(include_key) > 1:
|
||||
return bool(await self.smismember(include_key, member))
|
||||
# 3. if the group does NOT HAVE an include set and member NOT excluded
|
||||
return True
|
||||
|
||||
async def create_inclusion_exclusion_keys(self, group: str) -> None:
|
||||
redis_client = await self.get_client()
|
||||
await redis_client.sadd(self._get_group_inclusion_key(group), REDIS_SET_DEFAULT_VAL)
|
||||
await redis_client.sadd(self._get_group_exclusion_key(group), REDIS_SET_DEFAULT_VAL)
|
||||
|
||||
@staticmethod
|
||||
def _get_group_inclusion_key(group: str) -> str:
|
||||
return f"{group}_{REDIS_INCLUDE}"
|
||||
|
||||
@staticmethod
|
||||
def _get_group_exclusion_key(group: str) -> str:
|
||||
return f"{group}_{REDIS_EXCLUDE}"
|
||||
|
||||
|
||||
class NoopAsyncRedisClient(AsyncRedisClient):
|
||||
async def get(self, key: str, default: Any = None) -> Any:
|
||||
return default
|
||||
|
||||
async def exists(self, *keys: str) -> int:
|
||||
return 0
|
||||
|
||||
async def sadd(self, key: str, *members: Union[str, int, float]) -> int:
|
||||
return 0
|
||||
|
||||
async def smismember(self, key: str, values: list[Any] | Any) -> list[int] | int:
|
||||
return [0] * len(values) if isinstance(values, list) else 0
|
||||
|
||||
async def delete(self, *keys: str) -> int:
|
||||
return 0
|
||||
|
||||
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
|
||||
return False
|
||||
|
||||
async def create_inclusion_exclusion_keys(self, group: str) -> None:
|
||||
return None
|
||||
|
||||
async def scard(self, key: str) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
async def get_redis_client() -> AsyncRedisClient:
|
||||
global _client_instance
|
||||
if _client_instance is None:
|
||||
try:
|
||||
from letta.settings import settings
|
||||
|
||||
_client_instance = AsyncRedisClient(
|
||||
host=settings.redis_host or "localhost",
|
||||
port=settings.redis_port or 6379,
|
||||
)
|
||||
await _client_instance.wait_for_ready(timeout=5)
|
||||
logger.info("Redis client initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize Redis: {e}")
|
||||
_client_instance = NoopAsyncRedisClient()
|
||||
return _client_instance
|
||||
69
letta/helpers/decorators.py
Normal file
69
letta/helpers/decorators.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.plugins.plugins import get_experimental_checker
|
||||
from letta.settings import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def experimental(feature_name: str, fallback_function: Callable, **kwargs):
|
||||
"""Decorator that runs a fallback function if experimental feature is not enabled.
|
||||
|
||||
- kwargs from the decorator will be combined with function kwargs and overwritten only for experimental evaluation.
|
||||
- if the decorated function, fallback_function, or experimental checker function is async, the whole call will be async
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
experimental_checker = get_experimental_checker()
|
||||
is_f_async = inspect.iscoroutinefunction(f)
|
||||
is_fallback_async = inspect.iscoroutinefunction(fallback_function)
|
||||
is_experimental_checker_async = inspect.iscoroutinefunction(experimental_checker)
|
||||
|
||||
async def call_function(func, is_async, *args, **_kwargs):
|
||||
if is_async:
|
||||
return await func(*args, **_kwargs)
|
||||
return func(*args, **_kwargs)
|
||||
|
||||
# asynchronous wrapper if any function is async
|
||||
if any((is_f_async, is_fallback_async, is_experimental_checker_async)):
|
||||
|
||||
@wraps(f)
|
||||
async def async_wrapper(*args, **_kwargs):
|
||||
result = await call_function(experimental_checker, is_experimental_checker_async, feature_name, **dict(_kwargs, **kwargs))
|
||||
if result:
|
||||
return await call_function(f, is_f_async, *args, **_kwargs)
|
||||
else:
|
||||
return await call_function(fallback_function, is_fallback_async, *args, **_kwargs)
|
||||
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(*args, **_kwargs):
|
||||
if experimental_checker(feature_name, **dict(_kwargs, **kwargs)):
|
||||
return f(*args, **_kwargs)
|
||||
else:
|
||||
return fallback_function(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated(message: str):
|
||||
"""Simple decorator that marks a method as deprecated."""
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
if settings.debug:
|
||||
logger.warning(f"Function {f.__name__} is deprecated: {message}.")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
@@ -24,6 +25,7 @@ logger = get_logger(__name__)
|
||||
|
||||
def handle_db_timeout(func):
|
||||
"""Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception."""
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -34,6 +36,17 @@ def handle_db_timeout(func):
|
||||
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
||||
|
||||
return async_wrapper
|
||||
|
||||
|
||||
class AccessType(str, Enum):
|
||||
|
||||
22
letta/plugins/README.md
Normal file
22
letta/plugins/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
### Plugins
|
||||
|
||||
Plugins enable plug and play for various components.
|
||||
|
||||
Plugin configurations can be set in `letta.settings.settings`.
|
||||
|
||||
The plugins will take a delimited list of consisting of individual plugin configs:
|
||||
|
||||
`<plugin_name>.<config_name>=<class_or_function>`
|
||||
|
||||
joined by `;`
|
||||
|
||||
In the default configuration, the top level keys have values `plugin_name`,
|
||||
the `config_name` is nested under and the `class_or_function` is defined
|
||||
after in format `<module_path>:<name>`.
|
||||
|
||||
```
|
||||
DEFAULT_PLUGINS = {
|
||||
"experimental_check": {
|
||||
"default": "letta.plugins.defaults:is_experimental_enabled",
|
||||
...
|
||||
```
|
||||
0
letta/plugins/__init__.py
Normal file
0
letta/plugins/__init__.py
Normal file
11
letta/plugins/defaults.py
Normal file
11
letta/plugins/defaults.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
def is_experimental_enabled(feature_name: str, **kwargs) -> bool:
|
||||
if feature_name in ("async_agent_loop", "summarize"):
|
||||
if not (kwargs.get("eligibility", False) and settings.use_experimental):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Err on safety here, disabling experimental if not handled here.
|
||||
return False
|
||||
72
letta/plugins/plugins.py
Normal file
72
letta/plugins/plugins.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import importlib
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SummarizerProtocol(Protocol):
|
||||
"""What a summarizer must implement"""
|
||||
|
||||
async def summarize(self, text: str) -> str: ...
|
||||
def get_name(self) -> str: ...
|
||||
|
||||
|
||||
# Currently this supports one of each plugin type. This can be expanded in the future.
|
||||
DEFAULT_PLUGINS = {
|
||||
"experimental_check": {
|
||||
"protocol": None,
|
||||
"target": "letta.plugins.defaults:is_experimental_enabled",
|
||||
},
|
||||
"summarizer": {
|
||||
"protocol": SummarizerProtocol,
|
||||
"target": "letta.services.summarizer.summarizer:Summarizer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_plugin(plugin_type: str):
|
||||
"""Get a plugin instance"""
|
||||
plugin_register = dict(DEFAULT_PLUGINS, **settings.plugin_register_dict)
|
||||
if plugin_type in plugin_register:
|
||||
impl_path = plugin_register[plugin_type]["target"]
|
||||
module_path, name = impl_path.split(":")
|
||||
module = importlib.import_module(module_path)
|
||||
plugin = getattr(module, name)
|
||||
if type(plugin).__name__ == "function":
|
||||
return plugin
|
||||
elif type(plugin).__name__ == "class":
|
||||
if plugin_register["protocol"] and not isinstance(plugin, type(plugin_register["protocol"])):
|
||||
raise TypeError(f'{plugin} does not implement {type(plugin_register["protocol"]).__name__}')
|
||||
return plugin()
|
||||
raise TypeError("Unknown plugin type")
|
||||
|
||||
|
||||
_experimental_checker = None
|
||||
_summarizer = None
|
||||
|
||||
|
||||
# TODO handle coroutines
|
||||
# Convenience functions
|
||||
def get_experimental_checker():
|
||||
global _experimental_checker
|
||||
if _experimental_checker is None:
|
||||
_experimental_checker = get_plugin("experimental_check")
|
||||
return _experimental_checker
|
||||
|
||||
|
||||
def get_summarizer():
|
||||
global _summarizer
|
||||
if _summarizer is None:
|
||||
_summarizer = get_plugin("summarizer")
|
||||
return _summarizer
|
||||
|
||||
|
||||
def reset_experimental_checker():
|
||||
global _experimental_checker
|
||||
_experimental_checker = None
|
||||
|
||||
|
||||
def reset_summarizer():
|
||||
global _summarizer
|
||||
_summarizer = None
|
||||
@@ -333,7 +333,7 @@ def start_server(
|
||||
if (os.getenv("LOCAL_HTTPS") == "true") or "--localhttps" in sys.argv:
|
||||
print(f"▶ Server running at: https://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
||||
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||
if importlib.util.find_spec("granian") is not None and settings.use_uvloop:
|
||||
if importlib.util.find_spec("granian") is not None and settings.use_granian:
|
||||
from granian import Granian
|
||||
|
||||
# Experimental Granian engine
|
||||
|
||||
@@ -958,10 +958,8 @@ async def summarize_agent_conversation(
|
||||
This endpoint summarizes the current message history for a given agent,
|
||||
truncating and compressing it down to the specified `max_message_length`.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# user_eligible = actor.organization_id not in ["org-4a3af5dd-4c6a-48cb-ac13-3f73ecaaa4bf", "org-4ab3f6e8-9a44-4bee-aeb6-c681cbbc7bf6"]
|
||||
# TODO: This is redundant, remove soon
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
|
||||
agent_eligible = agent.enable_sleeptime or agent.agent_type == AgentType.sleeptime_agent or not agent.multi_agent_group
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
|
||||
|
||||
@@ -152,7 +152,10 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
print(f"Auto-generated code for debugging:\n\n{code}")
|
||||
raise e
|
||||
finally:
|
||||
# Clean up the temp file
|
||||
# Clean up the temp file if not debugging
|
||||
from letta.settings import settings
|
||||
|
||||
if not settings.debug:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
async def _prepare_venv(self, local_configs, venv_path: str, env: Dict[str, str]):
|
||||
|
||||
@@ -192,13 +192,17 @@ class Settings(BaseSettings):
|
||||
pool_use_lifo: bool = True
|
||||
disable_sqlalchemy_pooling: bool = False
|
||||
|
||||
redis_host: Optional[str] = None
|
||||
redis_port: Optional[int] = None
|
||||
|
||||
plugin_register: Optional[str] = None
|
||||
|
||||
# multi agent settings
|
||||
multi_agent_send_message_max_retries: int = 3
|
||||
multi_agent_send_message_timeout: int = 20 * 60
|
||||
multi_agent_concurrent_sends: int = 50
|
||||
|
||||
# telemetry logging
|
||||
verbose_telemetry_logging: bool = False
|
||||
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
|
||||
disable_tracing: bool = False
|
||||
llm_api_logging: bool = True
|
||||
@@ -259,6 +263,15 @@ class Settings(BaseSettings):
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def plugin_register_dict(self) -> dict:
|
||||
plugins = {}
|
||||
if self.plugin_register:
|
||||
for plugin in self.plugin_register.split(";"):
|
||||
name, target = plugin.split("=")
|
||||
plugins[name] = {"target": target}
|
||||
return plugins
|
||||
|
||||
|
||||
class TestSettings(Settings):
|
||||
model_config = SettingsConfigDict(env_prefix="letta_test_", extra="ignore")
|
||||
@@ -266,9 +279,15 @@ class TestSettings(Settings):
|
||||
letta_dir: Optional[Path] = Field(Path.home() / ".letta/test", env="LETTA_TEST_DIR")
|
||||
|
||||
|
||||
class LogSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="letta_logging_", extra="ignore")
|
||||
verbose_telemetry_logging: bool = False
|
||||
|
||||
|
||||
# singleton
|
||||
settings = Settings(_env_parse_none_str="None")
|
||||
test_settings = TestSettings()
|
||||
model_settings = ModelSettings()
|
||||
tool_settings = ToolSettings()
|
||||
summarizer_settings = SummarizerSettings()
|
||||
log_settings = LogSettings()
|
||||
|
||||
@@ -515,6 +515,11 @@ def is_optional_type(hint):
|
||||
|
||||
|
||||
def enforce_types(func):
|
||||
"""Enforces that values passed in match the expected types.
|
||||
|
||||
Technically will handle coroutines as well.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get type hints, excluding the return type hint
|
||||
@@ -1078,9 +1083,9 @@ def log_telemetry(logger: Logger, event: str, **kwargs):
|
||||
:param event: A string describing the event.
|
||||
:param kwargs: Additional key-value pairs for logging metadata.
|
||||
"""
|
||||
from letta.settings import settings
|
||||
from letta.settings import log_settings
|
||||
|
||||
if settings.verbose_telemetry_logging:
|
||||
if log_settings.verbose_telemetry_logging:
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S,%f UTC") # More readable timestamp
|
||||
extra_data = " | ".join(f"{key}={value}" for key, value in kwargs.items() if value is not None)
|
||||
logger.info(f"[{timestamp}] EVENT: {event} | {extra_data}")
|
||||
|
||||
551
poetry.lock
generated
551
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -96,10 +96,12 @@ async-lru = "^2.0.5"
|
||||
mistralai = "^1.8.1"
|
||||
uvloop = {version = "^0.21.0", optional = true}
|
||||
granian = {version = "^2.3.2", extras = ["uvloop", "reload"], optional = true}
|
||||
redis = {version = "^6.2.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
postgres = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "asyncpg"]
|
||||
redis = ["redis"]
|
||||
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "locust"]
|
||||
experimental = ["uvloop", "granian"]
|
||||
server = ["websockets", "fastapi", "uvicorn"]
|
||||
@@ -110,7 +112,7 @@ tests = ["wikipedia"]
|
||||
bedrock = ["boto3"]
|
||||
google = ["google-genai"]
|
||||
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.4.2"
|
||||
|
||||
20
tests/helpers/plugins_helper.py
Normal file
20
tests/helpers/plugins_helper.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
from letta.services.agent_manager import AgentManager
|
||||
|
||||
|
||||
async def is_experimental_okay(feature_name: str, **kwargs) -> bool:
|
||||
print(feature_name, kwargs)
|
||||
if feature_name == "test_pass_with_kwarg":
|
||||
return isinstance(kwargs["agent_manager"], AgentManager)
|
||||
if feature_name == "test_just_pass":
|
||||
return True
|
||||
if feature_name == "test_fail":
|
||||
return False
|
||||
if feature_name == "test_override_kwarg":
|
||||
return kwargs["bool_val"]
|
||||
if feature_name == "test_redis_flag":
|
||||
client = await get_redis_client()
|
||||
user_id = kwargs["user_id"]
|
||||
return await client.check_inclusion_and_exclusion(member=user_id, group="TEST_GROUP")
|
||||
# Err on safety here, disabling experimental if not handled here.
|
||||
return False
|
||||
92
tests/test_plugins.py
Normal file
92
tests/test_plugins.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
from letta.helpers.decorators import experimental
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_experimental_decorator(event_loop):
|
||||
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
|
||||
|
||||
@experimental("test_just_pass", fallback_function=lambda: False, kwarg1=3)
|
||||
def _return_true():
|
||||
return True
|
||||
|
||||
assert _return_true()
|
||||
settings.plugin_register = ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_arg_success(event_loop):
|
||||
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
|
||||
|
||||
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: False, bool_val=True)
|
||||
async def _return_true(a_val: bool, bool_val: bool):
|
||||
assert bool_val is False
|
||||
return True
|
||||
|
||||
assert _return_true(False, False)
|
||||
settings.plugin_register = ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_arg_fail(event_loop):
|
||||
# Should fallback to lambda
|
||||
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
|
||||
|
||||
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: True, bool_val=False)
|
||||
async def _return_false(a_val: bool, bool_val: bool):
|
||||
assert bool_val is True
|
||||
return False
|
||||
|
||||
assert _return_false(False, True)
|
||||
|
||||
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: False, bool_val=True)
|
||||
async def _return_true(a_val: bool, bool_val: bool):
|
||||
assert bool_val is False
|
||||
return True
|
||||
|
||||
assert _return_true(False, bool_val=False)
|
||||
|
||||
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: True)
|
||||
async def _get_true(a_val: bool, bool_val: bool):
|
||||
return True
|
||||
|
||||
assert await _get_true(True, bool_val=True)
|
||||
with pytest.raises(Exception):
|
||||
# kwarg must be included in either experimental flag or function call
|
||||
assert await _get_true(True, True)
|
||||
settings.plugin_register = ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_flag(event_loop):
|
||||
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
|
||||
|
||||
@experimental("test_redis_flag", fallback_function=lambda *args, **kwargs: _raise())
|
||||
async def _new_feature(user_id: str) -> str:
|
||||
return "new_feature"
|
||||
|
||||
def _raise():
|
||||
raise Exception()
|
||||
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
group_name = "TEST_GROUP"
|
||||
include_key = redis_client._get_group_inclusion_key(group_name)
|
||||
exclude_key = redis_client._get_group_exclusion_key(group_name)
|
||||
test_user = "user123"
|
||||
# reset
|
||||
for member in await redis_client.smembers(include_key):
|
||||
await redis_client.srem(include_key, member)
|
||||
for member in await redis_client.smembers(exclude_key):
|
||||
await redis_client.srem(exclude_key, member)
|
||||
|
||||
await redis_client.create_inclusion_exclusion_keys(group=group_name)
|
||||
await redis_client.sadd(include_key, test_user)
|
||||
|
||||
assert await _new_feature(user_id=test_user) == "new_feature"
|
||||
with pytest.raises(Exception):
|
||||
assert await _new_feature(user_id=test_user + "1")
|
||||
print("members: ", await redis_client.smembers(include_key))
|
||||
26
tests/test_redis_client.py
Normal file
26
tests/test_redis_client.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_client(event_loop):
|
||||
test_values = {"LETTA_TEST_0": [1, 2, 3], "LETTA_TEST_1": ["apple", "pear", "banana"], "LETTA_TEST_2": ["{}", 3.2, "cat"]}
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
# Clear out keys
|
||||
await redis_client.delete(*test_values.keys())
|
||||
|
||||
# Add items
|
||||
for k, v in test_values.items():
|
||||
assert await redis_client.sadd(k, *v) == 3
|
||||
|
||||
# Check Membership
|
||||
for k, v in test_values.items():
|
||||
assert await redis_client.smembers(k) == set(str(val) for val in v)
|
||||
|
||||
for k, v in test_values.items():
|
||||
assert await redis_client.smismember(k, "invalid") == 0
|
||||
assert await redis_client.smismember(k, v[0]) == 1
|
||||
assert await redis_client.smismember(k, v[:2]) == [1, 1]
|
||||
assert await redis_client.smismember(k, v[2:] + ["invalid"]) == [1, 0]
|
||||
Reference in New Issue
Block a user