chore: bump v0.11.5 (#2777)

This commit is contained in:
cthomas
2025-08-26 16:31:54 -07:00
committed by GitHub
92 changed files with 11572 additions and 2684 deletions

View File

@@ -41,6 +41,15 @@ jobs:
--health-interval 10s
--health-timeout 5s
--health-retries 5
redis:
image: redis:7
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 5s
--health-timeout 5s
--health-retries 10
steps:
# Ensure secrets don't leak
@@ -138,6 +147,8 @@ jobs:
LETTA_PG_PASSWORD: postgres
LETTA_PG_DB: postgres
LETTA_PG_HOST: localhost
LETTA_REDIS_HOST: localhost
LETTA_REDIS_PORT: 6379
LETTA_SERVER_PASS: test_server_token
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}

View File

@@ -5,7 +5,7 @@ try:
__version__ = version("letta")
except PackageNotFoundError:
# Fallback for development installations
__version__ = "0.11.4"
__version__ = "0.11.5"
if os.environ.get("LETTA_VERSION"):
__version__ = os.environ["LETTA_VERSION"]

View File

@@ -42,6 +42,7 @@ from letta.log import get_logger
from letta.memory import summarize_messages
from letta.orm import User
from letta.otel.tracing import log_event, trace_method
from letta.prompts.prompt_generator import PromptGenerator
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_prompt_template_for_agent_type
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
@@ -59,7 +60,7 @@ from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import check_supports_structured_output, compile_memory_metadata_block
from letta.services.helpers.agent_manager_helper import check_supports_structured_output
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
from letta.services.job_manager import JobManager
from letta.services.mcp.base_client import AsyncBaseMCPClient
@@ -330,8 +331,13 @@ class Agent(BaseAgent):
return None
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
# Extract terminal tool names from tool rules
terminal_tool_names = {rule.tool_name for rule in self.tool_rules_solver.terminal_tool_rules}
allowed_functions = runtime_override_tool_json_schema(
tool_list=allowed_functions, response_format=self.agent_state.response_format, request_heartbeat=True
tool_list=allowed_functions,
response_format=self.agent_state.response_format,
request_heartbeat=True,
terminal_tools=terminal_tool_names,
)
# For the first message, force the initial tool if one is specified
@@ -1246,7 +1252,7 @@ class Agent(BaseAgent):
agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id)
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
external_memory_summary = compile_memory_metadata_block(
external_memory_summary = PromptGenerator.compile_memory_metadata_block(
memory_edit_timestamp=get_utc_time(),
timezone=self.agent_state.timezone,
previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id),

View File

@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_MAX_STEPS
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.prompts.prompt_generator import PromptGenerator
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
@@ -17,7 +18,6 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.helpers.agent_manager_helper import get_system_message_from_compiled_memory
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.utils import united_diff
@@ -142,7 +142,7 @@ class BaseAgent(ABC):
if num_archival_memories is None:
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = get_system_message_from_compiled_memory(
new_system_message_str = PromptGenerator.get_system_message_from_compiled_memory(
system_prompt=agent_state.system,
memory_with_sources=curr_memory_str,
in_context_memory_last_edit=memory_edit_timestamp,

View File

@@ -137,6 +137,10 @@ class LettaAgent(BaseAgent):
message_buffer_limit=message_buffer_limit,
message_buffer_min=message_buffer_min,
partial_evict_summarizer_percentage=partial_evict_summarizer_percentage,
agent_manager=self.agent_manager,
message_manager=self.message_manager,
actor=self.actor,
agent_id=self.agent_id,
)
async def _check_run_cancellation(self) -> bool:
@@ -345,16 +349,17 @@ class LettaAgent(BaseAgent):
agent_step_span.end()
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
# stream step
# TODO: improve TTFT
@@ -642,17 +647,18 @@ class LettaAgent(BaseAgent):
agent_step_span.end()
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
step_progression = StepProgression.LOGGED_TRACE
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
@@ -1003,31 +1009,32 @@ class LettaAgent(BaseAgent):
# Log LLM Trace
# We are piecing together the streamed response here.
# Content here does not match the actual response schema as streams come in chunks.
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
},
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
},
},
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
@@ -1352,6 +1359,7 @@ class LettaAgent(BaseAgent):
) -> list[Message]:
# If total tokens is reached, we truncate down
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
# TODO: `force` and `clear` seem to no longer be used, we should remove
if force or (total_tokens and total_tokens > llm_config.context_window):
self.logger.warning(
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
@@ -1363,6 +1371,7 @@ class LettaAgent(BaseAgent):
clear=True,
)
else:
# NOTE (Sarah): Seems like this is doing nothing?
self.logger.info(
f"Total tokens {total_tokens} does not exceed configured max tokens {llm_config.context_window}, passing summarizing w/o force."
)
@@ -1453,8 +1462,10 @@ class LettaAgent(BaseAgent):
force_tool_call = valid_tool_names[0]
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
# Extract terminal tool names from tool rules
terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules}
allowed_tools = runtime_override_tool_json_schema(
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True, terminal_tools=terminal_tool_names
)
return (

View File

@@ -13,6 +13,7 @@ from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import add_pre_execution_message, enable_strict_mode, remove_request_heartbeat
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
from letta.log import get_logger
from letta.prompts.prompt_generator import PromptGenerator
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole, ToolType
from letta.schemas.letta_response import LettaResponse
@@ -35,7 +36,6 @@ from letta.server.rest_api.utils import (
)
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_system_message_async
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
@@ -144,7 +144,7 @@ class VoiceAgent(BaseAgent):
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=self.actor)
memory_edit_timestamp = get_utc_time()
in_context_messages[0].content[0].text = await compile_system_message_async(
in_context_messages[0].content[0].text = await PromptGenerator.compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,

View File

@@ -1,6 +1,6 @@
import asyncio
from functools import wraps
from typing import Any, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
from letta.log import get_logger
@@ -218,6 +218,126 @@ class AsyncRedisClient:
client = await self.get_client()
return await client.decr(key)
# Stream operations
@with_retry()
async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str:
"""Add entry to a stream.
Args:
stream: Stream name
fields: Dict of field-value pairs to add
id: Entry ID ('*' for auto-generation)
maxlen: Maximum length of the stream
approximate: Whether maxlen is approximate
Returns:
The ID of the added entry
"""
client = await self.get_client()
return await client.xadd(stream, fields, id=id, maxlen=maxlen, approximate=approximate)
@with_retry()
async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]:
"""Read from streams.
Args:
streams: Dict mapping stream names to IDs
count: Maximum number of entries to return
block: Milliseconds to block waiting for data (None = no blocking)
Returns:
List of entries from the streams
"""
client = await self.get_client()
return await client.xread(streams, count=count, block=block)
@with_retry()
async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]:
"""Read range of entries from a stream.
Args:
stream: Stream name
start: Start ID (inclusive)
end: End ID (inclusive)
count: Maximum number of entries to return
Returns:
List of entries in the specified range
"""
client = await self.get_client()
return await client.xrange(stream, start, end, count=count)
@with_retry()
async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]:
"""Read range of entries from a stream in reverse order.
Args:
stream: Stream name
start: Start ID (inclusive)
end: End ID (inclusive)
count: Maximum number of entries to return
Returns:
List of entries in the specified range in reverse order
"""
client = await self.get_client()
return await client.xrevrange(stream, start, end, count=count)
@with_retry()
async def xlen(self, stream: str) -> int:
"""Get the length of a stream.
Args:
stream: Stream name
Returns:
Number of entries in the stream
"""
client = await self.get_client()
return await client.xlen(stream)
@with_retry()
async def xdel(self, stream: str, *ids: str) -> int:
"""Delete entries from a stream.
Args:
stream: Stream name
ids: IDs of entries to delete
Returns:
Number of entries deleted
"""
client = await self.get_client()
return await client.xdel(stream, *ids)
@with_retry()
async def xinfo_stream(self, stream: str) -> Dict:
"""Get information about a stream.
Args:
stream: Stream name
Returns:
Dict with stream information
"""
client = await self.get_client()
return await client.xinfo_stream(stream)
@with_retry()
async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int:
"""Trim a stream to a maximum length.
Args:
stream: Stream name
maxlen: Maximum length
approximate: Whether maxlen is approximate
Returns:
Number of entries removed
"""
client = await self.get_client()
return await client.xtrim(stream, maxlen=maxlen, approximate=approximate)
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
exclude_key = self._get_group_exclusion_key(group)
include_key = self._get_group_inclusion_key(group)
@@ -290,6 +410,31 @@ class NoopAsyncRedisClient(AsyncRedisClient):
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
return 0
# Stream operations
async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str:
return ""
async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]:
return []
async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]:
return []
async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]:
return []
async def xlen(self, stream: str) -> int:
return 0
async def xdel(self, stream: str, *ids: str) -> int:
return 0
async def xinfo_stream(self, stream: str) -> Dict:
return {}
async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int:
return 0
async def get_redis_client() -> AsyncRedisClient:
global _client_instance

View File

@@ -76,6 +76,10 @@ class LettaUserNotFoundError(LettaError):
"""Error raised when a user is not found."""
class LettaUnexpectedStreamCancellationError(LettaError):
"""Error raised when a streaming request is terminated unexpectedly."""
class LLMError(LettaError):
pass

View File

@@ -21,8 +21,8 @@ async def open_files(agent_state: "AgentState", file_requests: List[FileOpenRequ
Open multiple files with different view ranges:
file_requests = [
FileOpenRequest(file_name="project_utils/config.py", offset=1, length=50), # Lines 1-50
FileOpenRequest(file_name="project_utils/main.py", offset=100, length=100), # Lines 100-199
FileOpenRequest(file_name="project_utils/config.py", offset=0, length=50), # Lines 1-50
FileOpenRequest(file_name="project_utils/main.py", offset=100, length=100), # Lines 101-200
FileOpenRequest(file_name="project_utils/utils.py") # Entire file
]

View File

@@ -1,156 +0,0 @@
from letta.log import get_logger
logger = get_logger(__name__)
# class BaseMCPClient:
# def __init__(self, server_config: BaseServerConfig):
# self.server_config = server_config
# self.session: Optional[ClientSession] = None
# self.stdio = None
# self.write = None
# self.initialized = False
# self.loop = asyncio.new_event_loop()
# self.cleanup_funcs = []
#
# def connect_to_server(self):
# asyncio.set_event_loop(self.loop)
# success = self._initialize_connection(self.server_config, timeout=tool_settings.mcp_connect_to_server_timeout)
#
# if success:
# try:
# self.loop.run_until_complete(
# asyncio.wait_for(self.session.initialize(), timeout=tool_settings.mcp_connect_to_server_timeout)
# )
# self.initialized = True
# except asyncio.TimeoutError:
# raise MCPTimeoutError("initializing session", self.server_config.server_name, tool_settings.mcp_connect_to_server_timeout)
# else:
# raise RuntimeError(
# f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
# )
#
# def _initialize_connection(self, server_config: BaseServerConfig, timeout: float) -> bool:
# raise NotImplementedError("Subclasses must implement _initialize_connection")
#
# def list_tools(self) -> List[MCPTool]:
# self._check_initialized()
# try:
# response = self.loop.run_until_complete(
# asyncio.wait_for(self.session.list_tools(), timeout=tool_settings.mcp_list_tools_timeout)
# )
# return response.tools
# except asyncio.TimeoutError:
# logger.error(
# f"Timed out while listing tools for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_list_tools_timeout}s)."
# )
# raise MCPTimeoutError("listing tools", self.server_config.server_name, tool_settings.mcp_list_tools_timeout)
#
# def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
# self._check_initialized()
# try:
# result = self.loop.run_until_complete(
# asyncio.wait_for(self.session.call_tool(tool_name, tool_args), timeout=tool_settings.mcp_execute_tool_timeout)
# )
#
# parsed_content = []
# for content_piece in result.content:
# if isinstance(content_piece, TextContent):
# parsed_content.append(content_piece.text)
# print("parsed_content (text)", parsed_content)
# else:
# parsed_content.append(str(content_piece))
# print("parsed_content (other)", parsed_content)
#
# if len(parsed_content) > 0:
# final_content = " ".join(parsed_content)
# else:
# # TODO move hardcoding to constants
# final_content = "Empty response from tool"
#
# return final_content, result.isError
# except asyncio.TimeoutError:
# logger.error(
# f"Timed out while executing tool '{tool_name}' for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_execute_tool_timeout}s)."
# )
# raise MCPTimeoutError(f"executing tool '{tool_name}'", self.server_config.server_name, tool_settings.mcp_execute_tool_timeout)
#
# def _check_initialized(self):
# if not self.initialized:
# logger.error("MCPClient has not been initialized")
# raise RuntimeError("MCPClient has not been initialized")
#
# def cleanup(self):
# try:
# for cleanup_func in self.cleanup_funcs:
# cleanup_func()
# self.initialized = False
# if not self.loop.is_closed():
# self.loop.close()
# except Exception as e:
# logger.warning(e)
# finally:
# logger.info("Cleaned up MCP clients on shutdown.")
#
#
# class BaseAsyncMCPClient:
# def __init__(self, server_config: BaseServerConfig):
# self.server_config = server_config
# self.session: Optional[ClientSession] = None
# self.stdio = None
# self.write = None
# self.initialized = False
# self.cleanup_funcs = []
#
# async def connect_to_server(self):
#
# success = await self._initialize_connection(self.server_config, timeout=tool_settings.mcp_connect_to_server_timeout)
#
# if success:
# self.initialized = True
# else:
# raise RuntimeError(
# f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
# )
#
# async def list_tools(self) -> List[MCPTool]:
# self._check_initialized()
# response = await self.session.list_tools()
# return response.tools
#
# async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
# self._check_initialized()
# result = await self.session.call_tool(tool_name, tool_args)
#
# parsed_content = []
# for content_piece in result.content:
# if isinstance(content_piece, TextContent):
# parsed_content.append(content_piece.text)
# else:
# parsed_content.append(str(content_piece))
#
# if len(parsed_content) > 0:
# final_content = " ".join(parsed_content)
# else:
# # TODO move hardcoding to constants
# final_content = "Empty response from tool"
#
# return final_content, result.isError
#
# def _check_initialized(self):
# if not self.initialized:
# logger.error("MCPClient has not been initialized")
# raise RuntimeError("MCPClient has not been initialized")
#
# async def cleanup(self):
# try:
# for cleanup_func in self.cleanup_funcs:
# cleanup_func()
# self.initialized = False
# if not self.loop.is_closed():
# self.loop.close()
# except Exception as e:
# logger.warning(e)
# finally:
# logger.info("Cleaned up MCP clients on shutdown.")
#

View File

@@ -1,51 +0,0 @@
# import asyncio
#
# from mcp import ClientSession
# from mcp.client.sse import sse_client
#
# from letta.functions.mcp_client.base_client import BaseAsyncMCPClient, BaseMCPClient
# from letta.functions.mcp_client.types import SSEServerConfig
# from letta.log import get_logger
#
## see: https://modelcontextprotocol.io/quickstart/user
#
# logger = get_logger(__name__)
# class SSEMCPClient(BaseMCPClient):
# def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool:
# try:
# sse_cm = sse_client(url=server_config.server_url)
# sse_transport = self.loop.run_until_complete(asyncio.wait_for(sse_cm.__aenter__(), timeout=timeout))
# self.stdio, self.write = sse_transport
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None)))
#
# session_cm = ClientSession(self.stdio, self.write)
# self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
# return True
# except asyncio.TimeoutError:
# logger.error(f"Timed out while establishing SSE connection (timeout={timeout}s).")
# return False
# except Exception:
# logger.exception("Exception occurred while initializing SSE client session.")
# return False
#
#
# class AsyncSSEMCPClient(BaseAsyncMCPClient):
#
# async def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool:
# try:
# sse_cm = sse_client(url=server_config.server_url)
# sse_transport = await sse_cm.__aenter__()
# self.stdio, self.write = sse_transport
# self.cleanup_funcs.append(lambda: sse_cm.__aexit__(None, None, None))
#
# session_cm = ClientSession(self.stdio, self.write)
# self.session = await session_cm.__aenter__()
# self.cleanup_funcs.append(lambda: session_cm.__aexit__(None, None, None))
# return True
# except Exception:
# logger.exception("Exception occurred while initializing SSE client session.")
# return False
#

View File

@@ -1,109 +0,0 @@
# import asyncio
# import sys
# from contextlib import asynccontextmanager
#
# import anyio
# import anyio.lowlevel
# import mcp.types as types
# from anyio.streams.text import TextReceiveStream
# from mcp import ClientSession, StdioServerParameters
# from mcp.client.stdio import get_default_environment
#
# from letta.functions.mcp_client.base_client import BaseMCPClient
# from letta.functions.mcp_client.types import StdioServerConfig
# from letta.log import get_logger
#
# logger = get_logger(__name__)
# class StdioMCPClient(BaseMCPClient):
# def _initialize_connection(self, server_config: StdioServerConfig, timeout: float) -> bool:
# try:
# server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
# stdio_cm = forked_stdio_client(server_params)
# stdio_transport = self.loop.run_until_complete(asyncio.wait_for(stdio_cm.__aenter__(), timeout=timeout))
# self.stdio, self.write = stdio_transport
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None)))
#
# session_cm = ClientSession(self.stdio, self.write)
# self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
# return True
# except asyncio.TimeoutError:
# logger.error(f"Timed out while establishing stdio connection (timeout={timeout}s).")
# return False
# except Exception:
# logger.exception("Exception occurred while initializing stdio client session.")
# return False
#
#
# @asynccontextmanager
# async def forked_stdio_client(server: StdioServerParameters):
# """
# Client transport for stdio: this will connect to a server by spawning a
# process and communicating with it over stdin/stdout.
# """
# read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
# write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
#
# try:
# process = await anyio.open_process(
# [server.command, *server.args],
# env=server.env or get_default_environment(),
# stderr=sys.stderr, # Consider logging stderr somewhere instead of silencing it
# )
# except OSError as exc:
# raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc
#
# async def stdout_reader():
# assert process.stdout, "Opened process is missing stdout"
# buffer = ""
# try:
# async with read_stream_writer:
# async for chunk in TextReceiveStream(
# process.stdout,
# encoding=server.encoding,
# errors=server.encoding_error_handler,
# ):
# lines = (buffer + chunk).split("\n")
# buffer = lines.pop()
# for line in lines:
# try:
# message = types.JSONRPCMessage.model_validate_json(line)
# except Exception as exc:
# await read_stream_writer.send(exc)
# continue
# await read_stream_writer.send(message)
# except anyio.ClosedResourceError:
# await anyio.lowlevel.checkpoint()
#
# async def stdin_writer():
# assert process.stdin, "Opened process is missing stdin"
# try:
# async with write_stream_reader:
# async for message in write_stream_reader:
# json = message.model_dump_json(by_alias=True, exclude_none=True)
# await process.stdin.send(
# (json + "\n").encode(
# encoding=server.encoding,
# errors=server.encoding_error_handler,
# )
# )
# except anyio.ClosedResourceError:
# await anyio.lowlevel.checkpoint()
#
# async def watch_process_exit():
# returncode = await process.wait()
# if returncode != 0:
# raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}")
#
# async with anyio.create_task_group() as tg, process:
# tg.start_soon(stdout_reader)
# tg.start_soon(stdin_writer)
# tg.start_soon(watch_process_exit)
#
# with anyio.move_on_after(0.2):
# await anyio.sleep_forever()
#
# yield read_stream, write_stream
#

View File

@@ -148,9 +148,21 @@ class SSEServerConfig(BaseServerConfig):
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with SSE requests")
def resolve_token(self) -> Optional[str]:
if self.auth_token and self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
return self.auth_token
"""
Extract token for storage if auth_header/auth_token are provided
and not already in custom_headers.
Returns:
The resolved token (without Bearer prefix) if it should be stored separately, None otherwise
"""
if self.auth_token and self.auth_header:
# Check if custom_headers already has the auth header
if not self.custom_headers or self.auth_header not in self.custom_headers:
# Strip Bearer prefix if present
if self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
return self.auth_token
return None
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
@@ -217,9 +229,21 @@ class StreamableHTTPServerConfig(BaseServerConfig):
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with streamable HTTP requests")
def resolve_token(self) -> Optional[str]:
if self.auth_token and self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
return self.auth_token
"""
Extract token for storage if auth_header/auth_token are provided
and not already in custom_headers.
Returns:
The resolved token (without Bearer prefix) if it should be stored separately, None otherwise
"""
if self.auth_token and self.auth_header:
# Check if custom_headers already has the auth header
if not self.custom_headers or self.auth_header not in self.custom_headers:
# Strip Bearer prefix if present
if self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
return self.auth_token
return None
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
if self.auth_token and super().is_templated_tool_variable(self.auth_token):

View File

@@ -608,13 +608,58 @@ def generate_tool_schema_for_mcp(
# Normalise so downstream code can treat it consistently.
parameters_schema.setdefault("required", [])
# Process properties to handle anyOf types and make optional fields strict-compatible
if "properties" in parameters_schema:
for field_name, field_props in parameters_schema["properties"].items():
# Handle anyOf types by flattening to type array
if "anyOf" in field_props and "type" not in field_props:
types = []
format_value = None
for option in field_props["anyOf"]:
if "type" in option:
types.append(option["type"])
# Capture format if present (e.g., uuid format for strings)
if "format" in option and not format_value:
format_value = option["format"]
if types:
# Deduplicate types using set
field_props["type"] = list(set(types))
# Only add format if the field is not optional (doesn't have null type)
if format_value and len(field_props["type"]) == 1 and "null" not in field_props["type"]:
field_props["format"] = format_value
# Remove the anyOf since we've flattened it
del field_props["anyOf"]
# For strict mode: heal optional fields by making them required with null type
if strict and field_name not in parameters_schema["required"]:
# Field is optional - add it to required array
parameters_schema["required"].append(field_name)
# Ensure the field can accept null to maintain optionality
if "type" in field_props:
if isinstance(field_props["type"], list):
# Already an array of types - add null if not present
if "null" not in field_props["type"]:
field_props["type"].append("null")
# Deduplicate
field_props["type"] = list(set(field_props["type"]))
elif field_props["type"] != "null":
# Single type - convert to array with null
field_props["type"] = list(set([field_props["type"], "null"]))
elif "anyOf" in field_props:
# If there's still an anyOf, ensure null is one of the options
has_null = any(opt.get("type") == "null" for opt in field_props["anyOf"])
if not has_null:
field_props["anyOf"].append({"type": "null"})
# Add the optional heartbeat parameter
if append_heartbeat:
parameters_schema["properties"][REQUEST_HEARTBEAT_PARAM] = {
"type": "boolean",
"description": REQUEST_HEARTBEAT_DESCRIPTION,
}
parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM)
if REQUEST_HEARTBEAT_PARAM not in parameters_schema["required"]:
parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM)
# Return the final schema
if strict:

View File

@@ -116,15 +116,21 @@ def validate_complete_json_schema(schema: Dict[str, Any]) -> Tuple[SchemaHealth,
required = node.get("required")
if required is None:
# TODO: @jnjpng skip this check for now, seems like OpenAI strict mode doesn't enforce this
# Only mark as non-strict for nested objects, not root
if not is_root:
mark_non_strict(f"{path}: 'required' not specified for object")
# if not is_root:
# mark_non_strict(f"{path}: 'required' not specified for object")
required = []
elif not isinstance(required, list):
mark_invalid(f"{path}: 'required' must be a list if present")
required = []
# OpenAI strict-mode extra checks:
# NOTE: We no longer flag properties not in required array as non-strict
# because we can heal these schemas by adding null to the type union
# This allows MCP tools with optional fields to be used with strict mode
# The healing happens in generate_tool_schema_for_mcp() when strict=True
for req_key in required:
if props and req_key not in props:
mark_invalid(f"{path}: required contains '{req_key}' not found in properties")
@@ -161,6 +167,15 @@ def validate_complete_json_schema(schema: Dict[str, Any]) -> Tuple[SchemaHealth,
# These are generally fine, but check for specific constraints
pass
# TYPE ARRAYS (e.g., ["string", "null"] for optional fields)
elif isinstance(node_type, list):
# Type arrays are allowed in OpenAI strict mode
# They represent union types (e.g., string | null)
for t in node_type:
# TODO: @jnjpng handle enum types?
if t not in ["string", "number", "integer", "boolean", "null", "array", "object"]:
mark_invalid(f"{path}: Invalid type '{t}' in type array")
# UNION TYPES
for kw in ("anyOf", "oneOf", "allOf"):
if kw in node:

View File

@@ -11,7 +11,7 @@ class SearchTask(BaseModel):
class FileOpenRequest(BaseModel):
file_name: str = Field(description="Name of the file to open")
offset: Optional[int] = Field(
default=None, description="Optional starting line number (1-indexed). If not specified, starts from beginning of file."
default=None, description="Optional offset for starting line number (0-indexed). If not specified, starts from beginning of file."
)
length: Optional[int] = Field(
default=None, description="Optional number of lines to view from offset (inclusive). If not specified, views to end of file."

View File

@@ -39,12 +39,10 @@ def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
# Ensure parameters is a valid dictionary
parameters = schema.get("parameters", {})
if isinstance(parameters, dict) and parameters.get("type") == "object":
# Set additionalProperties to False
parameters["additionalProperties"] = False
schema["parameters"] = parameters
# Remove the metadata fields from the schema
schema.pop(MCP_TOOL_METADATA_SCHEMA_STATUS, None)
schema.pop(MCP_TOOL_METADATA_SCHEMA_WARNINGS, None)

View File

@@ -287,12 +287,34 @@ class AnthropicClient(LLMClientBase):
else:
anthropic_tools = None
thinking_enabled = False
if messages and len(messages) > 0:
# Check if the last assistant message starts with a thinking block
# Find the last assistant message
last_assistant_message = None
for message in reversed(messages):
if message.get("role") == "assistant":
last_assistant_message = message
break
if (
last_assistant_message
and isinstance(last_assistant_message.get("content"), list)
and len(last_assistant_message["content"]) > 0
and last_assistant_message["content"][0].get("type") == "thinking"
):
thinking_enabled = True
try:
result = await client.beta.messages.count_tokens(
model=model or "claude-3-7-sonnet-20250219",
messages=messages or [{"role": "user", "content": "hi"}],
tools=anthropic_tools or [],
)
count_params = {
"model": model or "claude-3-7-sonnet-20250219",
"messages": messages or [{"role": "user", "content": "hi"}],
"tools": anthropic_tools or [],
}
if thinking_enabled:
count_params["thinking"] = {"type": "enabled", "budget_tokens": 16000}
result = await client.beta.messages.count_tokens(**count_params)
except:
raise

View File

@@ -0,0 +1,97 @@
import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.deepseek import convert_deepseek_response_to_chatcompletion, map_messages_to_deepseek_format
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import model_settings
class DeepseekClient(OpenAIClient):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return False
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
) -> dict:
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
llm_config.put_inner_thoughts_in_kwargs = False
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
def add_functions_to_system_message(system_message: ChatMessage):
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in functions)} </available functions>"
system_message.content += 'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.'
if llm_config.model == "deepseek-reasoner": # R1 currently doesn't support function calling natively
add_functions_to_system_message(
data["messages"][0]
) # Inject additional instructions to the system prompt with the available functions
data["messages"] = map_messages_to_deepseek_format(data["messages"])
return data
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)
return response_stream
@trace_method
def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
Handles potential extraction of inner thoughts if they were added via kwargs.
"""
response = ChatCompletionResponse(**response_data)
return convert_deepseek_response_to_chatcompletion(response)

View File

@@ -0,0 +1,79 @@
import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.settings import model_settings
class GroqClient(OpenAIClient):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return True
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
) -> dict:
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
# Groq validation - these fields are not supported and will cause 400 errors
# https://console.groq.com/docs/openai
if "top_logprobs" in data:
del data["top_logprobs"]
if "logit_bias" in data:
del data["logit_bias"]
data["logprobs"] = False
data["n"] = 1
return data
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to Groq API and returns raw response dict.
"""
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to Groq API and returns raw response dict.
"""
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
raise NotImplementedError("Streaming not supported for Groq.")

View File

@@ -133,7 +133,6 @@ def convert_to_structured_output(openai_function: dict, allow_optional: bool = F
structured_output["parameters"]["required"] = list(structured_output["parameters"]["properties"].keys())
else:
raise NotImplementedError("Optional parameter handling is not implemented.")
return structured_output

View File

@@ -8,7 +8,7 @@ import requests
from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LettaConfigurationError, RateLimitExceededError
from letta.llm_api.deepseek import build_deepseek_chat_completions_request, convert_deepseek_response_to_chatcompletion
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
from letta.llm_api.helpers import unpack_all_inner_thoughts_from_kwargs
from letta.llm_api.openai import (
build_openai_chat_completions_request,
openai_chat_completions_process_stream,
@@ -16,14 +16,13 @@ from letta.llm_api.openai import (
prepare_openai_payload,
)
from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.orm.user import User
from letta.otel.tracing import log_event, trace_method
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.services.telemetry_manager import TelemetryManager
@@ -246,116 +245,6 @@ def create(
return response
elif llm_config.model_endpoint_type == "xai":
api_key = model_settings.xai_api_key
if function_call is None and functions is not None and len(functions) > 0:
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
function_call = "required"
data = build_openai_chat_completions_request(
llm_config,
messages,
user_id,
functions,
function_call,
use_tool_naming,
put_inner_thoughts_first=put_inner_thoughts_first,
use_structured_output=False, # NOTE: not supported atm for xAI
)
# Specific bug for the mini models (as of Apr 14, 2025)
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
if "grok-3-mini-" in llm_config.model:
data.presence_penalty = None
data.frequency_penalty = None
if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
stream_interface, AgentRefreshStreamingInterface
), type(stream_interface)
response = openai_chat_completions_process_stream(
url=llm_config.model_endpoint,
api_key=api_key,
chat_completion_request=data,
stream_interface=stream_interface,
name=name,
# TODO turn on to support reasoning content from xAI reasoners:
# https://docs.x.ai/docs/guides/reasoning#reasoning
expect_reasoning_content=False,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=api_key,
chat_completion_request=data,
)
finally:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_end()
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
return response
elif llm_config.model_endpoint_type == "groq":
if stream:
raise NotImplementedError("Streaming not yet implemented for Groq.")
if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions":
raise LettaConfigurationError(message="Groq key is missing from letta config file", missing_fields=["groq_api_key"])
# force to true for groq, since they don't support 'content' is non-null
if llm_config.put_inner_thoughts_in_kwargs:
functions = add_inner_thoughts_to_functions(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION,
)
tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None
data = ChatCompletionRequest(
model=llm_config.model,
messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) for m in messages],
tools=tools,
tool_choice=function_call,
user=str(user_id),
)
# https://console.groq.com/docs/openai
# "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:"
assert data.top_logprobs is None
assert data.logit_bias is None
assert data.logprobs == False
assert data.n == 1
# They mention that none of the messages can have names, but it seems to not error out (for now)
data.stream = False
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
# groq uses the openai chat completions API, so this component should be reusable
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.groq_api_key,
chat_completion_request=data,
)
finally:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_end()
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
return response
elif llm_config.model_endpoint_type == "deepseek":
if model_settings.deepseek_api_key is None and llm_config.model_endpoint == "":
# only is a problem if we are *not* using an openai proxy

View File

@@ -79,5 +79,26 @@ class LLMClient:
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.xai:
from letta.llm_api.xai_client import XAIClient
return XAIClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.groq:
from letta.llm_api.groq_client import GroqClient
return GroqClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.deepseek:
from letta.llm_api.deepseek_client import DeepseekClient
return DeepseekClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case _:
return None

View File

@@ -15,6 +15,7 @@ from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.services.telemetry_manager import TelemetryManager
from letta.settings import settings
if TYPE_CHECKING:
from letta.orm import User
@@ -90,15 +91,16 @@ class LLMClientBase:
try:
log_event(name="llm_request_sent", attributes=request_data)
response_data = await self.request_async(request_data, llm_config)
await telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
if settings.track_provider_trace and telemetry_manager:
await telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:

View File

@@ -146,6 +146,9 @@ class OpenAIClient(LLMClientBase):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return requires_auto_tool_choice(llm_config)
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return supports_structured_output(llm_config)
@trace_method
def build_request_data(
self,

View File

@@ -0,0 +1,85 @@
import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.settings import model_settings
class XAIClient(OpenAIClient):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return False
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
) -> dict:
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
# Specific bug for the mini models (as of Apr 14, 2025)
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
if "grok-3-mini-" in llm_config.model:
data.pop("presence_penalty", None)
data.pop("frequency_penalty", None)
return data
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]

View File

@@ -0,0 +1,190 @@
from datetime import datetime
from typing import List, Literal, Optional
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import format_datetime, get_local_time_fast
from letta.otel.tracing import trace_method
from letta.schemas.memory import Memory
class PromptGenerator:
# TODO: This code is kind of wonky and deserves a rewrite
@trace_method
@staticmethod
def compile_memory_metadata_block(
memory_edit_timestamp: datetime,
timezone: str,
previous_message_count: int = 0,
archival_memory_size: Optional[int] = 0,
) -> str:
"""
Generate a memory metadata block for the agent's system prompt.
This creates a structured metadata section that informs the agent about
the current state of its memory systems, including timing information
and memory counts. This helps the agent understand what information
is available through its tools.
Args:
memory_edit_timestamp: When memory blocks were last modified
timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
previous_message_count: Number of messages in recall memory (conversation history)
archival_memory_size: Number of items in archival memory (long-term storage)
Returns:
A formatted string containing the memory metadata block with XML-style tags
Example Output:
<memory_metadata>
- The current time is: 2024-01-15 10:30 AM PST
- Memory blocks were last modified: 2024-01-15 09:00 AM PST
- 42 previous messages between you and the user are stored in recall memory (use tools to access them)
- 156 total memories you created are stored in archival memory (use tools to access them)
</memory_metadata>
"""
# Put the timestamp in the local timezone (mimicking get_local_time())
timestamp_str = format_datetime(memory_edit_timestamp, timezone)
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
metadata_lines = [
"<memory_metadata>",
f"- The current time is: {get_local_time_fast(timezone)}",
f"- Memory blocks were last modified: {timestamp_str}",
f"- {previous_message_count} previous messages between you and the user are stored in recall memory (use tools to access them)",
]
# Only include archival memory line if there are archival memories
if archival_memory_size is not None and archival_memory_size > 0:
metadata_lines.append(
f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)"
)
metadata_lines.append("</memory_metadata>")
memory_metadata_block = "\n".join(metadata_lines)
return memory_metadata_block
@staticmethod
def safe_format(template: str, variables: dict) -> str:
"""
Safely formats a template string, preserving empty {} and {unknown_vars}
while substituting known variables.
If we simply use {} in format_map, it'll be treated as a positional field
"""
# First escape any empty {} by doubling them
escaped = template.replace("{}", "{{}}")
# Now use format_map with our custom mapping
return escaped.format_map(PreserveMapping(variables))
@trace_method
@staticmethod
def get_system_message_from_compiled_memory(
system_prompt: str,
memory_with_sources: str,
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
timezone: str,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
The base system message may be templated, in which case we need to render the variables.
The following are reserved variables:
- CORE_MEMORY: the in-context memory of the LLM
"""
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject
raise NotImplementedError
else:
variables = {}
# Add the protected memory variable
if IN_CONTEXT_MEMORY_KEYWORD in variables:
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = PromptGenerator.compile_memory_metadata_block(
memory_edit_timestamp=in_context_memory_last_edit,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
timezone=timezone,
)
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
# Add to the variables list to inject
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
if template_format == "f-string":
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
# Catch the special case where the system prompt is unformatted
if append_icm_if_missing:
if memory_variable_string not in system_prompt:
# In this case, append it to the end to make sure memory is still injected
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
system_prompt += "\n\n" + memory_variable_string
# render the variables using the built-in templater
try:
if user_defined_variables:
formatted_prompt = PromptGenerator.safe_format(system_prompt, variables)
else:
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
except Exception as e:
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
else:
# TODO support for mustache and jinja2
raise NotImplementedError(template_format)
return formatted_prompt
@trace_method
@staticmethod
async def compile_system_message_async(
system_prompt: str,
in_context_memory: Memory,
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
timezone: str,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
tool_rules_solver: Optional[ToolRulesSolver] = None,
sources: Optional[List] = None,
max_files_open: Optional[int] = None,
) -> str:
tool_constraint_block = None
if tool_rules_solver is not None:
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject
raise NotImplementedError
else:
pass
memory_with_sources = await in_context_memory.compile_in_thread_async(
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
)
return PromptGenerator.get_system_message_from_compiled_memory(
system_prompt=system_prompt,
memory_with_sources=memory_with_sources,
in_context_memory_last_edit=in_context_memory_last_edit,
timezone=timezone,
user_defined_variables=user_defined_variables,
append_icm_if_missing=append_icm_if_missing,
template_format=template_format,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
)

View File

@@ -1,15 +1,17 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from pydantic import BaseModel, Field
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import Block, CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
from letta.schemas.group import Group, GroupCreate
from letta.schemas.mcp import MCPServer
from letta.schemas.message import Message, MessageCreate
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.source import Source, SourceCreate
from letta.schemas.tool import Tool
from letta.schemas.user import User
@@ -46,6 +48,15 @@ class MessageSchema(MessageCreate):
role: MessageRole = Field(..., description="The role of the participant.")
model: Optional[str] = Field(None, description="The model used to make the function call")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent")
tool_calls: Optional[List[OpenAIToolCall]] = Field(
default=None, description="The list of tool calls requested. Only applicable for role assistant."
)
tool_call_id: Optional[str] = Field(default=None, description="The ID of the tool call. Only applicable for role tool.")
tool_returns: Optional[List[ToolReturn]] = Field(default=None, description="Tool execution return information for prior tool calls")
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
# TODO: Should we also duplicate the steps here?
# TODO: What about tool_return?
@classmethod
def from_message(cls, message: Message) -> "MessageSchema":
@@ -64,6 +75,10 @@ class MessageSchema(MessageCreate):
group_id=message.group_id,
model=message.model,
agent_id=message.agent_id,
tool_calls=message.tool_calls,
tool_call_id=message.tool_call_id,
tool_returns=message.tool_returns,
created_at=message.created_at,
)
@@ -114,7 +129,7 @@ class AgentSchema(CreateAgent):
memory_blocks=[], # TODO: Convert from agent_state.memory if needed
tools=[],
tool_ids=[tool.id for tool in agent_state.tools] if agent_state.tools else [],
source_ids=[], # [source.id for source in agent_state.sources] if agent_state.sources else [],
source_ids=[source.id for source in agent_state.sources] if agent_state.sources else [],
block_ids=[block.id for block in agent_state.memory.blocks],
tool_rules=agent_state.tool_rules,
tags=agent_state.tags,

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Optional
from typing import List, Optional
from pydantic import Field
@@ -108,3 +108,26 @@ class FileAgent(FileAgentBase):
default_factory=datetime.utcnow,
description="Row last-update timestamp (UTC).",
)
class AgentFileAttachment(LettaBase):
"""Response model for agent file attachments showing file status in agent context"""
id: str = Field(..., description="Unique identifier of the file-agent relationship")
file_id: str = Field(..., description="Unique identifier of the file")
file_name: str = Field(..., description="Name of the file")
folder_id: str = Field(..., description="Unique identifier of the folder/source")
folder_name: str = Field(..., description="Name of the folder/source")
is_open: bool = Field(..., description="Whether the file is currently open in the agent's context")
last_accessed_at: Optional[datetime] = Field(None, description="Timestamp of last access by the agent")
visible_content: Optional[str] = Field(None, description="Portion of the file visible to the agent if open")
start_line: Optional[int] = Field(None, description="Starting line number if file was opened with line range")
end_line: Optional[int] = Field(None, description="Ending line number if file was opened with line range")
class PaginatedAgentFiles(LettaBase):
"""Paginated response for agent files"""
files: List[AgentFileAttachment] = Field(..., description="List of file attachments for the agent")
next_cursor: Optional[str] = Field(None, description="Cursor for fetching the next page (file-agent relationship ID)")
has_more: bool = Field(..., description="Whether more results exist after this page")

View File

@@ -4,6 +4,7 @@ from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.enums import JobStatus, JobType
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import MessageType
@@ -12,6 +13,7 @@ from letta.schemas.letta_message import MessageType
class JobBase(OrmMetadataBase):
__id_prefix__ = "job"
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.")
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")

View File

@@ -52,6 +52,8 @@ class LettaMessage(BaseModel):
sender_id: str | None = None
step_id: str | None = None
is_err: bool | None = None
seq_id: int | None = None
run_id: str | None = None
@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):

View File

@@ -46,6 +46,10 @@ class LettaStreamingRequest(LettaRequest):
default=False,
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
)
background: bool = Field(
default=False,
description="Whether to process the request in the background.",
)
class LettaAsyncRequest(LettaRequest):
@@ -66,3 +70,21 @@ class CreateBatch(BaseModel):
"'status' is the final batch status (e.g., 'completed', 'failed'), and "
"'completed_at' is an ISO 8601 timestamp indicating when the batch job completed.",
)
class RetrieveStreamRequest(BaseModel):
starting_after: int = Field(
0, description="Sequence id to use as a cursor for pagination. Response will start streaming after this chunk sequence id"
)
include_pings: Optional[bool] = Field(
default=False,
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
)
poll_interval: Optional[float] = Field(
default=0.1,
description="Seconds to wait between polls when no new data.",
)
batch_size: Optional[int] = Field(
default=100,
description="Number of entries to read per batch.",
)

View File

@@ -414,6 +414,8 @@ class Message(BaseMessage):
except json.JSONDecodeError:
raise ValueError(f"Failed to decode function return: {text_content}")
# if self.tool_call_id is None:
# import pdb;pdb.set_trace()
assert self.tool_call_id is not None
return ToolReturnMessage(
@@ -844,7 +846,7 @@ class Message(BaseMessage):
}
content = []
# COT / reasoning / thinking
if self.content is not None and len(self.content) > 1:
if self.content is not None and len(self.content) >= 1:
for content_part in self.content:
if isinstance(content_part, ReasoningContent):
content.append(
@@ -861,6 +863,13 @@ class Message(BaseMessage):
"data": content_part.data,
}
)
if isinstance(content_part, TextContent):
content.append(
{
"type": "text",
"text": content_part.text,
}
)
elif text_content is not None:
content.append(
{

View File

@@ -18,6 +18,7 @@ logger = get_logger(__name__)
class BedrockProvider(Provider):
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
access_key: str = Field(..., description="AWS secret access key for Bedrock.")
region: str = Field(..., description="AWS region for Bedrock")
async def bedrock_get_model_list_async(self) -> list[dict]:

View File

@@ -0,0 +1,300 @@
"""Redis stream manager for reading and writing SSE chunks with batching and TTL."""
import asyncio
import json
import time
from collections import defaultdict
from typing import AsyncIterator, Dict, List, Optional
from letta.data_sources.redis_client import AsyncRedisClient
from letta.log import get_logger
logger = get_logger(__name__)
class RedisSSEStreamWriter:
"""
Efficiently writes SSE chunks to Redis streams with batching and TTL management.
Features:
- Batches writes using Redis pipelines for performance
- Automatically sets/refreshes TTL on streams
- Tracks sequential IDs for cursor-based recovery
- Handles flush on size or time thresholds
"""
def __init__(
self,
redis_client: AsyncRedisClient,
flush_interval: float = 0.5,
flush_size: int = 50,
stream_ttl_seconds: int = 10800, # 3 hours default
max_stream_length: int = 10000, # Max entries per stream
):
"""
Initialize the Redis SSE stream writer.
Args:
redis_client: Redis client instance
flush_interval: Seconds between automatic flushes
flush_size: Number of chunks to buffer before flushing
stream_ttl_seconds: TTL for streams in seconds (default: 6 hours)
max_stream_length: Maximum entries per stream before trimming
"""
self.redis = redis_client
self.flush_interval = flush_interval
self.flush_size = flush_size
self.stream_ttl = stream_ttl_seconds
self.max_stream_length = max_stream_length
# Buffer for batching: run_id -> list of chunks
self.buffer: Dict[str, List[Dict]] = defaultdict(list)
# Track sequence IDs per run
self.seq_counters: Dict[str, int] = defaultdict(lambda: 1)
# Track last flush time per run
self.last_flush: Dict[str, float] = defaultdict(float)
# Background flush task
self._flush_task = None
self._running = False
async def start(self):
"""Start the background flush task."""
if not self._running:
self._running = True
self._flush_task = asyncio.create_task(self._periodic_flush())
async def stop(self):
"""Stop the background flush task and flush remaining data."""
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
for run_id in list(self.buffer.keys()):
if self.buffer[run_id]:
await self._flush_run(run_id)
async def write_chunk(
self,
run_id: str,
data: str,
is_complete: bool = False,
) -> int:
"""
Write an SSE chunk to the buffer for a specific run.
Args:
run_id: The run ID to write to
data: SSE-formatted chunk data
is_complete: Whether this is the final chunk
Returns:
The sequence ID assigned to this chunk
"""
seq_id = self.seq_counters[run_id]
self.seq_counters[run_id] += 1
chunk = {
"seq_id": seq_id,
"data": data,
"timestamp": int(time.time() * 1000),
}
if is_complete:
chunk["complete"] = "true"
self.buffer[run_id].append(chunk)
should_flush = (
len(self.buffer[run_id]) >= self.flush_size or is_complete or (time.time() - self.last_flush[run_id]) > self.flush_interval
)
if should_flush:
await self._flush_run(run_id)
return seq_id
async def _flush_run(self, run_id: str):
"""Flush buffered chunks for a specific run to Redis."""
if not self.buffer[run_id]:
return
chunks = self.buffer[run_id]
self.buffer[run_id] = []
stream_key = f"sse:run:{run_id}"
try:
client = await self.redis.get_client()
async with client.pipeline(transaction=False) as pipe:
for chunk in chunks:
pipe.xadd(stream_key, chunk, maxlen=self.max_stream_length, approximate=True)
pipe.expire(stream_key, self.stream_ttl)
await pipe.execute()
self.last_flush[run_id] = time.time()
logger.debug(
f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, " f"seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}"
)
if chunks[-1].get("complete") == "true":
self._cleanup_run(run_id)
except Exception as e:
logger.error(f"Failed to flush chunks for run {run_id}: {e}")
# Put chunks back in buffer to retry
self.buffer[run_id] = chunks + self.buffer[run_id]
raise
async def _periodic_flush(self):
"""Background task to periodically flush buffers."""
while self._running:
try:
await asyncio.sleep(self.flush_interval)
# Check each run for time-based flush
current_time = time.time()
runs_to_flush = [
run_id
for run_id, last_flush in self.last_flush.items()
if (current_time - last_flush) > self.flush_interval and self.buffer[run_id]
]
for run_id in runs_to_flush:
await self._flush_run(run_id)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in periodic flush: {e}")
def _cleanup_run(self, run_id: str):
"""Clean up tracking data for a completed run."""
self.buffer.pop(run_id, None)
self.seq_counters.pop(run_id, None)
self.last_flush.pop(run_id, None)
async def mark_complete(self, run_id: str):
"""Mark a stream as complete and flush."""
# Add a [DONE] marker
await self.write_chunk(run_id, "data: [DONE]\n\n", is_complete=True)
async def create_background_stream_processor(
stream_generator,
redis_client: AsyncRedisClient,
run_id: str,
writer: Optional[RedisSSEStreamWriter] = None,
) -> None:
"""
Process a stream in the background and store chunks to Redis.
This function consumes the stream generator and writes all chunks
to Redis for later retrieval.
Args:
stream_generator: The async generator yielding SSE chunks
redis_client: Redis client instance
run_id: The run ID to store chunks under
writer: Optional pre-configured writer (creates new if not provided)
"""
if writer is None:
writer = RedisSSEStreamWriter(redis_client)
await writer.start()
should_stop_writer = True
else:
should_stop_writer = False
try:
async for chunk in stream_generator:
if isinstance(chunk, tuple):
chunk = chunk[0]
is_done = isinstance(chunk, str) and ("data: [DONE]" in chunk or "event: error" in chunk)
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
if is_done:
break
except Exception as e:
logger.error(f"Error processing stream for run {run_id}: {e}")
# Write error chunk
error_chunk = {"error": {"message": str(e)}}
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=True)
finally:
if should_stop_writer:
await writer.stop()
async def redis_sse_stream_generator(
redis_client: AsyncRedisClient,
run_id: str,
starting_after: Optional[int] = None,
poll_interval: float = 0.1,
batch_size: int = 100,
) -> AsyncIterator[str]:
"""
Generate SSE events from Redis stream chunks.
This generator reads chunks stored in Redis streams and yields them as SSE events.
It supports cursor-based recovery by allowing you to start from a specific seq_id.
Args:
redis_client: Redis client instance
run_id: The run ID to read chunks for
starting_after: Sequential ID (integer) to start reading from (default: None for beginning)
poll_interval: Seconds to wait between polls when no new data (default: 0.1)
batch_size: Number of entries to read per batch (default: 100)
Yields:
SSE-formatted chunks from the Redis stream
"""
stream_key = f"sse:run:{run_id}"
last_redis_id = "-"
cursor_seq_id = starting_after or 0
logger.debug(f"Starting redis_sse_stream_generator for run_id={run_id}, stream_key={stream_key}")
while True:
entries = await redis_client.xrange(stream_key, start=last_redis_id, count=batch_size)
if entries:
yielded_any = False
for entry_id, fields in entries:
if entry_id == last_redis_id:
continue
chunk_seq_id = int(fields.get("seq_id", 0))
if chunk_seq_id > cursor_seq_id:
data = fields.get("data", "")
if not data:
logger.debug(f"No data found for chunk {chunk_seq_id} in run {run_id}")
continue
if '"run_id":null' in data:
data = data.replace('"run_id":null', f'"run_id":"{run_id}"')
if '"seq_id":null' in data:
data = data.replace('"seq_id":null', f'"seq_id":{chunk_seq_id}')
yield data
yielded_any = True
if fields.get("complete") == "true":
return
last_redis_id = entry_id
if not yielded_any and len(entries) > 1:
continue
if not entries or (len(entries) == 1 and entries[0][0] == last_redis_id):
await asyncio.sleep(poll_interval)

View File

@@ -14,7 +14,7 @@ from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.constants import AGENT_ID_PATTERN, DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
from letta.data_sources.redis_client import get_redis_client
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
@@ -26,6 +26,7 @@ from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.agent_file import AgentFileSchema
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.enums import JobType
from letta.schemas.file import AgentFileAttachment, PaginatedAgentFiles
from letta.schemas.group import Group
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
@@ -39,6 +40,7 @@ from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.summarizer.enums import SummarizationMode
@@ -249,6 +251,7 @@ async def import_agent(
override_existing_tools: bool = True,
project_id: str | None = None,
strip_messages: bool = False,
env_vars: Optional[dict[str, Any]] = None,
) -> List[str]:
"""
Import an agent using the new AgentFileSchema format.
@@ -259,7 +262,13 @@ async def import_agent(
raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}")
try:
import_result = await server.agent_serialization_manager.import_file(schema=agent_schema, actor=actor)
import_result = await server.agent_serialization_manager.import_file(
schema=agent_schema,
actor=actor,
append_copy_suffix=append_copy_suffix,
override_existing_tools=override_existing_tools,
env_vars=env_vars,
)
if not import_result.success:
raise HTTPException(
@@ -297,7 +306,9 @@ async def import_agent_serialized(
False,
description="If set to True, strips all messages from the agent before importing.",
),
env_vars: Optional[Dict[str, Any]] = Form(None, description="Environment variables to pass to the agent for tool execution."),
env_vars_json: Optional[str] = Form(
None, description="Environment variables as a JSON string to pass to the agent for tool execution."
),
):
"""
Import a serialized agent file and recreate the agent(s) in the system.
@@ -311,6 +322,17 @@ async def import_agent_serialized(
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
# Parse env_vars_json if provided
env_vars = None
if env_vars_json:
try:
env_vars = json.loads(env_vars_json)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="env_vars_json must be a valid JSON string")
if not isinstance(env_vars, dict):
raise HTTPException(status_code=400, detail="env_vars_json must be a valid JSON string")
# Check if the JSON is AgentFileSchema or AgentSchema
# TODO: This is kind of hacky, but should work as long as dont' change the schema
if "agents" in agent_json and isinstance(agent_json.get("agents"), list):
@@ -323,6 +345,7 @@ async def import_agent_serialized(
override_existing_tools=override_existing_tools,
project_id=project_id,
strip_messages=strip_messages,
env_vars=env_vars,
)
else:
# This is a legacy AgentSchema
@@ -728,6 +751,49 @@ async def list_agent_folders(
return await server.agent_manager.list_attached_sources_async(agent_id=agent_id, actor=actor)
@router.get("/{agent_id}/files", response_model=PaginatedAgentFiles, operation_id="list_agent_files")
async def list_agent_files(
agent_id: str,
cursor: Optional[str] = Query(None, description="Pagination cursor from previous response"),
limit: int = Query(20, ge=1, le=100, description="Number of items to return (1-100)"),
is_open: Optional[bool] = Query(None, description="Filter by open status (true for open files, false for closed files)"),
server: "SyncServer" = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get the files attached to an agent with their open/closed status (paginated).
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# get paginated file-agent relationships for this agent
file_agents, next_cursor, has_more = await server.file_agent_manager.list_files_for_agent_paginated(
agent_id=agent_id, actor=actor, cursor=cursor, limit=limit, is_open=is_open
)
# enrich with file and source metadata
enriched_files = []
for fa in file_agents:
# get source/folder metadata
source = await server.source_manager.get_source_by_id(source_id=fa.source_id, actor=actor)
# build response object
attachment = AgentFileAttachment(
id=fa.id,
file_id=fa.file_id,
file_name=fa.file_name,
folder_id=fa.source_id,
folder_name=source.name if source else "Unknown",
is_open=fa.is_open,
last_accessed_at=fa.last_accessed_at,
visible_content=fa.visible_content,
start_line=fa.start_line,
end_line=fa.end_line,
)
enriched_files.append(attachment)
return PaginatedAgentFiles(files=enriched_files, next_cursor=next_cursor, has_more=has_more)
# TODO: remove? can also get with agent blocks
@router.get("/{agent_id}/core-memory", response_model=Memory, operation_id="retrieve_agent_memory")
async def retrieve_agent_memory(
@@ -999,7 +1065,8 @@ async def send_message(
"bedrock",
"ollama",
"azure",
"together",
"xai",
"groq",
]
# Create a new run for execution tracking
@@ -1143,7 +1210,8 @@ async def send_message_streaming(
"bedrock",
"ollama",
"azure",
"together",
"xai",
"groq",
]
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]
@@ -1157,6 +1225,7 @@ async def send_message_streaming(
metadata={
"job_type": "send_message_streaming",
"agent_id": agent_id,
"background": request.background or False,
},
request_config=LettaRequestConfig(
use_assistant_message=request.use_assistant_message,
@@ -1211,8 +1280,58 @@ async def send_message_streaming(
else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
),
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
if request.background and settings.track_agent_run:
if isinstance(redis_client, NoopAsyncRedisClient):
raise HTTPException(
status_code=503,
detail=(
"Background streaming requires Redis to be running. "
"Please ensure Redis is properly configured. "
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
),
)
if request.stream_tokens and model_compatible_token_streaming:
raw_stream = agent_loop.step_stream(
input_messages=request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
)
else:
raw_stream = agent_loop.step_stream_no_tokens(
request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
)
asyncio.create_task(
create_background_stream_processor(
stream_generator=raw_stream,
redis_client=redis_client,
run_id=run.id,
)
)
stream = redis_sse_stream_generator(
redis_client=redis_client,
run_id=run.id,
)
if request.include_pings and settings.enable_keepalive:
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval)
return StreamingResponseWithStatusCode(
stream,
media_type="text/event-stream",
)
if request.stream_tokens and model_compatible_token_streaming:
raw_stream = agent_loop.step_stream(
input_messages=request.messages,
@@ -1350,6 +1469,7 @@ async def _process_message_background(
"google_vertex",
"bedrock",
"ollama",
"groq",
]
if agent_eligible and model_compatible:
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
@@ -1538,7 +1658,8 @@ async def preview_raw_payload(
"bedrock",
"ollama",
"azure",
"together",
"xai",
"groq",
]
if agent_eligible and model_compatible:
@@ -1608,7 +1729,8 @@ async def summarize_agent_conversation(
"bedrock",
"ollama",
"azure",
"together",
"xai",
"groq",
]
if agent_eligible and model_compatible:

View File

@@ -7,6 +7,7 @@ from typing import List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query, UploadFile
from starlette import status
from starlette.responses import Response
import letta.constants as constants
from letta.helpers.pinecone_utils import (
@@ -34,7 +35,7 @@ from letta.services.file_processor.file_types import get_allowed_media_types, ge
from letta.services.file_processor.parser.markitdown_parser import MarkitdownFileParser
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
from letta.settings import settings
from letta.utils import safe_create_task, sanitize_filename
from letta.utils import safe_create_file_processing_task, safe_create_task, sanitize_filename
logger = get_logger(__name__)
@@ -138,8 +139,11 @@ async def create_folder(
# TODO: need to asyncify this
if not folder_create.embedding_config:
if not folder_create.embedding:
# TODO: modify error type
raise ValueError("Must specify either embedding or embedding_config in request")
if settings.default_embedding_handle is None:
# TODO: modify error type
raise ValueError("Must specify either embedding or embedding_config in request")
else:
folder_create.embedding = settings.default_embedding_handle
folder_create.embedding_config = await server.get_embedding_config_from_handle_async(
handle=folder_create.embedding,
embedding_chunk_size=folder_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
@@ -257,13 +261,16 @@ async def upload_file_to_folder(
# Store original filename and handle duplicate logic
# Use custom name if provided, otherwise use the uploaded file's name
original_filename = sanitize_filename(name if name else file.filename) # Basic sanitization only
# If custom name is provided, use it directly (it's just metadata, not a filesystem path)
# Otherwise, sanitize the uploaded filename for security
original_filename = name if name else sanitize_filename(file.filename) # Basic sanitization only
# Check if duplicate exists
existing_file = await server.file_manager.get_file_by_original_name_and_source(
original_filename=original_filename, source_id=folder_id, actor=actor
)
unique_filename = None
if existing_file:
# Duplicate found, handle based on strategy
if duplicate_handling == DuplicateFileHandling.ERROR:
@@ -305,8 +312,11 @@ async def upload_file_to_folder(
# Use cloud processing for all files (simple files always, complex files with Mistral key)
logger.info("Running experimental cloud based file processing...")
safe_create_task(
safe_create_file_processing_task(
load_file_to_source_cloud(server, agent_states, content, folder_id, actor, folder.embedding_config, file_metadata),
file_metadata=file_metadata,
server=server,
actor=actor,
logger=logger,
label="file_processor.process",
)

View File

@@ -1,16 +1,23 @@
from datetime import timedelta
from typing import Annotated, List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from pydantic import Field
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.helpers.datetime_helpers import get_utc_time
from letta.orm.errors import NoResultFound
from letta.schemas.enums import JobStatus, JobType, MessageRole
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_request import RetrieveStreamRequest
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.run import Run
from letta.schemas.step import Step
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.settings import settings
router = APIRouter(prefix="/runs", tags=["runs"])
@@ -19,6 +26,14 @@ router = APIRouter(prefix="/runs", tags=["runs"])
def list_runs(
server: "SyncServer" = Depends(get_letta_server),
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
after: Optional[str] = Query(None, description="Cursor for pagination"),
before: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(50, description="Maximum number of runs to return"),
ascending: bool = Query(
False,
description="Whether to sort agents oldest to newest (True) or newest to oldest (False, default)",
),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
@@ -26,18 +41,29 @@ def list_runs(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
runs = [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)]
if not agent_ids:
return runs
return [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
runs = [
Run.from_job(job)
for job in server.job_manager.list_jobs(
actor=actor,
job_type=JobType.RUN,
limit=limit,
before=before,
after=after,
ascending=False,
)
]
if agent_ids:
runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
if background is not None:
runs = [run for run in runs if "background" in run.metadata and run.metadata["background"] == background]
return runs
@router.get("/active", response_model=List[Run], operation_id="list_active_runs")
def list_active_runs(
server: "SyncServer" = Depends(get_letta_server),
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
@@ -46,13 +72,15 @@ def list_active_runs(
actor = server.user_manager.get_user_or_default(user_id=actor_id)
active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN)
active_runs = [Run.from_job(job) for job in active_runs]
if not agent_ids:
return active_runs
if agent_ids:
active_runs = [run for run in active_runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
return [run for run in active_runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
if background is not None:
active_runs = [run for run in active_runs if "background" in run.metadata and run.metadata["background"] == background]
return active_runs
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
@@ -213,3 +241,65 @@ async def delete_run(
return Run.from_job(job)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")
@router.post(
"/{run_id}/stream",
response_model=None,
operation_id="retrieve_stream",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
)
async def retrieve_stream(
run_id: str,
request: RetrieveStreamRequest = Body(None),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")
run = Run.from_job(job)
if "background" not in run.metadata or not run.metadata["background"]:
raise HTTPException(status_code=400, detail="Run was not created in background mode, so it cannot be retrieved.")
if run.created_at < get_utc_time() - timedelta(hours=3):
raise HTTPException(status_code=410, detail="Run was created more than 3 hours ago, and is now expired.")
redis_client = await get_redis_client()
if isinstance(redis_client, NoopAsyncRedisClient):
raise HTTPException(
status_code=503,
detail=(
"Background streaming requires Redis to be running. "
"Please ensure Redis is properly configured. "
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
),
)
stream = redis_sse_stream_generator(
redis_client=redis_client,
run_id=run_id,
starting_after=request.starting_after,
poll_interval=request.poll_interval,
batch_size=request.batch_size,
)
if request.include_pings and settings.enable_keepalive:
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval)
return StreamingResponseWithStatusCode(
stream,
media_type="text/event-stream",
)

View File

@@ -2,18 +2,17 @@ import asyncio
import mimetypes
import os
import tempfile
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query, UploadFile
from starlette import status
from starlette.responses import Response
import letta.constants as constants
from letta.helpers.pinecone_utils import (
delete_file_records_from_pinecone_index,
delete_source_records_from_pinecone_index,
list_pinecone_index_for_files,
should_use_pinecone,
)
from letta.log import get_logger
@@ -35,14 +34,13 @@ from letta.services.file_processor.file_types import get_allowed_media_types, ge
from letta.services.file_processor.parser.markitdown_parser import MarkitdownFileParser
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
from letta.settings import settings
from letta.utils import safe_create_task, sanitize_filename
from letta.utils import safe_create_file_processing_task, safe_create_task, sanitize_filename
logger = get_logger(__name__)
# Register all supported file types with Python's mimetypes module
register_mime_types()
router = APIRouter(prefix="/sources", tags=["sources"])
@@ -139,8 +137,11 @@ async def create_source(
# TODO: need to asyncify this
if not source_create.embedding_config:
if not source_create.embedding:
# TODO: modify error type
raise ValueError("Must specify either embedding or embedding_config in request")
if settings.default_embedding_handle is None:
# TODO: modify error type
raise ValueError("Must specify either embedding or embedding_config in request")
else:
source_create.embedding = settings.default_embedding_handle
source_create.embedding_config = await server.get_embedding_config_from_handle_async(
handle=source_create.embedding,
embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
@@ -258,7 +259,9 @@ async def upload_file_to_source(
# Store original filename and handle duplicate logic
# Use custom name if provided, otherwise use the uploaded file's name
original_filename = sanitize_filename(name if name else file.filename) # Basic sanitization only
# If custom name is provided, use it directly (it's just metadata, not a filesystem path)
# Otherwise, sanitize the uploaded filename for security
original_filename = name if name else sanitize_filename(file.filename) # Basic sanitization only
# Check if duplicate exists
existing_file = await server.file_manager.get_file_by_original_name_and_source(
@@ -307,8 +310,11 @@ async def upload_file_to_source(
# Use cloud processing for all files (simple files always, complex files with Mistral key)
logger.info("Running experimental cloud based file processing...")
safe_create_task(
safe_create_file_processing_task(
load_file_to_source_cloud(server, agent_states, content, source_id, actor, source.embedding_config, file_metadata),
file_metadata=file_metadata,
server=server,
actor=actor,
logger=logger,
label="file_processor.process",
)
@@ -358,6 +364,10 @@ async def list_source_files(
limit: int = Query(1000, description="Number of files to return"),
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
include_content: bool = Query(False, description="Whether to include full file content"),
check_status_updates: bool = Query(
True,
description="Whether to check and update file processing status (from the vector db service). If False, will not fetch and update the status, which may lead to performance gains.",
),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
@@ -372,6 +382,7 @@ async def list_source_files(
actor=actor,
include_content=include_content,
strip_directory_prefix=True, # TODO: Reconsider this. This is purely for aesthetics.
check_status_updates=check_status_updates,
)
@@ -400,51 +411,8 @@ async def get_file_metadata(
if file_metadata.source_id != source_id:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.")
# Check for timeout if status is not terminal
if not file_metadata.processing_status.is_terminal_state():
if file_metadata.created_at:
# Handle timezone differences between PostgreSQL (timezone-aware) and SQLite (timezone-naive)
if settings.letta_pg_uri_no_default:
# PostgreSQL: both datetimes are timezone-aware
timeout_threshold = datetime.now(timezone.utc) - timedelta(minutes=settings.file_processing_timeout_minutes)
file_created_at = file_metadata.created_at
else:
# SQLite: both datetimes should be timezone-naive
timeout_threshold = datetime.utcnow() - timedelta(minutes=settings.file_processing_timeout_minutes)
file_created_at = file_metadata.created_at
if file_created_at < timeout_threshold:
# Move file to error status with timeout message
timeout_message = settings.file_processing_timeout_error_message.format(settings.file_processing_timeout_minutes)
try:
file_metadata = await server.file_manager.update_file_status(
file_id=file_metadata.id, actor=actor, processing_status=FileProcessingStatus.ERROR, error_message=timeout_message
)
except ValueError as e:
# state transition was blocked - log it but don't fail the request
logger.warning(f"Could not update file to timeout error state: {str(e)}")
# continue with existing file_metadata
if should_use_pinecone() and file_metadata.processing_status == FileProcessingStatus.EMBEDDING:
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor)
logger.info(
f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_id} ({file_metadata.file_name}) in organization {actor.organization_id}"
)
if len(ids) != file_metadata.chunks_embedded or len(ids) == file_metadata.total_chunks:
if len(ids) != file_metadata.total_chunks:
file_status = file_metadata.processing_status
else:
file_status = FileProcessingStatus.COMPLETED
try:
file_metadata = await server.file_manager.update_file_status(
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
)
except ValueError as e:
# state transition was blocked - this is a race condition
# log it but don't fail the request since we're just reading metadata
logger.warning(f"Race condition detected in get_file_metadata: {str(e)}")
# return the current file state without updating
# Check and update file status (timeout check and pinecone embedding sync)
file_metadata = await server.file_manager.check_and_update_file_status(file_metadata, actor)
return file_metadata

View File

@@ -1,18 +1,28 @@
from typing import Optional
from fastapi import APIRouter, Depends, Header
from letta.schemas.provider_trace import ProviderTrace
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.settings import settings
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
@router.get("/{step_id}", response_model=ProviderTrace, operation_id="retrieve_provider_trace")
@router.get("/{step_id}", response_model=Optional[ProviderTrace], operation_id="retrieve_provider_trace")
async def retrieve_provider_trace_by_step_id(
step_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
return await server.telemetry_manager.get_provider_trace_by_step_id_async(
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
)
provider_trace = None
if settings.track_provider_trace:
try:
provider_trace = await server.telemetry_manager.get_provider_trace_by_step_id_async(
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
)
except:
pass
return provider_trace

View File

@@ -547,7 +547,7 @@ async def add_mcp_server_to_config(
server_name=request.server_name,
server_type=request.type,
server_url=request.server_url,
token=request.resolve_token() if not request.custom_headers else None,
token=request.resolve_token(),
custom_headers=request.custom_headers,
)
elif isinstance(request, StreamableHTTPServerConfig):
@@ -555,7 +555,7 @@ async def add_mcp_server_to_config(
server_name=request.server_name,
server_type=request.type,
server_url=request.server_url,
token=request.resolve_token() if not request.custom_headers else None,
token=request.resolve_token(),
custom_headers=request.custom_headers,
)

View File

@@ -10,6 +10,7 @@ import anyio
from fastapi.responses import StreamingResponse
from starlette.types import Send
from letta.errors import LettaUnexpectedStreamCancellationError
from letta.log import get_logger
from letta.schemas.enums import JobStatus
from letta.schemas.letta_ping import LettaPing
@@ -288,33 +289,11 @@ class StreamingResponseWithStatusCode(StreamingResponse):
# Handle client timeouts (should throw error to inform user)
except asyncio.CancelledError as exc:
logger.warning("Stream was cancelled due to client timeout or unexpected disconnection")
logger.warning("Stream was terminated due to unexpected cancellation from server")
# Handle unexpected cancellation with error
more_body = False
error_resp = {"error": {"message": "Request was unexpectedly cancelled (likely due to client timeout or disconnection)"}}
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
if not self.response_started:
await send(
{
"type": "http.response.start",
"status": 408, # Request Timeout
"headers": self.raw_headers,
}
)
raise
if self._client_connected:
try:
await send(
{
"type": "http.response.body",
"body": error_event,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
self._client_connected = False
capture_sentry_exception(exc)
return
raise LettaUnexpectedStreamCancellationError("Stream was terminated due to unexpected cancellation from server")
except Exception as exc:
logger.exception("Unhandled Streaming Error")

View File

@@ -2068,7 +2068,6 @@ class SyncServer(Server):
raise ValueError(f"No client was created for MCP server: {mcp_server_name}")
tools = await self.mcp_clients[mcp_server_name].list_tools()
# Add health information to each tool
for tool in tools:
if tool.inputSchema:

View File

@@ -42,6 +42,7 @@ from letta.orm.sandbox_config import AgentEnvironmentVariable
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
from letta.orm.sqlalchemy_base import AccessType
from letta.otel.tracing import trace_method
from letta.prompts.prompt_generator import PromptGenerator
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent, get_prompt_template_for_agent_type
from letta.schemas.block import DEFAULT_BLOCKS
@@ -89,7 +90,6 @@ from letta.services.helpers.agent_manager_helper import (
check_supports_structured_output,
compile_system_message,
derive_system_message,
get_system_message_from_compiled_memory,
initialize_message_sequence,
initialize_message_sequence_async,
package_initial_message_sequence,
@@ -1783,7 +1783,7 @@ class AgentManager:
# update memory (TODO: potentially update recall/archival stats separately)
new_system_message_str = get_system_message_from_compiled_memory(
new_system_message_str = PromptGenerator.get_system_message_from_compiled_memory(
system_prompt=agent_state.system,
memory_with_sources=curr_memory_str,
in_context_memory_last_edit=memory_edit_timestamp,

View File

@@ -1,8 +1,16 @@
import asyncio
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError
from letta.errors import (
AgentExportIdMappingError,
AgentExportProcessingError,
AgentFileExportError,
AgentFileImportError,
AgentNotFoundForExportError,
)
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.log import get_logger
from letta.schemas.agent import AgentState, CreateAgent
@@ -420,6 +428,8 @@ class AgentSerializationManager:
self,
schema: AgentFileSchema,
actor: User,
append_copy_suffix: bool = False,
override_existing_tools: bool = True,
dry_run: bool = False,
env_vars: Optional[Dict[str, Any]] = None,
) -> ImportResult:
@@ -481,7 +491,9 @@ class AgentSerializationManager:
pydantic_tools.append(Tool(**tool_schema.model_dump(exclude={"id"})))
# bulk upsert all tools at once
created_tools = await self.tool_manager.bulk_upsert_tools_async(pydantic_tools, actor)
created_tools = await self.tool_manager.bulk_upsert_tools_async(
pydantic_tools, actor, override_existing_tools=override_existing_tools
)
# map file ids to database ids
# note: tools are matched by name during upsert, so we need to match by name here too
@@ -513,8 +525,20 @@ class AgentSerializationManager:
if schema.sources:
# convert source schemas to pydantic sources
pydantic_sources = []
# First, do a fast batch check for existing source names to avoid conflicts
source_names_to_check = [s.name for s in schema.sources]
existing_source_names = await self.source_manager.get_existing_source_names(source_names_to_check, actor)
for source_schema in schema.sources:
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
# Check if source name already exists, if so add unique suffix
original_name = source_data["name"]
if original_name in existing_source_names:
unique_suffix = uuid.uuid4().hex[:8]
source_data["name"] = f"{original_name}_{unique_suffix}"
pydantic_sources.append(Source(**source_data))
# bulk upsert all sources at once
@@ -523,13 +547,15 @@ class AgentSerializationManager:
# map file ids to database ids
# note: sources are matched by name during upsert, so we need to match by name here too
created_sources_by_name = {source.name: source for source in created_sources}
for source_schema in schema.sources:
created_source = created_sources_by_name.get(source_schema.name)
for i, source_schema in enumerate(schema.sources):
# Use the pydantic source name (which may have been modified for uniqueness)
source_name = pydantic_sources[i].name
created_source = created_sources_by_name.get(source_name)
if created_source:
file_to_db_ids[source_schema.id] = created_source.id
imported_count += 1
else:
logger.warning(f"Source {source_schema.name} was not created during bulk upsert")
logger.warning(f"Source {source_name} was not created during bulk upsert")
# 4. Create files (depends on sources)
for file_schema in schema.files:
@@ -548,38 +574,49 @@ class AgentSerializationManager:
imported_count += 1
# 5. Process files for chunking/embedding (depends on files and sources)
if should_use_pinecone():
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
else:
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
file_processor = FileProcessor(
file_parser=self.file_parser,
embedder=embedder,
actor=actor,
using_pinecone=self.using_pinecone,
)
# Start background tasks for file processing
background_tasks = []
if schema.files and any(f.content for f in schema.files):
if should_use_pinecone():
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
else:
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
file_processor = FileProcessor(
file_parser=self.file_parser,
embedder=embedder,
actor=actor,
using_pinecone=self.using_pinecone,
)
for file_schema in schema.files:
if file_schema.content: # Only process files with content
file_db_id = file_to_db_ids[file_schema.id]
source_db_id = file_to_db_ids[file_schema.source_id]
for file_schema in schema.files:
if file_schema.content: # Only process files with content
file_db_id = file_to_db_ids[file_schema.id]
source_db_id = file_to_db_ids[file_schema.source_id]
# Get the created file metadata (with caching)
if file_db_id not in file_metadata_cache:
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
file_metadata = file_metadata_cache[file_db_id]
# Get the created file metadata (with caching)
if file_db_id not in file_metadata_cache:
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
file_metadata = file_metadata_cache[file_db_id]
# Save the db call of fetching content again
file_metadata.content = file_schema.content
# Save the db call of fetching content again
file_metadata.content = file_schema.content
# Process the file for chunking/embedding
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_db_id)
imported_count += len(passages)
# Create background task for file processing
# TODO: This can be moved to celery or RQ or something
task = asyncio.create_task(
self._process_file_async(
file_metadata=file_metadata, source_id=source_db_id, file_processor=file_processor, actor=actor
)
)
background_tasks.append(task)
logger.info(f"Started background processing for file {file_metadata.file_name} (ID: {file_db_id})")
# 6. Create agents with empty message history
for agent_schema in schema.agents:
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})
if append_copy_suffix:
agent_data["name"] = agent_data.get("name") + "_copy"
# Remap tool_ids from file IDs to database IDs
if agent_data.get("tool_ids"):
@@ -589,6 +626,10 @@ class AgentSerializationManager:
if agent_data.get("block_ids"):
agent_data["block_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["block_ids"]]
# Remap source_ids from file IDs to database IDs
if agent_data.get("source_ids"):
agent_data["source_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["source_ids"]]
if env_vars:
for var in agent_data["tool_exec_environment_variables"]:
var["value"] = env_vars.get(var["key"], "")
@@ -635,14 +676,16 @@ class AgentSerializationManager:
for file_agent_schema in agent_schema.files_agents:
file_db_id = file_to_db_ids[file_agent_schema.file_id]
# Use cached file metadata if available
# Use cached file metadata if available (with content)
if file_db_id not in file_metadata_cache:
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(
file_db_id, actor, include_content=True
)
file_metadata = file_metadata_cache[file_db_id]
files_for_agent.append(file_metadata)
if file_agent_schema.visible_content:
visible_content_map[file_db_id] = file_agent_schema.visible_content
visible_content_map[file_metadata.file_name] = file_agent_schema.visible_content
# Bulk attach files to agent
await self.file_agent_manager.attach_files_bulk(
@@ -669,9 +712,19 @@ class AgentSerializationManager:
file_to_db_ids[group.id] = created_group.id
imported_count += 1
# prepare result message
num_background_tasks = len(background_tasks)
if num_background_tasks > 0:
message = (
f"Import completed successfully. Imported {imported_count} entities. "
f"{num_background_tasks} file(s) are being processed in the background for embeddings."
)
else:
message = f"Import completed successfully. Imported {imported_count} entities."
return ImportResult(
success=True,
message=f"Import completed successfully. Imported {imported_count} entities.",
message=message,
imported_count=imported_count,
imported_agent_ids=imported_agent_ids,
id_mappings=file_to_db_ids,
@@ -849,3 +902,47 @@ class AgentSerializationManager:
except AttributeError:
allowed = model_cls.__fields__.keys() # Pydantic v1
return {k: v for k, v in data.items() if k in allowed}
async def _process_file_async(self, file_metadata: FileMetadata, source_id: str, file_processor: FileProcessor, actor: User):
"""
Process a file asynchronously in the background.
This method handles chunking and embedding of file content without blocking
the main import process.
Args:
file_metadata: The file metadata with content
source_id: The database ID of the source
file_processor: The file processor instance to use
actor: The user performing the action
"""
file_id = file_metadata.id
file_name = file_metadata.file_name
try:
logger.info(f"Starting background processing for file {file_name} (ID: {file_id})")
# process the file for chunking/embedding
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_id)
logger.info(f"Successfully processed file {file_name} with {len(passages)} passages")
# file status is automatically updated to COMPLETED by process_imported_file
return passages
except Exception as e:
logger.error(f"Failed to process file {file_name} (ID: {file_id}) in background: {e}")
# update file status to ERROR
try:
await self.file_manager.update_file_status(
file_id=file_id,
actor=actor,
processing_status=FileProcessingStatus.ERROR,
error_message=str(e) if str(e) else f"Agent serialization failed: {type(e).__name__}",
)
except Exception as update_error:
logger.error(f"Failed to update file status to ERROR for {file_id}: {update_error}")
# we don't re-raise here since this is a background task
# the file will be marked as ERROR and the import can continue

View File

@@ -1,6 +1,6 @@
import asyncio
import os
from datetime import datetime
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from sqlalchemy import func, select, update
@@ -9,6 +9,8 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from letta.constants import MAX_FILENAME_LENGTH
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.file import FileContent as FileContentModel
from letta.orm.file import FileMetadata as FileMetadataModel
@@ -20,8 +22,11 @@ from letta.schemas.source import Source as PydanticSource
from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, SourceStats
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import settings
from letta.utils import enforce_types
logger = get_logger(__name__)
class DuplicateFileError(Exception):
"""Raised when a duplicate file is encountered and error handling is specified"""
@@ -174,6 +179,10 @@ class FileManager:
if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None:
raise ValueError("Nothing to update")
# validate that ERROR status must have an error message
if processing_status == FileProcessingStatus.ERROR and not error_message:
raise ValueError("Error message is required when setting processing status to ERROR")
values: dict[str, object] = {"updated_at": datetime.utcnow()}
if processing_status is not None:
values["processing_status"] = processing_status
@@ -273,6 +282,79 @@ class FileManager:
)
return await file_orm.to_pydantic_async()
@enforce_types
@trace_method
async def check_and_update_file_status(
self,
file_metadata: PydanticFileMetadata,
actor: PydanticUser,
) -> PydanticFileMetadata:
"""
Check and update file status for timeout and embedding completion.
This method consolidates logic for:
1. Checking if a file has timed out during processing
2. Checking Pinecone embedding status and updating counts
Args:
file_metadata: The file metadata to check
actor: User performing the check
Returns:
Updated file metadata with current status
"""
# check for timeout if status is not terminal
if not file_metadata.processing_status.is_terminal_state():
if file_metadata.created_at:
# handle timezone differences between PostgreSQL (timezone-aware) and SQLite (timezone-naive)
if settings.letta_pg_uri_no_default:
# postgresql: both datetimes are timezone-aware
timeout_threshold = datetime.now(timezone.utc) - timedelta(minutes=settings.file_processing_timeout_minutes)
file_created_at = file_metadata.created_at
else:
# sqlite: both datetimes should be timezone-naive
timeout_threshold = datetime.utcnow() - timedelta(minutes=settings.file_processing_timeout_minutes)
file_created_at = file_metadata.created_at
if file_created_at < timeout_threshold:
# move file to error status with timeout message
timeout_message = settings.file_processing_timeout_error_message.format(settings.file_processing_timeout_minutes)
try:
file_metadata = await self.update_file_status(
file_id=file_metadata.id,
actor=actor,
processing_status=FileProcessingStatus.ERROR,
error_message=timeout_message,
)
except ValueError as e:
# state transition was blocked - log it but don't fail
logger.warning(f"Could not update file to timeout error state: {str(e)}")
# continue with existing file_metadata
# check pinecone embedding status
if should_use_pinecone() and file_metadata.processing_status == FileProcessingStatus.EMBEDDING:
ids = await list_pinecone_index_for_files(file_id=file_metadata.id, actor=actor)
logger.info(
f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_metadata.id} ({file_metadata.file_name}) in organization {actor.organization_id}"
)
if len(ids) != file_metadata.chunks_embedded or len(ids) == file_metadata.total_chunks:
if len(ids) != file_metadata.total_chunks:
file_status = file_metadata.processing_status
else:
file_status = FileProcessingStatus.COMPLETED
try:
file_metadata = await self.update_file_status(
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
)
except ValueError as e:
# state transition was blocked - this is a race condition
# log it but don't fail since we're just checking status
logger.warning(f"Race condition detected in check_and_update_file_status: {str(e)}")
# return the current file state without updating
return file_metadata
@enforce_types
@trace_method
async def upsert_file_content(
@@ -328,8 +410,22 @@ class FileManager:
limit: Optional[int] = 50,
include_content: bool = False,
strip_directory_prefix: bool = False,
check_status_updates: bool = False,
) -> List[PydanticFileMetadata]:
"""List all files with optional pagination."""
"""List all files with optional pagination and status checking.
Args:
source_id: Source to list files from
actor: User performing the request
after: Pagination cursor
limit: Maximum number of files to return
include_content: Whether to include file content
strip_directory_prefix: Whether to strip directory prefix from filenames
check_status_updates: Whether to check and update status for timeout and embedding completion
Returns:
List of file metadata
"""
async with db_registry.async_session() as session:
options = [selectinload(FileMetadataModel.content)] if include_content else None
@@ -341,10 +437,19 @@ class FileManager:
source_id=source_id,
query_options=options,
)
return [
await file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
for file in files
]
# convert all files to pydantic models
file_metadatas = await asyncio.gather(
*[file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix) for file in files]
)
# if status checking is enabled, check all files concurrently
if check_status_updates:
file_metadatas = await asyncio.gather(
*[self.check_and_update_file_status(file_metadata, actor) for file_metadata in file_metadatas]
)
return file_metadatas
@enforce_types
@trace_method

View File

@@ -264,7 +264,10 @@ class FileProcessor:
},
)
await self.file_manager.update_file_status(
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.ERROR, error_message=str(e)
file_id=file_metadata.id,
actor=self.actor,
processing_status=FileProcessingStatus.ERROR,
error_message=str(e) if str(e) else f"File processing failed: {type(e).__name__}",
)
return []
@@ -361,7 +364,7 @@ class FileProcessor:
file_id=file_metadata.id,
actor=self.actor,
processing_status=FileProcessingStatus.ERROR,
error_message=str(e),
error_message=str(e) if str(e) else f"Import file processing failed: {type(e).__name__}",
)
return []

View File

@@ -293,6 +293,66 @@ class FileAgentManager:
else:
return [r.to_pydantic() for r in rows]
@enforce_types
@trace_method
async def list_files_for_agent_paginated(
self,
agent_id: str,
actor: PydanticUser,
cursor: Optional[str] = None,
limit: int = 20,
is_open: Optional[bool] = None,
) -> tuple[List[PydanticFileAgent], Optional[str], bool]:
"""
Return paginated file associations for an agent.
Args:
agent_id: The agent ID to get files for
actor: User performing the action
cursor: Pagination cursor (file-agent ID to start after)
limit: Maximum number of results to return
is_open: Optional filter for open/closed status (None = all, True = open only, False = closed only)
Returns:
Tuple of (file_agents, next_cursor, has_more)
"""
async with db_registry.async_session() as session:
conditions = [
FileAgentModel.agent_id == agent_id,
FileAgentModel.organization_id == actor.organization_id,
FileAgentModel.is_deleted == False,
]
# apply is_open filter if specified
if is_open is not None:
conditions.append(FileAgentModel.is_open == is_open)
# apply cursor if provided (get records after this ID)
if cursor:
conditions.append(FileAgentModel.id > cursor)
query = select(FileAgentModel).where(and_(*conditions))
# order by ID for stable pagination
query = query.order_by(FileAgentModel.id)
# fetch limit + 1 to check if there are more results
query = query.limit(limit + 1)
result = await session.execute(query)
rows = result.scalars().all()
# check if we got more records than requested (meaning there are more pages)
has_more = len(rows) > limit
if has_more:
# trim back to the requested limit
rows = rows[:limit]
# get cursor for next page (ID of last item in current page)
next_cursor = rows[-1].id if rows else None
return [r.to_pydantic() for r in rows], next_cursor, has_more
@enforce_types
@trace_method
async def list_agents_for_file(

View File

@@ -21,7 +21,7 @@ from letta.constants import (
STRUCTURED_OUTPUT_MODELS,
)
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import format_datetime, get_local_time, get_local_time_fast
from letta.helpers.datetime_helpers import get_local_time
from letta.llm_api.llm_client import LLMClient
from letta.orm.agent import Agent as AgentModel
from letta.orm.agents_tags import AgentsTags
@@ -33,6 +33,7 @@ from letta.orm.sources_agents import SourcesAgents
from letta.orm.sqlite_functions import adapt_array
from letta.otel.tracing import trace_method
from letta.prompts import gpt_system
from letta.prompts.prompt_generator import PromptGenerator
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
@@ -217,60 +218,6 @@ def derive_system_message(agent_type: AgentType, enable_sleeptime: Optional[bool
return system
# TODO: This code is kind of wonky and deserves a rewrite
def compile_memory_metadata_block(
memory_edit_timestamp: datetime,
timezone: str,
previous_message_count: int = 0,
archival_memory_size: Optional[int] = 0,
) -> str:
"""
Generate a memory metadata block for the agent's system prompt.
This creates a structured metadata section that informs the agent about
the current state of its memory systems, including timing information
and memory counts. This helps the agent understand what information
is available through its tools.
Args:
memory_edit_timestamp: When memory blocks were last modified
timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
previous_message_count: Number of messages in recall memory (conversation history)
archival_memory_size: Number of items in archival memory (long-term storage)
Returns:
A formatted string containing the memory metadata block with XML-style tags
Example Output:
<memory_metadata>
- The current time is: 2024-01-15 10:30 AM PST
- Memory blocks were last modified: 2024-01-15 09:00 AM PST
- 42 previous messages between you and the user are stored in recall memory (use tools to access them)
- 156 total memories you created are stored in archival memory (use tools to access them)
</memory_metadata>
"""
# Put the timestamp in the local timezone (mimicking get_local_time())
timestamp_str = format_datetime(memory_edit_timestamp, timezone)
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
metadata_lines = [
"<memory_metadata>",
f"- The current time is: {get_local_time_fast(timezone)}",
f"- Memory blocks were last modified: {timestamp_str}",
f"- {previous_message_count} previous messages between you and the user are stored in recall memory (use tools to access them)",
]
# Only include archival memory line if there are archival memories
if archival_memory_size is not None and archival_memory_size > 0:
metadata_lines.append(
f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)"
)
metadata_lines.append("</memory_metadata>")
memory_metadata_block = "\n".join(metadata_lines)
return memory_metadata_block
class PreserveMapping(dict):
"""Used to preserve (do not modify) undefined variables in the system prompt"""
@@ -331,7 +278,7 @@ def compile_system_message(
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = compile_memory_metadata_block(
memory_metadata_string = PromptGenerator.compile_memory_metadata_block(
memory_edit_timestamp=in_context_memory_last_edit,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
@@ -372,154 +319,6 @@ def compile_system_message(
return formatted_prompt
@trace_method
def get_system_message_from_compiled_memory(
system_prompt: str,
memory_with_sources: str,
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
timezone: str,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
The base system message may be templated, in which case we need to render the variables.
The following are reserved variables:
- CORE_MEMORY: the in-context memory of the LLM
"""
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject
raise NotImplementedError
else:
variables = {}
# Add the protected memory variable
if IN_CONTEXT_MEMORY_KEYWORD in variables:
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = compile_memory_metadata_block(
memory_edit_timestamp=in_context_memory_last_edit,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
timezone=timezone,
)
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
# Add to the variables list to inject
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
if template_format == "f-string":
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
# Catch the special case where the system prompt is unformatted
if append_icm_if_missing:
if memory_variable_string not in system_prompt:
# In this case, append it to the end to make sure memory is still injected
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
system_prompt += "\n\n" + memory_variable_string
# render the variables using the built-in templater
try:
if user_defined_variables:
formatted_prompt = safe_format(system_prompt, variables)
else:
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
except Exception as e:
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
else:
# TODO support for mustache and jinja2
raise NotImplementedError(template_format)
return formatted_prompt
@trace_method
async def compile_system_message_async(
system_prompt: str,
in_context_memory: Memory,
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
timezone: str,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
tool_rules_solver: Optional[ToolRulesSolver] = None,
sources: Optional[List] = None,
max_files_open: Optional[int] = None,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
The base system message may be templated, in which case we need to render the variables.
The following are reserved variables:
- CORE_MEMORY: the in-context memory of the LLM
"""
# Add tool rule constraints if available
tool_constraint_block = None
if tool_rules_solver is not None:
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject
raise NotImplementedError
else:
variables = {}
# Add the protected memory variable
if IN_CONTEXT_MEMORY_KEYWORD in variables:
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = compile_memory_metadata_block(
memory_edit_timestamp=in_context_memory_last_edit,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
timezone=timezone,
)
memory_with_sources = await in_context_memory.compile_in_thread_async(
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
)
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
# Add to the variables list to inject
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
if template_format == "f-string":
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
# Catch the special case where the system prompt is unformatted
if append_icm_if_missing:
if memory_variable_string not in system_prompt:
# In this case, append it to the end to make sure memory is still injected
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
system_prompt += "\n\n" + memory_variable_string
# render the variables using the built-in templater
try:
if user_defined_variables:
formatted_prompt = safe_format(system_prompt, variables)
else:
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
except Exception as e:
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
else:
# TODO support for mustache and jinja2
raise NotImplementedError(template_format)
return formatted_prompt
@trace_method
def initialize_message_sequence(
agent_state: AgentState,
@@ -601,7 +400,7 @@ async def initialize_message_sequence_async(
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
full_system_message = await compile_system_message_async(
full_system_message = await PromptGenerator.compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,

View File

@@ -70,13 +70,16 @@ def runtime_override_tool_json_schema(
tool_list: list[JsonDict],
response_format: ResponseFormatUnion | None,
request_heartbeat: bool = True,
terminal_tools: set[str] | None = None,
) -> list[JsonDict]:
"""Override the tool JSON schemas at runtime if certain conditions are met.
Cases:
1. We will inject `send_message` tool calls with `response_format` if provided
2. Tools will have an additional `request_heartbeat` parameter added.
2. Tools will have an additional `request_heartbeat` parameter added (except for terminal tools).
"""
if terminal_tools is None:
terminal_tools = set()
for tool_json in tool_list:
if tool_json["name"] == SEND_MESSAGE_TOOL_NAME and response_format and response_format.type != ResponseFormatType.text:
if response_format.type == ResponseFormatType.json_schema:
@@ -89,8 +92,8 @@ def runtime_override_tool_json_schema(
"properties": {},
}
if request_heartbeat:
# TODO (cliandy): see support for tool control loop parameters
if tool_json["name"] != SEND_MESSAGE_TOOL_NAME:
# Only add request_heartbeat to non-terminal tools
if tool_json["name"] not in terminal_tools:
tool_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = {
"type": "boolean",
"description": REQUEST_HEARTBEAT_DESCRIPTION,

View File

@@ -14,9 +14,15 @@ logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncBaseMCPClient:
def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
# HTTP headers
AGENT_ID_HEADER = "X-Agent-Id"
def __init__(
self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None
):
self.server_config = server_config
self.oauth_provider = oauth_provider
self.agent_id = agent_id
self.exit_stack = AsyncExitStack()
self.session: Optional[ClientSession] = None
self.initialized = False

View File

@@ -16,8 +16,10 @@ logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncSSEMCPClient(AsyncBaseMCPClient):
def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
super().__init__(server_config, oauth_provider)
def __init__(
self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None
):
super().__init__(server_config, oauth_provider, agent_id)
async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
headers = {}
@@ -27,6 +29,9 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient):
if server_config.auth_header and server_config.auth_token:
headers[server_config.auth_header] = server_config.auth_token
if self.agent_id:
headers[self.AGENT_ID_HEADER] = self.agent_id
# Use OAuth provider if available, otherwise use regular headers
if self.oauth_provider:
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider)

View File

@@ -1,3 +1,5 @@
from typing import Optional
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -10,6 +12,9 @@ logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncStdioMCPClient(AsyncBaseMCPClient):
def __init__(self, server_config: StdioServerConfig, oauth_provider=None, agent_id: Optional[str] = None):
super().__init__(server_config, oauth_provider, agent_id)
async def _initialize_connection(self, server_config: StdioServerConfig) -> None:
args = [arg.split() for arg in server_config.args]
# flatten

View File

@@ -12,8 +12,13 @@ logger = get_logger(__name__)
class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
super().__init__(server_config, oauth_provider)
def __init__(
self,
server_config: StreamableHTTPServerConfig,
oauth_provider: Optional[OAuthClientProvider] = None,
agent_id: Optional[str] = None,
):
super().__init__(server_config, oauth_provider, agent_id)
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
if not isinstance(server_config, StreamableHTTPServerConfig):
@@ -28,6 +33,10 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
if server_config.auth_header and server_config.auth_token:
headers[server_config.auth_header] = server_config.auth_token
# Add agent ID header if provided
if self.agent_id:
headers[self.AGENT_ID_HEADER] = self.agent_id
# Use OAuth provider if available, otherwise use regular headers
if self.oauth_provider:
streamable_http_cm = streamablehttp_client(

View File

@@ -41,6 +41,7 @@ from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPCl
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
from letta.services.tool_manager import ToolManager
from letta.settings import tool_settings
from letta.utils import enforce_types, printd
logger = get_logger(__name__)
@@ -55,19 +56,18 @@ class MCPManager:
self.cached_mcp_servers = {} # maps id -> async connection
@enforce_types
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser) -> List[MCPTool]:
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]:
"""Get a list of all tools for a specific MCP server."""
mcp_client = None
try:
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config()
mcp_client = await self.get_mcp_client(server_config, actor)
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
await mcp_client.connect_to_server()
# list tools
tools = await mcp_client.list_tools()
# Add health information to each tool
for tool in tools:
if tool.inputSchema:
@@ -92,33 +92,34 @@ class MCPManager:
tool_args: Optional[Dict[str, Any]],
environment_variables: Dict[str, str],
actor: PydanticUser,
agent_id: Optional[str] = None,
) -> Tuple[str, bool]:
"""Call a specific tool from a specific MCP server."""
from letta.settings import tool_settings
mcp_client = None
try:
if not tool_settings.mcp_read_from_config:
# read from DB
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config(environment_variables)
else:
# read from config file
mcp_config = self.read_mcp_config()
if mcp_server_name not in mcp_config:
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
server_config = mcp_config[mcp_server_name]
if not tool_settings.mcp_read_from_config:
# read from DB
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config(environment_variables)
else:
# read from config file
mcp_config = self.read_mcp_config()
if mcp_server_name not in mcp_config:
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
server_config = mcp_config[mcp_server_name]
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
await mcp_client.connect_to_server()
mcp_client = await self.get_mcp_client(server_config, actor)
await mcp_client.connect_to_server()
# call tool
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
# TODO: change to pydantic tool
await mcp_client.cleanup()
return result, success
# call tool
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
# TODO: change to pydantic tool
return result, success
finally:
if mcp_client:
await mcp_client.cleanup()
@enforce_types
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
@@ -129,7 +130,6 @@ class MCPManager:
raise ValueError(f"MCP server '{mcp_server_name}' not found")
mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor)
for mcp_tool in mcp_tools:
# TODO: @jnjpng move health check to tool class
if mcp_tool.name == mcp_tool_name:
@@ -450,6 +450,7 @@ class MCPManager:
server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig],
actor: PydanticUser,
oauth_provider: Optional[Any] = None,
agent_id: Optional[str] = None,
) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]:
"""
Helper function to create the appropriate MCP client based on server configuration.
@@ -482,13 +483,13 @@ class MCPManager:
if server_config.type == MCPServerType.SSE:
server_config = SSEServerConfig(**server_config.model_dump())
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider)
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
elif server_config.type == MCPServerType.STDIO:
server_config = StdioServerConfig(**server_config.model_dump())
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider)
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
server_config = StreamableHTTPServerConfig(**server_config.model_dump())
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider)
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
else:
raise ValueError(f"Unsupported server config type: {type(server_config)}")

View File

@@ -143,7 +143,6 @@ class SourceManager:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
await session.execute(upsert_stmt)
await session.commit()
@@ -397,3 +396,29 @@ class SourceManager:
sources_orm = result.scalars().all()
return [source.to_pydantic() for source in sources_orm]
@enforce_types
@trace_method
async def get_existing_source_names(self, source_names: List[str], actor: PydanticUser) -> set[str]:
"""
Fast batch check to see which source names already exist for the organization.
Args:
source_names: List of source names to check
actor: User performing the action
Returns:
Set of source names that already exist
"""
if not source_names:
return set()
async with db_registry.async_session() as session:
query = select(SourceModel.name).where(
SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
)
result = await session.execute(query)
existing_names = result.scalars().all()
return set(existing_names)

View File

@@ -15,6 +15,8 @@ from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.message_manager import MessageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.system import package_summarize_message_no_counts
from letta.templates.template_helper import render_template
@@ -36,6 +38,10 @@ class Summarizer:
message_buffer_limit: int = 10,
message_buffer_min: int = 3,
partial_evict_summarizer_percentage: float = 0.30,
agent_manager: Optional[AgentManager] = None,
message_manager: Optional[MessageManager] = None,
actor: Optional[User] = None,
agent_id: Optional[str] = None,
):
self.mode = mode
@@ -46,6 +52,12 @@ class Summarizer:
self.summarizer_agent = summarizer_agent
self.partial_evict_summarizer_percentage = partial_evict_summarizer_percentage
# for partial buffer only
self.agent_manager = agent_manager
self.message_manager = message_manager
self.actor = actor
self.agent_id = agent_id
@trace_method
async def summarize(
self,
@@ -121,9 +133,6 @@ class Summarizer:
logger.debug("Not forcing summarization, returning in-context messages as is.")
return all_in_context_messages, False
# Very ugly code to pull LLMConfig etc from the SummarizerAgent if we're not using it for anything else
assert self.summarizer_agent is not None
# First step: determine how many messages to retain
total_message_count = len(all_in_context_messages)
assert self.partial_evict_summarizer_percentage >= 0.0 and self.partial_evict_summarizer_percentage <= 1.0
@@ -147,15 +156,13 @@ class Summarizer:
# Dynamically get the LLMConfig from the summarizer agent
# Pretty cringe code here that we need the agent for this but we don't use it
agent_state = await self.summarizer_agent.agent_manager.get_agent_by_id_async(
agent_id=self.summarizer_agent.agent_id, actor=self.summarizer_agent.actor
)
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
# TODO if we do this via the "agent", then we can more easily allow toggling on the memory block version
summary_message_str = await simple_summary(
messages=messages_to_summarize,
llm_config=agent_state.llm_config,
actor=self.summarizer_agent.actor,
actor=self.actor,
include_ack=True,
)
@@ -185,9 +192,9 @@ class Summarizer:
)[0]
# Create the message in the DB
await self.summarizer_agent.message_manager.create_many_messages_async(
await self.message_manager.create_many_messages_async(
pydantic_msgs=[summary_message_obj],
actor=self.summarizer_agent.actor,
actor=self.actor,
)
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
@@ -354,7 +361,11 @@ async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor:
# NOTE: we should disable the inner_thoughts_in_kwargs here, because we don't use it
# I'm leaving it commented it out for now for safety but is fine assuming the var here is a copy not a reference
# llm_config.put_inner_thoughts_in_kwargs = False
response_data = await llm_client.request_async(request_data, llm_config)
try:
response_data = await llm_client.request_async(request_data, llm_config)
except Exception as e:
# handle LLM error (likely a context window exceeded error)
raise llm_client.handle_llm_error(e)
response = llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, llm_config)
if response.choices[0].message.content is None:
logger.warning("No content returned from summarizer")

View File

@@ -151,16 +151,16 @@ class LettaFileToolExecutor(ToolExecutor):
offset = file_request.offset
length = file_request.length
# Convert 1-indexed offset/length to 0-indexed start/end for LineChunker
# Use 0-indexed offset/length directly for LineChunker
start, end = None, None
if offset is not None or length is not None:
if offset is not None and offset < 1:
raise ValueError(f"Offset for file {file_name} must be >= 1 (1-indexed), got {offset}")
if offset is not None and offset < 0:
raise ValueError(f"Offset for file {file_name} must be >= 0 (0-indexed), got {offset}")
if length is not None and length < 1:
raise ValueError(f"Length for file {file_name} must be >= 1, got {length}")
# Convert to 0-indexed for LineChunker
start = (offset - 1) if offset is not None else None
# Use offset directly as it's already 0-indexed
start = offset if offset is not None else None
if start is not None and length is not None:
end = start + length
else:
@@ -193,7 +193,7 @@ class LettaFileToolExecutor(ToolExecutor):
visible_content=visible_content,
max_files_open=agent_state.max_files_open,
start_line=start + 1 if start is not None else None, # convert to 1-indexed for user display
end_line=end if end is not None else None, # end is already exclusive in slicing, so this is correct
end_line=end if end is not None else None, # end is already exclusive, shows as 1-indexed inclusive
)
opened_files.append(file_name)
@@ -220,10 +220,14 @@ class LettaFileToolExecutor(ToolExecutor):
for req in file_requests:
previous_info = format_previous_range(req.file_name)
if req.offset is not None and req.length is not None:
end_line = req.offset + req.length - 1
file_summaries.append(f"{req.file_name} (lines {req.offset}-{end_line}){previous_info}")
# Display as 1-indexed for user readability: (offset+1) to (offset+length)
start_line = req.offset + 1
end_line = req.offset + req.length
file_summaries.append(f"{req.file_name} (lines {start_line}-{end_line}){previous_info}")
elif req.offset is not None:
file_summaries.append(f"{req.file_name} (lines {req.offset}-end){previous_info}")
# Display as 1-indexed
start_line = req.offset + 1
file_summaries.append(f"{req.file_name} (lines {start_line}-end){previous_info}")
else:
file_summaries.append(f"{req.file_name}{previous_info}")

View File

@@ -37,8 +37,10 @@ class ExternalMCPToolExecutor(ToolExecutor):
# TODO: may need to have better client connection management
environment_variables = {}
agent_id = None
if agent_state:
environment_variables = agent_state.get_agent_env_vars_as_dict()
agent_id = agent_state.id
function_response, success = await mcp_manager.execute_mcp_server_tool(
mcp_server_name=mcp_server_name,
@@ -46,6 +48,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
tool_args=function_args,
environment_variables=environment_variables,
actor=actor,
agent_id=agent_id,
)
return ToolExecutionResult(

View File

@@ -1,3 +1,4 @@
import asyncio
import traceback
from typing import Any, Dict, Optional, Type
@@ -129,6 +130,18 @@ class ToolExecutionManager:
result.func_return = FUNCTION_RETURN_VALUE_TRUNCATED(return_str, len(return_str), tool.return_char_limit)
return result
except asyncio.CancelledError as e:
self.logger.error(f"Aysnc cancellation error executing tool {function_name}: {str(e)}")
error_message = get_friendly_error_msg(
function_name=function_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
return ToolExecutionResult(
status="error",
func_return=error_message,
stderr=[traceback.format_exc()],
)
except Exception as e:
status = "error"
self.logger.error(f"Error executing tool {function_name}: {str(e)}")

View File

@@ -184,7 +184,9 @@ class ToolManager:
@enforce_types
@trace_method
async def bulk_upsert_tools_async(self, pydantic_tools: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
async def bulk_upsert_tools_async(
self, pydantic_tools: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
) -> List[PydanticTool]:
"""
Bulk create or update multiple tools in a single database transaction.
@@ -227,10 +229,10 @@ class ToolManager:
if settings.letta_pg_uri_no_default:
# use optimized postgresql bulk upsert
async with db_registry.async_session() as session:
return await self._bulk_upsert_postgresql(session, pydantic_tools, actor)
return await self._bulk_upsert_postgresql(session, pydantic_tools, actor, override_existing_tools)
else:
# fallback to individual upserts for sqlite
return await self._upsert_tools_individually(pydantic_tools, actor)
return await self._upsert_tools_individually(pydantic_tools, actor, override_existing_tools)
@enforce_types
@trace_method
@@ -784,8 +786,10 @@ class ToolManager:
return await self._upsert_tools_individually(tool_data_list, actor)
@trace_method
async def _bulk_upsert_postgresql(self, session, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
"""hyper-optimized postgresql bulk upsert using on_conflict_do_update."""
async def _bulk_upsert_postgresql(
self, session, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
) -> List[PydanticTool]:
"""hyper-optimized postgresql bulk upsert using on_conflict_do_update or on_conflict_do_nothing."""
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert
@@ -809,32 +813,51 @@ class ToolManager:
# use postgresql's native bulk upsert
stmt = insert(table).values(insert_data)
# on conflict, update all columns except id, created_at, and _created_by_id
excluded = stmt.excluded
update_dict = {}
for col in table.columns:
if col.name not in ("id", "created_at", "_created_by_id"):
if col.name == "updated_at":
update_dict[col.name] = func.now()
else:
update_dict[col.name] = excluded[col.name]
if override_existing_tools:
# on conflict, update all columns except id, created_at, and _created_by_id
excluded = stmt.excluded
update_dict = {}
for col in table.columns:
if col.name not in ("id", "created_at", "_created_by_id"):
if col.name == "updated_at":
update_dict[col.name] = func.now()
else:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
else:
# on conflict, do nothing (skip existing tools)
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"])
await session.execute(upsert_stmt)
await session.commit()
# fetch results
# fetch results (includes both inserted and skipped tools)
tool_names = [tool.name for tool in tool_data_list]
result_query = select(ToolModel).where(ToolModel.name.in_(tool_names), ToolModel.organization_id == actor.organization_id)
result = await session.execute(result_query)
return [tool.to_pydantic() for tool in result.scalars()]
@trace_method
async def _upsert_tools_individually(self, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
async def _upsert_tools_individually(
self, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
) -> List[PydanticTool]:
"""fallback to individual upserts for sqlite (original approach)."""
tools = []
for tool in tool_data_list:
upserted_tool = await self.create_or_update_tool_async(tool, actor)
tools.append(upserted_tool)
if override_existing_tools:
# update existing tools if they exist
upserted_tool = await self.create_or_update_tool_async(tool, actor)
tools.append(upserted_tool)
else:
# skip existing tools, only create new ones
existing_tool_id = await self.get_tool_id_by_name_async(tool_name=tool.name, actor=actor)
if existing_tool_id:
# tool exists, fetch and return it without updating
existing_tool = await self.get_tool_by_id_async(existing_tool_id, actor=actor)
tools.append(existing_tool)
else:
# tool doesn't exist, create it
created_tool = await self.create_tool_async(tool, actor=actor)
tools.append(created_tool)
return tools

View File

@@ -252,6 +252,7 @@ class Settings(BaseSettings):
track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages")
track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.")
track_agent_run: bool = Field(default=True, description="Enable tracking agent run with cancellation support")
track_provider_trace: bool = Field(default=True, description="Enable tracking raw llm request and response at each step")
# FastAPI Application Settings
uvicorn_workers: int = 1

View File

@@ -1103,6 +1103,43 @@ def safe_create_task(coro, logger: Logger, label: str = "background task"):
return asyncio.create_task(wrapper())
def safe_create_file_processing_task(coro, file_metadata, server, actor, logger: Logger, label: str = "file processing task"):
"""
Create a task for file processing that updates file status on failure.
This is a specialized version of safe_create_task that ensures file
status is properly updated to ERROR with a meaningful message if the
task fails.
Args:
coro: The coroutine to execute
file_metadata: FileMetadata object being processed
server: Server instance with file_manager
actor: User performing the operation
logger: Logger instance for error logging
label: Description of the task for logging
"""
from letta.schemas.enums import FileProcessingStatus
async def wrapper():
try:
await coro
except Exception as e:
logger.exception(f"{label} failed for file {file_metadata.file_name} with {type(e).__name__}: {e}")
# update file status to ERROR with a meaningful message
try:
await server.file_manager.update_file_status(
file_id=file_metadata.id,
actor=actor,
processing_status=FileProcessingStatus.ERROR,
error_message=f"Processing failed: {str(e)}" if str(e) else f"Processing failed: {type(e).__name__}",
)
except Exception as update_error:
logger.error(f"Failed to update file status to ERROR for {file_metadata.id}: {update_error}")
return asyncio.create_task(wrapper())
class CancellationSignal:
"""
A signal that can be checked for cancellation during streaming operations.

2763
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@
"lock": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry lock --no-update",
"command": "uv lock --no-update",
"cwd": "apps/core"
}
},
@@ -26,10 +26,7 @@
"dev": {
"executor": "@nxlv/python:run-commands",
"options": {
"commands": [
"./otel/start-otel-collector.sh",
"poetry run letta server"
],
"commands": ["./otel/start-otel-collector.sh", "uv run letta server"],
"parallel": true,
"cwd": "apps/core"
}
@@ -39,7 +36,7 @@
"options": {
"commands": [
"./otel/start-otel-collector.sh",
"poetry run letta server --debug --reload"
"uv run letta server --debug --reload"
],
"parallel": true,
"cwd": "apps/core"
@@ -58,21 +55,21 @@
"install": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry install --all-extras",
"command": "uv sync --all-extras",
"cwd": "apps/core"
}
},
"lint": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry run isort --profile black . && poetry run black . && poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .",
"command": "uv run isort --profile black . && uv run black . && uv run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .",
"cwd": "apps/core"
}
},
"database:migrate": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry run alembic upgrade head",
"command": "uv run alembic upgrade head",
"cwd": "apps/core"
}
},
@@ -83,7 +80,7 @@
"{workspaceRoot}/coverage/apps/core"
],
"options": {
"command": "poetry run pytest tests/",
"command": "uv run pytest tests/",
"cwd": "apps/core"
}
}

View File

@@ -1,6 +1,96 @@
[project]
name = "letta"
version = "0.10.0"
description = "Create LLM agents with long-term memory and custom tools"
authors = [
{name = "Letta Team", email = "contact@letta.com"},
]
license = {text = "Apache License"}
readme = "README.md"
requires-python = "<3.14,>=3.11"
dependencies = [
"typer>=0.15.2",
"questionary>=2.0.1",
"pytz>=2023.3.post1",
"tqdm>=4.66.1",
"black[jupyter]>=24.2.0",
"setuptools>=70",
"prettytable>=3.9.0",
"docstring-parser>=0.16,<0.17",
"httpx>=0.28.0",
"numpy>=2.1.0",
"demjson3>=3.0.6",
"pyyaml>=6.0.1",
"sqlalchemy-json>=0.7.0",
"pydantic>=2.10.6",
"html2text>=2020.1.16",
"sqlalchemy[asyncio]>=2.0.41",
"python-box>=7.1.1",
"sqlmodel>=0.0.16",
"python-multipart>=0.0.19",
"sqlalchemy-utils>=0.41.2",
"pydantic-settings>=2.2.1",
"httpx-sse>=0.4.0",
"nltk>=3.8.1",
"jinja2>=3.1.5",
"composio-core>=0.7.7",
"alembic>=1.13.3",
"pyhumps>=3.8.0",
"pathvalidate>=3.2.1",
"sentry-sdk[fastapi]==2.19.1",
"rich>=13.9.4",
"brotli>=1.1.0",
"grpcio>=1.68.1",
"grpcio-tools>=1.68.1",
"llama-index>=0.12.2",
"llama-index-embeddings-openai>=0.3.1",
"anthropic>=0.49.0",
"letta_client>=0.1.276",
"openai>=1.99.9",
"opentelemetry-api==1.30.0",
"opentelemetry-sdk==1.30.0",
"opentelemetry-instrumentation-requests==0.51b0",
"opentelemetry-instrumentation-sqlalchemy==0.51b0",
"opentelemetry-exporter-otlp==1.30.0",
"faker>=36.1.0",
"colorama>=0.4.6",
"marshmallow-sqlalchemy>=1.4.1",
"datamodel-code-generator[http]>=0.25.0",
"mcp[cli]>=1.9.4",
"firecrawl-py==2.16.5",
"apscheduler>=3.11.0",
"aiomultiprocess>=0.9.1",
"matplotlib>=3.10.1",
"tavily-python>=0.7.2",
"mistralai>=1.8.1",
"structlog>=25.4.0",
"certifi>=2025.6.15",
"markitdown[docx,pdf,pptx]>=0.1.2",
"orjson>=3.11.1",
]
[project.optional-dependencies]
postgres = ["pgvector>=0.2.3", "pg8000>=1.30.3", "psycopg2-binary>=2.9.10", "psycopg2>=2.9.10", "asyncpg>=0.30.0"]
redis = ["redis>=6.2.0"]
pinecone = ["pinecone[asyncio]>=7.3.0"]
dev = ["pytest>=8.0.0", "pytest-asyncio>=0.24.0", "pexpect>=4.9.0", "black>=24.2.0", "pre-commit>=3.5.0", "pyright>=1.1.347", "pytest-order>=1.2.0", "autoflake>=2.3.0", "isort>=5.13.2", "locust>=2.31.5"]
experimental = ["uvloop>=0.21.0; sys_platform != 'win32'", "granian[reload]>=2.3.2", "google-cloud-profiler>=4.1.0"]
server = ["websockets>=12.0", "fastapi>=0.115.6", "uvicorn>=0.24.0.post1"]
cloud-tool-sandbox = ["e2b-code-interpreter==1.5.2", "modal>=1.1.0"]
external-tools = ["docker>=7.1.0", "langchain>=0.3.7", "wikipedia>=1.4.0", "langchain-community>=0.3.7", "firecrawl-py==2.16.5"]
tests = ["wikipedia>=1.4.0", "pytest-asyncio>=0.24.0"]
sqlite = ["aiosqlite>=0.21.0", "sqlite-vec>=0.1.7a2"]
bedrock = ["boto3>=1.36.24", "aioboto3>=14.3.0"]
google = ["google-genai>=1.15.0"]
desktop = ["pyright>=1.1.347", "fastapi>=0.115.6", "uvicorn>=0.24.0.post1", "docker>=7.1.0", "langchain>=0.3.7", "wikipedia>=1.4.0", "langchain-community>=0.3.7", "locust>=2.31.5", "sqlite-vec>=0.1.7a2", "pgvector>=0.2.3"]
all = ["pgvector>=0.2.3", "turbopuffer>=0.5.17", "pg8000>=1.30.3", "psycopg2-binary>=2.9.10", "psycopg2>=2.9.10", "pytest", "pytest-asyncio>=0.24.0", "pexpect>=4.9.0", "black>=24.2.0", "pre-commit>=3.5.0", "pyright>=1.1.347", "pytest-order>=1.2.0", "autoflake>=2.3.0", "isort>=5.13.2", "fastapi>=0.115.6", "uvicorn>=0.24.0.post1", "docker>=7.1.0", "langchain>=0.3.7", "wikipedia>=1.4.0", "langchain-community>=0.3.7", "locust>=2.31.5", "uvloop>=0.21.0; sys_platform != 'win32'", "granian[reload]>=2.3.2", "redis>=6.2.0", "pinecone[asyncio]>=7.3.0", "google-cloud-profiler>=4.1.0"]
[project.scripts]
letta = "letta.main:app"
[tool.poetry]
name = "letta"
version = "0.11.4"
version = "0.11.5"
packages = [
{include = "letta"},
]
@@ -72,7 +162,7 @@ llama-index = "^0.12.2"
llama-index-embeddings-openai = "^0.3.1"
e2b-code-interpreter = {version = "^1.0.3", optional = true}
anthropic = "^0.49.0"
letta_client = "^0.1.220"
letta_client = "^0.1.277"
openai = "^1.99.9"
opentelemetry-api = "1.30.0"
opentelemetry-sdk = "1.30.0"

View File

@@ -1,8 +1,8 @@
{
"context_window": 8192,
"model": "llama-3.1-70b-versatile",
"model_endpoint_type": "groq",
"model_endpoint": "https://api.groq.com/openai/v1",
"model_wrapper": null,
"put_inner_thoughts_in_kwargs": true
"context_window": 8192,
"model": "qwen/qwen3-32b",
"model_endpoint_type": "groq",
"model_endpoint": "https://api.groq.com/openai/v1",
"model_wrapper": null,
"put_inner_thoughts_in_kwargs": true
}

View File

@@ -1,10 +1,12 @@
import logging
import os
from datetime import datetime, timezone
from typing import Generator
import pytest
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
from letta.server.db import db_registry
from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager
from letta.settings import tool_settings
@@ -14,6 +16,36 @@ def pytest_configure(config):
logging.basicConfig(level=logging.DEBUG)
@pytest.fixture(scope="session", autouse=True)
def disable_db_pooling_for_tests():
"""Disable database connection pooling for the entire test session."""
os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true"
yield
if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ:
del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"]
@pytest.fixture(autouse=True)
async def cleanup_db_connections():
"""Cleanup database connections after each test."""
yield
try:
if hasattr(db_registry, "_async_engines"):
for engine in db_registry._async_engines.values():
if engine:
try:
await engine.dispose()
except Exception:
# Suppress common teardown errors that don't affect test validity
pass
db_registry._initialized["async"] = False
db_registry._async_engines.clear()
db_registry._async_session_factories.clear()
except Exception:
# Suppress all cleanup errors to avoid confusing test failures
pass
@pytest.fixture
def disable_e2b_api_key() -> Generator[None, None, None]:
"""

View File

@@ -3,7 +3,7 @@ import os
import time
from typing import Optional, Union
import requests
from letta_client import AsyncLetta, Letta
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
@@ -254,7 +254,8 @@ def validate_context_window_overview(
assert len(overview.functions_definitions) > 0
def upload_test_agentfile_from_disk(server_url: str, filename: str) -> ImportedAgentsResponse:
# Changed this from server_url to client since client may be authenticated or not
def upload_test_agentfile_from_disk(client: Letta, filename: str) -> ImportedAgentsResponse:
"""
Upload a given .af file to live FastAPI server.
"""
@@ -263,18 +264,87 @@ def upload_test_agentfile_from_disk(server_url: str, filename: str) -> ImportedA
file_path = os.path.join(path_to_test_agent_files, filename)
with open(file_path, "rb") as f:
files = {"file": (filename, f, "application/json")}
return client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
# Send parameters as form data instead of query parameters
form_data = {
"append_copy_suffix": "true",
"override_existing_tools": "false",
}
response = requests.post(
f"{server_url}/v1/agents/import",
headers={"user_id": ""},
files=files,
data=form_data, # Send as form data
)
return ImportedAgentsResponse(**response.json())
async def upload_test_agentfile_from_disk_async(client: AsyncLetta, filename: str) -> ImportedAgentsResponse:
"""
Upload a given .af file to live FastAPI server.
"""
path_to_current_file = os.path.dirname(__file__)
path_to_test_agent_files = path_to_current_file.removesuffix("/helpers") + "/test_agent_files"
file_path = os.path.join(path_to_test_agent_files, filename)
with open(file_path, "rb") as f:
uploaded = await client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
return uploaded
def upload_file_and_wait(
client: Letta,
source_id: str,
file_path: str,
name: Optional[str] = None,
max_wait: int = 60,
duplicate_handling: Optional[str] = None,
):
"""Helper function to upload a file and wait for processing to complete"""
with open(file_path, "rb") as f:
if duplicate_handling:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
else:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
# wait for the file to be processed
start_time = time.time()
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
if time.time() - start_time > max_wait:
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
time.sleep(1)
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
print("Waiting for file processing to complete...", file_metadata.processing_status)
if file_metadata.processing_status == "error":
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
return file_metadata
def upload_file_and_wait_list_files(
client: Letta,
source_id: str,
file_path: str,
name: Optional[str] = None,
max_wait: int = 60,
duplicate_handling: Optional[str] = None,
):
"""Helper function to upload a file and wait for processing using list_files instead of get_file_metadata"""
with open(file_path, "rb") as f:
if duplicate_handling:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
else:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
# wait for the file to be processed using list_files
start_time = time.time()
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
if time.time() - start_time > max_wait:
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
time.sleep(1)
# use list_files to get all files and find our specific file
files = client.sources.files.list(source_id=source_id, limit=100)
# find the file with matching id
for file in files:
if file.id == file_metadata.id:
file_metadata = file
break
else:
raise RuntimeError(f"File {file_metadata.id} not found in source files list")
print("Waiting for file processing to complete (via list_files)...", file_metadata.processing_status)
if file_metadata.processing_status == "error":
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
return file_metadata

View File

@@ -26,13 +26,6 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture()
def server():
config = LettaConfig.load()

View File

@@ -321,7 +321,7 @@ def tool_with_pip_requirements(test_user):
import requests
# Simple usage to verify packages work
response = requests.get("https://httpbin.org/json", timeout=5)
response = requests.get("https://httpbin.org/json", timeout=30)
arr = np.array([1, 2, 3])
return f"Success! Status: {response.status_code}, Array sum: {np.sum(arr)}"
except ImportError as e:

View File

@@ -70,7 +70,7 @@ def client(server_url: str) -> Letta:
yield client_instance
@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
@@ -333,7 +333,7 @@ def test_web_search(
], f"Invalid api_key_source: {response_json['api_key_source']}"
@pytest.mark.asyncio
@pytest.mark.asyncio(scope="function")
async def test_web_search_uses_agent_env_var_model():
"""Test that web search uses the model specified in agent tool exec env vars."""

View File

@@ -32,20 +32,6 @@ from letta.services.tool_sandbox.modal_sandbox_v2 import AsyncToolSandboxModalV2
from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager, get_version_manager
from letta.services.user_manager import UserManager
@pytest.fixture
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.new_event_loop()
yield loop
# Cleanup tasks before closing loop
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
# ============================================================================
# SHARED FIXTURES
# ============================================================================
@@ -90,12 +76,12 @@ def basic_tool(test_user):
source_code="""
def calculate(operation: str, a: float, b: float) -> float:
'''Perform a calculation on two numbers.
Args:
operation: The operation to perform (add, subtract, multiply, divide)
a: The first number
b: The second number
Returns:
float: The result of the calculation
'''
@@ -145,11 +131,11 @@ import asyncio
async def fetch_data(url: str, delay: float = 0.1) -> Dict:
'''Simulate fetching data from a URL.
Args:
url: The URL to fetch data from
delay: The delay in seconds before returning
Returns:
Dict: A dictionary containing the fetched data
'''
@@ -194,17 +180,17 @@ import hashlib
def process_json(data: str) -> Dict:
'''Process JSON data and return metadata.
Args:
data: The JSON string to process
Returns:
Dict: Metadata about the JSON data
'''
try:
parsed = json.loads(data)
data_hash = hashlib.md5(data.encode()).hexdigest()
return {
"valid": True,
"keys": list(parsed.keys()) if isinstance(parsed, dict) else None,

View File

@@ -10,7 +10,7 @@ from dotenv import load_dotenv
from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage
from letta_client.core import RequestOptions
from tests.helpers.utils import upload_test_agentfile_from_disk
from tests.helpers.utils import upload_test_agentfile_from_disk_async
REASONING_THROTTLE_MS = 100
TEST_USER_MESSAGE = "What products or services does 11x AI sell?"
@@ -66,7 +66,7 @@ async def test_pinecone_tool(client: AsyncLetta, server_url: str) -> None:
"""
Test the Pinecone tool integration with the Letta client.
"""
response = upload_test_agentfile_from_disk(server_url, "knowledge-base.af")
response = await upload_test_agentfile_from_disk_async(client, "knowledge-base.af")
agent_id = response.agent_ids[0]

View File

@@ -144,7 +144,7 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
]
# configs for models that are to dumb to do much other than messaging
limited_configs = ["ollama.json", "together-qwen-2.5-72b-instruct.json", "vllm.json", "lmstudio.json"]
limited_configs = ["ollama.json", "together-qwen-2.5-72b-instruct.json", "vllm.json", "lmstudio.json", "groq.json"]
all_configs = [
"openai-gpt-4o-mini.json",
@@ -161,6 +161,7 @@ all_configs = [
"gemini-2.5-pro-vertex.json",
"ollama.json",
"together-qwen-2.5-72b-instruct.json",
"groq.json",
]
reasoning_configs = [
@@ -398,7 +399,7 @@ def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
assert content is None
def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]], reasoning_enabled: bool) -> None:
"""
Validate that Anthropic/Claude format assistant messages with tool_use have no <thinking> tags.
Args:
@@ -432,10 +433,12 @@ def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
# Verify that the message only contains tool_use items
tool_use_items = [item for item in content_list if item.get("type") == "tool_use"]
assert len(tool_use_items) > 0, "Assistant message should have at least one tool_use item"
assert len(content_list) == len(tool_use_items), (
f"Assistant message should ONLY contain tool_use items when reasoning is disabled. "
f"Found {len(content_list)} total items but only {len(tool_use_items)} are tool_use items."
)
if not reasoning_enabled:
assert len(content_list) == len(tool_use_items), (
f"Assistant message should ONLY contain tool_use items when reasoning is disabled. "
f"Found {len(content_list)} total items but only {len(tool_use_items)} are tool_use items."
)
def validate_google_format_scrubbing(contents: List[Dict[str, Any]]) -> None:
@@ -1131,6 +1134,146 @@ def test_token_streaming_agent_loop_error(
assert len(messages_from_db) == 0
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_background_token_streaming_greeting_with_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Use longer message for Anthropic models to test if they stream in chunks
if llm_config.model_endpoint_type == "anthropic":
messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY
else:
messages_to_send = USER_MESSAGE_FORCE_REPLY
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=messages_to_send,
stream_tokens=True,
background=True,
)
messages = accumulate_chunks(
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
)
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
run_id = messages[0].run_id
assert run_id is not None
response = client.runs.stream(run_id=run_id, starting_after=0)
messages = accumulate_chunks(
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
)
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
last_message_cursor = messages[-3].seq_id - 1
response = client.runs.stream(run_id=run_id, starting_after=last_message_cursor)
messages = accumulate_chunks(
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
)
assert len(messages) == 3
assert messages[0].message_type == "assistant_message" and messages[0].seq_id == last_message_cursor + 1
assert messages[1].message_type == "stop_reason"
assert messages[2].message_type == "usage_statistics"
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_background_token_streaming_greeting_without_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Use longer message for Anthropic models to force chunking
if llm_config.model_endpoint_type == "anthropic":
messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY
else:
messages_to_send = USER_MESSAGE_FORCE_REPLY
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=messages_to_send,
use_assistant_message=False,
stream_tokens=True,
background=True,
)
messages = accumulate_chunks(
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
)
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_background_token_streaming_tool_call(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Use longer message for Anthropic models to force chunking
if llm_config.model_endpoint_type == "anthropic":
messages_to_send = USER_MESSAGE_ROLL_DICE_LONG
else:
messages_to_send = USER_MESSAGE_ROLL_DICE
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=messages_to_send,
stream_tokens=True,
background=True,
)
messages = accumulate_chunks(
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
)
assert_tool_call_response(messages, streaming=True, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
start = time.time()
while True:
@@ -1842,7 +1985,7 @@ def test_inner_thoughts_toggle_interleaved(
validate_openai_format_scrubbing(messages)
elif llm_config.model_endpoint_type == "anthropic":
messages = response["messages"]
validate_anthropic_format_scrubbing(messages)
validate_anthropic_format_scrubbing(messages, llm_config.enable_reasoner)
elif llm_config.model_endpoint_type in ["google_ai", "google_vertex"]:
# Google uses 'contents' instead of 'messages'
contents = response.get("contents", response.get("messages", []))

View File

@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
from letta.functions.schema_generator import generate_tool_schema_for_mcp
from letta.functions.schema_validator import SchemaHealth, validate_complete_json_schema
@@ -113,10 +114,8 @@ def test_empty_object_in_required_marked_invalid():
@pytest.mark.asyncio
async def test_add_mcp_tool_rejects_non_strict_schemas():
"""Test that adding MCP tools with non-strict schemas is rejected."""
from fastapi import HTTPException
async def test_add_mcp_tool_accepts_non_strict_schemas():
"""Test that adding MCP tools with non-strict schemas is allowed."""
from letta.server.rest_api.routers.v1.tools import add_mcp_tool
from letta.settings import tool_settings
@@ -138,15 +137,19 @@ async def test_add_mcp_tool_rejects_non_strict_schemas():
mock_server = AsyncMock()
mock_server.get_tools_from_mcp_server = AsyncMock(return_value=[non_strict_tool])
mock_server.user_manager.get_user_or_default = MagicMock()
mock_server.tool_manager.create_mcp_tool_async = AsyncMock(return_value=non_strict_tool)
mock_get_server.return_value = mock_server
# Should raise HTTPException for non-strict schema
with pytest.raises(HTTPException) as exc_info:
await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, actor_id=None)
# Should accept non-strict schema without raising an exception
result = await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, actor_id=None)
assert exc_info.value.status_code == 400
assert "non-strict schema" in exc_info.value.detail["message"].lower()
assert exc_info.value.detail["health_status"] == SchemaHealth.NON_STRICT_ONLY.value
# Verify the tool was added successfully
assert result is not None
# Verify create_mcp_tool_async was called with the right parameters
mock_server.tool_manager.create_mcp_tool_async.assert_called_once()
call_args = mock_server.tool_manager.create_mcp_tool_async.call_args
assert call_args.kwargs["mcp_server_name"] == "test_server"
@pytest.mark.asyncio
@@ -183,3 +186,271 @@ async def test_add_mcp_tool_rejects_invalid_schemas():
assert exc_info.value.status_code == 400
assert "invalid schema" in exc_info.value.detail["message"].lower()
assert exc_info.value.detail["health_status"] == SchemaHealth.INVALID.value
def test_mcp_schema_healing_for_optional_fields():
"""Test that optional fields in MCP schemas are healed only in strict mode."""
# Create an MCP tool with optional field 'b'
mcp_tool = MCPTool(
name="test_tool",
description="A test tool",
inputSchema={
"type": "object",
"properties": {
"a": {"type": "integer", "description": "Required field"},
"b": {"type": "integer", "description": "Optional field"},
},
"required": ["a"], # Only 'a' is required
"additionalProperties": False,
},
)
# Generate schema without strict mode - should NOT heal optional fields
non_strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=False)
assert "a" in non_strict_schema["parameters"]["required"]
assert "b" not in non_strict_schema["parameters"]["required"] # Should remain optional
assert non_strict_schema["parameters"]["properties"]["b"]["type"] == "integer" # No null added
# Validate non-strict schema - should still be STRICT_COMPLIANT because validator is relaxed
status, _ = validate_complete_json_schema(non_strict_schema["parameters"])
assert status == SchemaHealth.STRICT_COMPLIANT
# Generate schema with strict mode - should heal optional fields
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
assert strict_schema["strict"] is True
assert "a" in strict_schema["parameters"]["required"]
assert "b" in strict_schema["parameters"]["required"] # Now required
assert set(strict_schema["parameters"]["properties"]["b"]["type"]) == {"integer", "null"} # Now accepts null
# Validate strict schema
status, _ = validate_complete_json_schema(strict_schema["parameters"])
assert status == SchemaHealth.STRICT_COMPLIANT # Should pass strict mode
def test_mcp_schema_healing_with_anyof():
"""Test schema healing for fields with anyOf that include optional types."""
mcp_tool = MCPTool(
name="test_tool",
description="A test tool",
inputSchema={
"type": "object",
"properties": {
"a": {"type": "string", "description": "Required field"},
"b": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"description": "Optional field with anyOf",
},
},
"required": ["a"], # Only 'a' is required
"additionalProperties": False,
},
)
# Generate strict schema
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
assert strict_schema["strict"] is True
assert "a" in strict_schema["parameters"]["required"]
assert "b" in strict_schema["parameters"]["required"] # Now required
# Type should be flattened array with deduplication
assert set(strict_schema["parameters"]["properties"]["b"]["type"]) == {"integer", "null"}
# Validate strict schema
status, _ = validate_complete_json_schema(strict_schema["parameters"])
assert status == SchemaHealth.STRICT_COMPLIANT
def test_mcp_schema_type_deduplication():
"""Test that duplicate types are deduplicated in schema generation."""
mcp_tool = MCPTool(
name="test_tool",
description="A test tool",
inputSchema={
"type": "object",
"properties": {
"field": {
"anyOf": [
{"type": "string"},
{"type": "string"}, # Duplicate
{"type": "null"},
],
"description": "Field with duplicate types",
},
},
"required": [],
"additionalProperties": False,
},
)
# Generate strict schema
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
# Check that duplicates were removed
field_types = strict_schema["parameters"]["properties"]["field"]["type"]
assert len(field_types) == len(set(field_types)) # No duplicates
assert set(field_types) == {"string", "null"}
def test_mcp_schema_healing_preserves_existing_null():
"""Test that schema healing doesn't add duplicate null when it already exists."""
mcp_tool = MCPTool(
name="test_tool",
description="A test tool",
inputSchema={
"type": "object",
"properties": {
"field": {
"type": ["string", "null"], # Already has null
"description": "Field that already accepts null",
},
},
"required": [], # Optional
"additionalProperties": False,
},
)
# Generate strict schema
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
# Check that null wasn't duplicated
field_types = strict_schema["parameters"]["properties"]["field"]["type"]
null_count = field_types.count("null")
assert null_count == 1 # Should only have one null
def test_mcp_schema_healing_all_fields_already_required():
"""Test that schema healing works correctly when all fields are already required."""
mcp_tool = MCPTool(
name="test_tool",
description="A test tool",
inputSchema={
"type": "object",
"properties": {
"a": {"type": "string", "description": "Field A"},
"b": {"type": "integer", "description": "Field B"},
},
"required": ["a", "b"], # All fields already required
"additionalProperties": False,
},
)
# Generate strict schema
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
# Check that fields remain as-is
assert set(strict_schema["parameters"]["required"]) == {"a", "b"}
assert strict_schema["parameters"]["properties"]["a"]["type"] == "string"
assert strict_schema["parameters"]["properties"]["b"]["type"] == "integer"
# Should be strict compliant
status, _ = validate_complete_json_schema(strict_schema["parameters"])
assert status == SchemaHealth.STRICT_COMPLIANT
def test_mcp_schema_with_uuid_format():
"""Test handling of UUID format in anyOf schemas (root cause of duplicate string types)."""
mcp_tool = MCPTool(
name="test_tool",
description="A test tool with UUID formatted field",
inputSchema={
"type": "object",
"properties": {
"session_id": {
"anyOf": [{"type": "string"}, {"format": "uuid", "type": "string"}, {"type": "null"}],
"description": "Session ID that can be a string, UUID, or null",
},
},
"required": [],
"additionalProperties": False,
},
)
# Generate strict schema
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
# Check that string type is not duplicated
session_props = strict_schema["parameters"]["properties"]["session_id"]
assert set(session_props["type"]) == {"string", "null"} # No duplicate strings
# Format should NOT be preserved because field is optional (has null type)
assert "format" not in session_props
# Should be in required array (healed)
assert "session_id" in strict_schema["parameters"]["required"]
# Should be strict compliant
status, _ = validate_complete_json_schema(strict_schema["parameters"])
assert status == SchemaHealth.STRICT_COMPLIANT
def test_mcp_schema_healing_only_in_strict_mode():
"""Test that schema healing only happens in strict mode."""
mcp_tool = MCPTool(
name="test_tool",
description="Test that healing only happens in strict mode",
inputSchema={
"type": "object",
"properties": {
"required_field": {"type": "string", "description": "Already required"},
"optional_field1": {"type": "integer", "description": "Optional 1"},
"optional_field2": {"type": "boolean", "description": "Optional 2"},
},
"required": ["required_field"],
"additionalProperties": False,
},
)
# Test with strict=False - no healing
non_strict = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=False)
assert "strict" not in non_strict # strict flag not set
assert non_strict["parameters"]["required"] == ["required_field"] # Only originally required field
assert non_strict["parameters"]["properties"]["required_field"]["type"] == "string"
assert non_strict["parameters"]["properties"]["optional_field1"]["type"] == "integer" # No null
assert non_strict["parameters"]["properties"]["optional_field2"]["type"] == "boolean" # No null
# Test with strict=True - healing happens
strict = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
assert strict["strict"] is True # strict flag is set
assert set(strict["parameters"]["required"]) == {"required_field", "optional_field1", "optional_field2"}
assert strict["parameters"]["properties"]["required_field"]["type"] == "string"
assert set(strict["parameters"]["properties"]["optional_field1"]["type"]) == {"integer", "null"}
assert set(strict["parameters"]["properties"]["optional_field2"]["type"]) == {"boolean", "null"}
# Both should be strict compliant (validator is relaxed)
status1, _ = validate_complete_json_schema(non_strict["parameters"])
status2, _ = validate_complete_json_schema(strict["parameters"])
assert status1 == SchemaHealth.STRICT_COMPLIANT
assert status2 == SchemaHealth.STRICT_COMPLIANT
def test_mcp_schema_with_uuid_format_required_field():
"""Test that UUID format is preserved for required fields that don't have null type."""
mcp_tool = MCPTool(
name="test_tool",
description="A test tool with required UUID formatted field",
inputSchema={
"type": "object",
"properties": {
"session_id": {
"anyOf": [{"type": "string"}, {"format": "uuid", "type": "string"}],
"description": "Session ID that must be a string with UUID format",
},
},
"required": ["session_id"], # Required field
"additionalProperties": False,
},
)
# Generate strict schema
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
# Check that string type is not duplicated and format IS preserved
session_props = strict_schema["parameters"]["properties"]["session_id"]
assert session_props["type"] == ["string"] # No null, no duplicates
assert "format" in session_props
assert session_props["format"] == "uuid" # Format should be preserved for non-optional field
# Should be in required array
assert "session_id" in strict_schema["parameters"]["required"]
# Should be strict compliant
status, _ = validate_complete_json_schema(strict_schema["parameters"])
assert status == SchemaHealth.STRICT_COMPLIANT

View File

@@ -22,7 +22,7 @@ class TestSchemaValidator:
"additionalProperties": False,
},
},
"required": ["name", "age"],
"required": ["name", "age", "address"], # All properties must be required for strict mode
"additionalProperties": False,
}
@@ -235,22 +235,22 @@ class TestSchemaValidator:
assert status == SchemaHealth.INVALID
assert any("dict" in reason for reason in reasons)
def test_schema_with_defaults_strict_compliant(self):
"""Test that root-level schemas without required field are STRICT_COMPLIANT."""
def test_schema_with_defaults_non_strict(self):
"""Test that root-level schemas without required field are STRICT_COMPLIANT (validator is relaxed)."""
schema = {
"type": "object",
"properties": {"name": {"type": "string"}, "optional": {"type": "string"}},
# Missing "required" field at root level is OK
# Missing "required" field at root level - validator now accepts this
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# After fix, root level without required should be STRICT_COMPLIANT
# Validator is relaxed - schemas with optional fields are now STRICT_COMPLIANT
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_composio_schema_with_optional_root_properties_strict_compliant(self):
"""Test that Composio-like schemas with optional root properties are STRICT_COMPLIANT."""
def test_composio_schema_with_optional_root_properties_non_strict(self):
"""Test that Composio-like schemas with optional root properties are STRICT_COMPLIANT (validator is relaxed)."""
schema = {
"type": "object",
"properties": {
@@ -267,25 +267,25 @@ class TestSchemaValidator:
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_root_level_without_required_strict_compliant(self):
"""Test that root-level objects without 'required' field are STRICT_COMPLIANT."""
def test_root_level_without_required_non_strict(self):
"""Test that root-level objects without 'required' field are STRICT_COMPLIANT (validator is relaxed)."""
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
# No "required" field at root level
# No "required" field at root level - validator now accepts this
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Root level without required should be STRICT_COMPLIANT
# Validator is relaxed - accepts schemas without required field
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_nested_object_without_required_non_strict(self):
"""Test that nested objects without 'required' remain NON_STRICT_ONLY."""
"""Test that nested objects without 'required' are STRICT_COMPLIANT (validator is relaxed)."""
schema = {
"type": "object",
"properties": {
@@ -309,6 +309,37 @@ class TestSchemaValidator:
}
status, reasons = validate_complete_json_schema(schema)
assert status == SchemaHealth.NON_STRICT_ONLY
# Should have warning about nested preferences object missing 'required'
assert any("required" in reason and "preferences" in reason for reason in reasons)
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_user_example_schema_non_strict(self):
"""Test the user's example schema with optional properties - now STRICT_COMPLIANT (validator is relaxed)."""
schema = {
"type": "object",
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None, "title": "B"},
},
"required": ["a"], # Only 'a' is required, 'b' is not
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_all_properties_required_strict_compliant(self):
"""Test that schemas with all properties required are STRICT_COMPLIANT."""
schema = {
"type": "object",
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None, "title": "B"},
},
"required": ["a", "b"], # All properties are required
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []

View File

@@ -1,4 +1,3 @@
import asyncio
from typing import List, Optional
import pytest
@@ -36,14 +35,6 @@ from tests.utils import create_tool_from_func
# ------------------------------
@pytest.fixture(scope="module")
def event_loop():
"""Use a single asyncio loop for the entire test session."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
def _clear_tables():
from letta.server.db import db_context

52
tests/test_embeddings.py Normal file
View File

@@ -0,0 +1,52 @@
import glob
import json
import os
import pytest
from letta.llm_api.llm_client import LLMClient
from letta.schemas.embedding_config import EmbeddingConfig
from letta.server.server import SyncServer
included_files = [
# "ollama.json",
"letta-hosted.json",
"openai_embed.json",
]
config_dir = "tests/configs/embedding_model_configs"
config_files = glob.glob(os.path.join(config_dir, "*.json"))
embedding_configs = [
EmbeddingConfig(**json.load(open(config_file))) for config_file in config_files if config_file.split("/")[-1] in included_files
]
@pytest.fixture
async def default_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_default_organization()
yield org
@pytest.fixture
def default_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_default_user(org_id=default_organization.id)
yield user
@pytest.mark.asyncio
@pytest.mark.parametrize(
"embedding_config",
embedding_configs,
ids=[c.embedding_model for c in embedding_configs],
)
async def test_embeddings(embedding_config: EmbeddingConfig, default_user):
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=default_user,
)
test_input = "This is a test input."
embeddings = await embedding_client.request_embeddings([test_input], embedding_config)
assert len(embeddings) == 1
assert len(embeddings[0]) == embedding_config.embedding_dim

View File

@@ -49,14 +49,6 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
# --------------------------------------------------------------------------- #
@pytest.fixture(scope="module")
def event_loop():
"""Use a single asyncio loop for the entire test session."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def weather_tool(server):
def get_weather(location: str) -> str:

View File

@@ -28,6 +28,10 @@ def server_url() -> str:
start_server(debug=True)
api_url = os.getenv("LETTA_API_URL")
if api_url:
return api_url
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
@@ -56,12 +60,18 @@ def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
api_url = os.getenv("LETTA_API_URL")
api_key = os.getenv("LETTA_API_KEY")
if api_url and not api_key:
raise ValueError("LETTA_API_KEY is required when passing LETTA_API_URL")
client_instance = Letta(token=api_key, base_url=api_url if api_url else server_url)
return client_instance
async def test_deep_research_agent(client, server_url, disable_e2b_api_key):
imported_af = upload_test_agentfile_from_disk(server_url, "deep-thought.af")
async def test_deep_research_agent(client: Letta, server_url, disable_e2b_api_key):
imported_af = upload_test_agentfile_from_disk(client, "deep-thought.af")
agent_id = imported_af.agent_ids[0]
@@ -69,6 +79,7 @@ async def test_deep_research_agent(client, server_url, disable_e2b_api_key):
response = client.agents.messages.create_stream(
agent_id=agent_id,
stream_tokens=True,
include_pings=True,
messages=[
MessageCreate(
role="user",
@@ -90,8 +101,8 @@ async def test_deep_research_agent(client, server_url, disable_e2b_api_key):
client.agents.delete(agent_id=agent_id)
async def test_11x_agent(client, server_url, disable_e2b_api_key):
imported_af = upload_test_agentfile_from_disk(server_url, "mock_alice.af")
async def test_11x_agent(client: Letta, server_url, disable_e2b_api_key):
imported_af = upload_test_agentfile_from_disk(client, "mock_alice.af")
agent_id = imported_af.agent_ids[0]

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,7 @@ from letta.settings import settings
@pytest.mark.asyncio
async def test_default_experimental_decorator(event_loop):
async def test_default_experimental_decorator():
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
@experimental("test_just_pass", fallback_function=lambda: False, kwarg1=3)
@@ -18,7 +18,7 @@ async def test_default_experimental_decorator(event_loop):
@pytest.mark.asyncio
async def test_overwrite_arg_success(event_loop):
async def test_overwrite_arg_success():
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: False, bool_val=True)
@@ -31,7 +31,7 @@ async def test_overwrite_arg_success(event_loop):
@pytest.mark.asyncio
async def test_overwrite_arg_fail(event_loop):
async def test_overwrite_arg_fail():
# Should fallback to lambda
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
@@ -61,7 +61,7 @@ async def test_overwrite_arg_fail(event_loop):
@pytest.mark.asyncio
async def test_redis_flag(event_loop):
async def test_redis_flag():
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
@experimental("test_redis_flag", fallback_function=lambda *args, **kwargs: _raise())

View File

@@ -130,7 +130,7 @@ async def test_provider_trace_experimental_step(message, agent_state, default_us
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user, event_loop):
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user):
experimental_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
@@ -169,7 +169,7 @@ async def test_provider_trace_experimental_step_stream(message, agent_state, def
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_provider_trace_step(client, agent_state, default_user, message, event_loop):
async def test_provider_trace_step(client, agent_state, default_user, message):
client.agents.messages.create(agent_id=agent_state.id, messages=[])
response = client.agents.messages.create(
agent_id=agent_state.id,
@@ -186,7 +186,7 @@ async def test_provider_trace_step(client, agent_state, default_user, message, e
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_noop_provider_trace(message, agent_state, default_user, event_loop):
async def test_noop_provider_trace(message, agent_state, default_user):
experimental_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),

View File

@@ -4,7 +4,7 @@ from letta.data_sources.redis_client import get_redis_client
@pytest.mark.asyncio
async def test_redis_client(event_loop):
async def test_redis_client():
test_values = {"LETTA_TEST_0": [1, 2, 3], "LETTA_TEST_1": ["apple", "pear", "banana"], "LETTA_TEST_2": ["{}", 3.2, "cat"]}
redis_client = await get_redis_client()

View File

@@ -0,0 +1,278 @@
"""
Test schema validation for OpenAI strict mode compliance.
"""
from letta.functions.schema_validator import SchemaHealth, validate_complete_json_schema
def test_user_example_schema_now_strict():
"""Test that schemas with optional fields are now considered STRICT_COMPLIANT (will be healed)."""
schema = {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"default": None,
"title": "B",
},
},
"required": ["a"], # Only 'a' is required, 'b' is not
"type": "object",
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should now be STRICT_COMPLIANT because we can heal optional fields
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_all_properties_required_is_strict():
"""Test that schemas with all properties required are STRICT_COMPLIANT."""
schema = {
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, # Optional via null type
},
"required": ["a", "b"], # All properties are required
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should be STRICT_COMPLIANT since all properties are required
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_nested_object_missing_required_now_strict():
"""Test that nested objects with optional fields are now STRICT_COMPLIANT (will be healed)."""
schema = {
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"host": {"type": "string"},
"port": {"type": "integer"},
"optional_field": {"anyOf": [{"type": "string"}, {"type": "null"}]},
},
"required": ["host", "port"], # optional_field not required
"additionalProperties": False,
}
},
"required": ["config"],
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should now be STRICT_COMPLIANT because we can heal optional fields
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_nested_object_all_required_is_strict():
"""Test that nested objects with all properties required are STRICT_COMPLIANT."""
schema = {
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"host": {"type": "string"},
"port": {"type": "integer"},
"timeout": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
},
"required": ["host", "port", "timeout"], # All properties required
"additionalProperties": False,
}
},
"required": ["config"],
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should be STRICT_COMPLIANT since all properties at all levels are required
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_empty_object_no_properties_is_strict():
"""Test that objects with no properties are STRICT_COMPLIANT."""
schema = {
"type": "object",
"properties": {},
"required": [],
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Empty objects with no properties should be STRICT_COMPLIANT
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_missing_additionalproperties_not_strict():
"""Test that missing additionalProperties makes schema NON_STRICT_ONLY."""
schema = {
"type": "object",
"properties": {
"field": {"type": "string"},
},
"required": ["field"],
# Missing additionalProperties
}
status, reasons = validate_complete_json_schema(schema)
# Should be NON_STRICT_ONLY due to missing additionalProperties
assert status == SchemaHealth.NON_STRICT_ONLY
assert any("additionalProperties" in reason and "not explicitly set" in reason for reason in reasons)
def test_additionalproperties_true_not_strict():
"""Test that additionalProperties: true makes schema NON_STRICT_ONLY."""
schema = {
"type": "object",
"properties": {
"field": {"type": "string"},
},
"required": ["field"],
"additionalProperties": True, # Allows additional properties
}
status, reasons = validate_complete_json_schema(schema)
# Should be NON_STRICT_ONLY due to additionalProperties not being false
assert status == SchemaHealth.NON_STRICT_ONLY
assert any("additionalProperties" in reason and "not false" in reason for reason in reasons)
def test_complex_schema_with_arrays():
"""Test a complex schema with arrays and nested objects."""
schema = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "integer"},
"name": {"type": "string"},
"tags": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["id", "name", "tags"], # All properties required
"additionalProperties": False,
},
},
"total": {"type": "integer"},
},
"required": ["items", "total"], # All properties required
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should be STRICT_COMPLIANT since all properties at all levels are required
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_fastmcp_tool_schema_now_strict():
"""Test that a schema from FastMCP with optional field 'b' is now STRICT_COMPLIANT."""
# This is the exact schema format provided by the user
schema = {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None, "title": "B"},
},
"required": ["a"], # Only 'a' is required, but we can heal this
"type": "object",
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should now be STRICT_COMPLIANT because we can heal optional fields
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_union_types_with_anyof():
"""Test that anyOf unions are handled correctly."""
schema = {
"type": "object",
"properties": {
"value": {
"anyOf": [
{"type": "string"},
{"type": "number"},
{"type": "null"},
]
}
},
"required": ["value"],
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should be STRICT_COMPLIANT - anyOf is allowed and all properties are required
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_healed_schema_with_type_array():
"""Test that healed schemas with type arrays including null are STRICT_COMPLIANT."""
# This represents a schema that has been healed by adding null to optional fields
schema = {
"type": "object",
"properties": {
"required_field": {"type": "string"},
"optional_field": {"type": ["integer", "null"]}, # Healed: was optional, now required with null
},
"required": ["required_field", "optional_field"], # All fields now required
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should be STRICT_COMPLIANT since all properties are required
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []
def test_healed_nested_schema():
"""Test that healed nested schemas are STRICT_COMPLIANT."""
schema = {
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"host": {"type": "string"},
"port": {"type": ["integer", "null"]}, # Healed optional field
"timeout": {"type": ["number", "null"]}, # Healed optional field
},
"required": ["host", "port", "timeout"], # All fields required after healing
"additionalProperties": False,
}
},
"required": ["config"],
"additionalProperties": False,
}
status, reasons = validate_complete_json_schema(schema)
# Should be STRICT_COMPLIANT after healing
assert status == SchemaHealth.STRICT_COMPLIANT
assert reasons == []

View File

@@ -1,5 +1,7 @@
import io
import json
import os
import textwrap
import threading
import time
import uuid
@@ -15,6 +17,8 @@ from letta_client.core import ApiError
from letta_client.types import AgentState, ToolReturnMessage
from pydantic import BaseModel, Field
from tests.helpers.utils import upload_file_and_wait
# Constants
SERVER_PORT = 8283
@@ -60,6 +64,117 @@ def agent(client: LettaSDKClient):
client.agents.delete(agent_id=agent_state.id)
@pytest.fixture(scope="function")
def fibonacci_tool(client: LettaSDKClient):
"""Fixture providing Fibonacci calculation tool."""
def calculate_fibonacci(n: int) -> int:
"""Calculate the nth Fibonacci number.
Args:
n: The position in the Fibonacci sequence to calculate.
Returns:
The nth Fibonacci number.
"""
if n <= 0:
return 0
elif n == 1:
return 1
else:
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b
tool = client.tools.upsert_from_function(func=calculate_fibonacci, tags=["math", "utility"])
yield tool
client.tools.delete(tool.id)
@pytest.fixture(scope="function")
def preferences_tool(client: LettaSDKClient):
"""Fixture providing user preferences tool."""
def get_user_preferences(category: str) -> str:
"""Get user preferences for a specific category.
Args:
category: The preference category to retrieve (notification, theme, language).
Returns:
The user's preference for the specified category, or "not specified" if unknown.
"""
preferences = {"notification": "email only", "theme": "dark mode", "language": "english"}
return preferences.get(category, "not specified")
tool = client.tools.upsert_from_function(func=get_user_preferences, tags=["user", "preferences"])
yield tool
client.tools.delete(tool.id)
@pytest.fixture(scope="function")
def data_analysis_tool(client: LettaSDKClient):
"""Fixture providing data analysis tool."""
def analyze_data(data_type: str, values: List[float]) -> str:
"""Analyze data and provide insights.
Args:
data_type: Type of data to analyze.
values: Numerical values to analyze.
Returns:
Analysis results including average, max, and min values.
"""
if not values:
return "No data provided"
avg = sum(values) / len(values)
max_val = max(values)
min_val = min(values)
return f"Analysis of {data_type}: avg={avg:.2f}, max={max_val}, min={min_val}"
tool = client.tools.upsert_from_function(func=analyze_data, tags=["analysis", "data"])
yield tool
client.tools.delete(tool.id)
@pytest.fixture(scope="function")
def persona_block(client: LettaSDKClient):
"""Fixture providing persona memory block."""
block = client.blocks.create(
label="persona",
value="You are Alex, a data analyst and mathematician who helps users with calculations and insights. You have extensive experience in statistical analysis and prefer to provide clear, accurate results.",
limit=8000,
)
yield block
client.blocks.delete(block.id)
@pytest.fixture(scope="function")
def human_block(client: LettaSDKClient):
"""Fixture providing human memory block."""
block = client.blocks.create(
label="human",
value="username: sarah_researcher\noccupation: data scientist\ninterests: machine learning, statistics, fibonacci sequences\npreferred_communication: detailed explanations with examples",
limit=4000,
)
yield block
client.blocks.delete(block.id)
@pytest.fixture(scope="function")
def context_block(client: LettaSDKClient):
"""Fixture providing project context memory block."""
block = client.blocks.create(
label="project_context",
value="Current project: Building predictive models for financial markets. Sarah is working on sequence analysis and pattern recognition. Recently interested in mathematical sequences like Fibonacci for trend analysis.",
limit=6000,
)
yield block
client.blocks.delete(block.id)
def test_shared_blocks(client: LettaSDKClient):
# create a block
block = client.blocks.create(
@@ -1465,7 +1580,6 @@ def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
"""Test that passing both new JSON schema AND source code still renames the tool based on source code"""
import textwrap
# Create initial tool
def initial_tool(x: int) -> int:
@@ -1543,3 +1657,346 @@ def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
finally:
# Clean up
client.tools.delete(tool_id=tool.id)
def test_import_agent_file_from_disk(
client: LettaSDKClient, fibonacci_tool, preferences_tool, data_analysis_tool, persona_block, human_block, context_block
):
"""Test exporting an agent to file and importing it back from disk."""
# Create a comprehensive agent (similar to test_agent_serialization_v2)
name = f"test_export_import_{str(uuid.uuid4())}"
temp_agent = client.agents.create(
name=name,
memory_blocks=[persona_block, human_block, context_block],
model="openai/gpt-4.1-mini",
embedding="openai/text-embedding-3-small",
tool_ids=[fibonacci_tool.id, preferences_tool.id, data_analysis_tool.id],
include_base_tools=True,
tags=["test", "export", "import"],
system="You are a helpful assistant specializing in data analysis and mathematical computations.",
)
# Add archival memory
archival_passages = ["Test archival passage for export/import testing.", "Another passage with data about testing procedures."]
for passage_text in archival_passages:
client.agents.passages.create(agent_id=temp_agent.id, text=passage_text)
# Send a test message
client.agents.messages.create(
agent_id=temp_agent.id,
messages=[
MessageCreate(
role="user",
content="Test message for export",
),
],
)
# Export the agent
serialized_v2 = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
# Save to file
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_basic_agent_with_blocks_tools_messages_v2.af")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
json.dump(serialized_v2, f, indent=2)
# Now import from the file
with open(file_path, "rb") as f:
import_result = client.agents.import_file(
file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict
)
# Basic verification
assert import_result is not None, "Import result should not be None"
assert len(import_result.agent_ids) > 0, "Should have imported at least one agent"
# Get the imported agent
imported_agent_id = import_result.agent_ids[0]
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
# Basic checks
assert imported_agent is not None, "Should be able to retrieve imported agent"
assert imported_agent.name is not None, "Imported agent should have a name"
assert imported_agent.memory is not None, "Agent should have memory"
assert len(imported_agent.tools) > 0, "Agent should have tools"
assert imported_agent.system is not None, "Agent should have a system prompt"
def test_agent_serialization_v2(
client: LettaSDKClient, fibonacci_tool, preferences_tool, data_analysis_tool, persona_block, human_block, context_block
):
"""Test agent serialization with comprehensive setup including custom tools, blocks, messages, and archival memory."""
name = f"comprehensive_test_agent_{str(uuid.uuid4())}"
temp_agent = client.agents.create(
name=name,
memory_blocks=[persona_block, human_block, context_block],
model="openai/gpt-4.1-mini",
embedding="openai/text-embedding-3-small",
tool_ids=[fibonacci_tool.id, preferences_tool.id, data_analysis_tool.id],
include_base_tools=True,
tags=["test", "comprehensive", "serialization"],
system="You are a helpful assistant specializing in data analysis and mathematical computations.",
)
# Add archival memory
archival_passages = [
"Project background: Sarah is working on a financial prediction model that uses Fibonacci retracements for technical analysis.",
"Research notes: Golden ratio (1.618) derived from Fibonacci sequence is often used in financial markets for support/resistance levels.",
]
for passage_text in archival_passages:
client.agents.passages.create(agent_id=temp_agent.id, text=passage_text)
# Send some messages
client.agents.messages.create(
agent_id=temp_agent.id,
messages=[
MessageCreate(
role="user",
content="Test message",
),
],
)
# Serialize using v2
serialized_v2 = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
# Convert dict to JSON bytes for import
json_str = json.dumps(serialized_v2)
file_obj = io.BytesIO(json_str.encode("utf-8"))
# Import again
import_result = client.agents.import_file(file=file_obj, append_copy_suffix=False, override_existing_tools=True)
# Verify import was successful
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
imported_agent_id = import_result.agent_ids[0]
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
# ========== BASIC AGENT PROPERTIES ==========
# Name should be the same (if append_copy_suffix=False) or have suffix
assert imported_agent.name == name, f"Agent name mismatch: {imported_agent.name} != {name}"
# LLM and embedding configs should be preserved
assert (
imported_agent.llm_config.model == temp_agent.llm_config.model
), f"LLM model mismatch: {imported_agent.llm_config.model} != {temp_agent.llm_config.model}"
assert imported_agent.embedding_config.embedding_model == temp_agent.embedding_config.embedding_model, "Embedding model mismatch"
# System prompt should be preserved
assert imported_agent.system == temp_agent.system, "System prompt was not preserved"
# Tags should be preserved
assert set(imported_agent.tags) == set(temp_agent.tags), f"Tags mismatch: {imported_agent.tags} != {temp_agent.tags}"
# Agent type should be preserved
assert (
imported_agent.agent_type == temp_agent.agent_type
), f"Agent type mismatch: {imported_agent.agent_type} != {temp_agent.agent_type}"
# ========== MEMORY BLOCKS ==========
# Compare memory blocks directly from AgentState objects
original_blocks = temp_agent.memory.blocks
imported_blocks = imported_agent.memory.blocks
# Should have same number of blocks
assert len(imported_blocks) == len(original_blocks), f"Block count mismatch: {len(imported_blocks)} != {len(original_blocks)}"
# Verify each block by label
original_blocks_by_label = {block.label: block for block in original_blocks}
imported_blocks_by_label = {block.label: block for block in imported_blocks}
# Check persona block
assert "persona" in imported_blocks_by_label, "Persona block missing in imported agent"
assert "Alex" in imported_blocks_by_label["persona"].value, "Persona block content not preserved"
assert imported_blocks_by_label["persona"].limit == original_blocks_by_label["persona"].limit, "Persona block limit mismatch"
# Check human block
assert "human" in imported_blocks_by_label, "Human block missing in imported agent"
assert "sarah_researcher" in imported_blocks_by_label["human"].value, "Human block content not preserved"
assert imported_blocks_by_label["human"].limit == original_blocks_by_label["human"].limit, "Human block limit mismatch"
# Check context block
assert "project_context" in imported_blocks_by_label, "Context block missing in imported agent"
assert "financial markets" in imported_blocks_by_label["project_context"].value, "Context block content not preserved"
assert (
imported_blocks_by_label["project_context"].limit == original_blocks_by_label["project_context"].limit
), "Context block limit mismatch"
# ========== TOOLS ==========
# Compare tools directly from AgentState objects
original_tools = temp_agent.tools
imported_tools = imported_agent.tools
# Should have same number of tools
assert len(imported_tools) == len(original_tools), f"Tool count mismatch: {len(imported_tools)} != {len(original_tools)}"
original_tool_names = {tool.name for tool in original_tools}
imported_tool_names = {tool.name for tool in imported_tools}
# Check custom tools are present
assert "calculate_fibonacci" in imported_tool_names, "Fibonacci tool missing in imported agent"
assert "get_user_preferences" in imported_tool_names, "Preferences tool missing in imported agent"
assert "analyze_data" in imported_tool_names, "Data analysis tool missing in imported agent"
# Check for base tools (since we set include_base_tools=True when creating the agent)
# Base tools should also be present (at least some core ones)
base_tool_names = {"send_message", "conversation_search"}
missing_base_tools = base_tool_names - imported_tool_names
assert len(missing_base_tools) == 0, f"Missing base tools: {missing_base_tools}"
# Verify tool names match exactly
assert original_tool_names == imported_tool_names, f"Tool names don't match: {original_tool_names} != {imported_tool_names}"
# ========== MESSAGE HISTORY ==========
# Get messages for both agents
original_messages = client.agents.messages.list(agent_id=temp_agent.id, limit=100)
imported_messages = client.agents.messages.list(agent_id=imported_agent_id, limit=100)
# Should have same number of messages
assert len(imported_messages) >= 1, "Imported agent should have messages"
# Filter for user messages (excluding system-generated login messages)
original_user_msgs = [msg for msg in original_messages if msg.message_type == "user_message" and "Test message" in msg.content]
imported_user_msgs = [msg for msg in imported_messages if msg.message_type == "user_message" and "Test message" in msg.content]
# Should have the same number of test messages
assert len(imported_user_msgs) == len(
original_user_msgs
), f"User message count mismatch: {len(imported_user_msgs)} != {len(original_user_msgs)}"
# Verify test message content is preserved
if len(original_user_msgs) > 0 and len(imported_user_msgs) > 0:
assert imported_user_msgs[0].content == original_user_msgs[0].content, "User message content not preserved"
assert "Test message" in imported_user_msgs[0].content, "Test message content not found"
def test_export_import_agent_with_files(client: LettaSDKClient):
"""Test exporting and importing an agent with files attached."""
# Clean up any existing source with the same name from previous runs
existing_sources = client.sources.list()
for existing_source in existing_sources:
client.sources.delete(source_id=existing_source.id)
# Create a source and upload test files
source = client.sources.create(name="test_export_source", embedding="openai/text-embedding-3-small")
# Upload test files to the source
test_files = ["tests/data/test.txt", "tests/data/test.md"]
for file_path in test_files:
upload_file_and_wait(client, source.id, file_path)
# Verify files were uploaded successfully
files_in_source = client.sources.files.list(source_id=source.id, limit=10)
assert len(files_in_source) == len(test_files), f"Expected {len(test_files)} files, got {len(files_in_source)}"
# Create a simple agent with the source attached
temp_agent = client.agents.create(
memory_blocks=[
CreateBlock(label="human", value="username: sarah"),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
source_ids=[source.id], # Attach the source with files
)
# Verify the agent has the source and file blocks
agent_state = client.agents.retrieve(agent_id=temp_agent.id)
assert len(agent_state.sources) == 1, "Agent should have one source attached"
assert agent_state.sources[0].id == source.id, "Agent should have the correct source attached"
# Verify file blocks are present
file_blocks = agent_state.memory.file_blocks
assert len(file_blocks) == len(test_files), f"Expected {len(test_files)} file blocks, got {len(file_blocks)}"
# Export the agent
serialized_agent = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
# Convert to JSON bytes for import
json_str = json.dumps(serialized_agent)
file_obj = io.BytesIO(json_str.encode("utf-8"))
# Import the agent
import_result = client.agents.import_file(file=file_obj, append_copy_suffix=True, override_existing_tools=True)
# Verify import was successful
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
imported_agent_id = import_result.agent_ids[0]
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
# Verify the source is attached to the imported agent
assert len(imported_agent.sources) == 1, "Imported agent should have one source attached"
imported_source = imported_agent.sources[0]
# Check that imported source has the same files
imported_files = client.sources.files.list(source_id=imported_source.id, limit=10)
assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files"
# Verify file blocks are preserved in imported agent
imported_file_blocks = imported_agent.memory.file_blocks
assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks"
# Verify file block content
for file_block in imported_file_blocks:
assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content"
assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header"
# Test that files can be opened on the imported agent
if len(imported_files) > 0:
test_file = imported_files[0]
client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id)
# Clean up
client.agents.delete(agent_id=temp_agent.id)
client.agents.delete(agent_id=imported_agent_id)
client.sources.delete(source_id=source.id)
def test_import_agent_with_files_from_disk(client: LettaSDKClient):
"""Test exporting an agent with files to disk and importing it back."""
# Upload test files to the source
test_files = ["tests/data/test.txt", "tests/data/test.md"]
# Save to file
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent_with_files_and_sources.af")
# Now import from the file
with open(file_path, "rb") as f:
import_result = client.agents.import_file(
file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict
)
# Verify import was successful
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
imported_agent_id = import_result.agent_ids[0]
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
# Verify the source is attached to the imported agent
assert len(imported_agent.sources) == 1, "Imported agent should have one source attached"
imported_source = imported_agent.sources[0]
# Check that imported source has the same files
imported_files = client.sources.files.list(source_id=imported_source.id, limit=10)
assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files"
# Verify file blocks are preserved in imported agent
imported_file_blocks = imported_agent.memory.file_blocks
assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks"
# Verify file block content
for file_block in imported_file_blocks:
assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content"
assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header"
# Test that files can be opened on the imported agent
if len(imported_files) > 0:
test_file = imported_files[0]
client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id)
# Clean up agents and sources
client.agents.delete(agent_id=imported_agent_id)
client.sources.delete(source_id=imported_source.id)

View File

@@ -485,7 +485,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
@pytest.mark.asyncio
async def test_read_local_llm_configs(server: SyncServer, user: User, event_loop):
async def test_read_local_llm_configs(server: SyncServer, user: User):
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
clean_up_dir = False
if not os.path.exists(configs_base_dir):
@@ -1016,7 +1016,7 @@ async def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, b
@pytest.mark.asyncio
async def test_messages_with_provider_override(server: SyncServer, user_id: str, event_loop):
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
provider = server.provider_manager.create_provider(
request=ProviderCreate(
@@ -1096,7 +1096,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str,
@pytest.mark.asyncio
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User, event_loop):
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
models = await server.list_llm_models_async(actor=user)
model_handles = [model.handle for model in models]
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"

View File

@@ -1,3 +1,4 @@
import asyncio
import os
import re
import tempfile
@@ -7,17 +8,19 @@ from datetime import datetime, timedelta
import pytest
from dotenv import load_dotenv
from letta_client import CreateBlock, DuplicateFileHandling
from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client import LettaRequest
from letta_client import MessageCreate as ClientMessageCreate
from letta_client.types import AgentState
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.schemas.enums import FileProcessingStatus, ToolType
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
from letta.settings import settings
from tests.helpers.utils import upload_file_and_wait, upload_file_and_wait_list_files
from tests.utils import wait_for_server
# Constants
@@ -71,31 +74,6 @@ def client() -> LettaSDKClient:
yield client
def upload_file_and_wait(
client: LettaSDKClient, source_id: str, file_path: str, max_wait: int = 60, duplicate_handling: DuplicateFileHandling = None
):
"""Helper function to upload a file and wait for processing to complete"""
with open(file_path, "rb") as f:
if duplicate_handling:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling)
else:
file_metadata = client.sources.files.upload(source_id=source_id, file=f)
# Wait for the file to be processed
start_time = time.time()
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
if time.time() - start_time > max_wait:
pytest.fail(f"File processing timed out after {max_wait} seconds")
time.sleep(1)
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
print("Waiting for file processing to complete...", file_metadata.processing_status)
if file_metadata.processing_status == "error":
pytest.fail(f"File processing failed: {file_metadata.error_message}")
return file_metadata
@pytest.fixture
def agent_state(disable_pinecone, client: LettaSDKClient):
open_file_tool = client.tools.list(name="open_files")[0]
@@ -418,7 +396,7 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK
assert initial_content_length > 10, f"Expected file content > 10 chars, got {initial_content_length}"
# Ask agent to open the file for a specific range using offset/length
offset, length = 1, 5 # 1-indexed offset, 5 lines
offset, length = 0, 5 # 0-indexed offset, 5 lines
print(f"Requesting agent to open file with offset={offset}, length={length}")
open_response1 = client.agents.messages.create(
agent_id=agent_state.id,
@@ -447,7 +425,7 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK
assert "5: " in old_value, f"Expected line 5 to be present, got: {old_value}"
# Ask agent to open the file for a different range
offset, length = 6, 5 # Different offset, same length
offset, length = 5, 5 # Different offset, same length
open_response2 = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
@@ -476,8 +454,8 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK
assert "10: " in new_value, f"Expected line 10 to be present, got: {new_value}"
print(f"Comparing content ranges:")
print(f" First range (offset=1, length=5): '{old_value}'")
print(f" Second range (offset=6, length=5): '{new_value}'")
print(f" First range (offset=0, length=5): '{old_value}'")
print(f" Second range (offset=5, length=5): '{new_value}'")
assert new_value != old_value, f"Different view ranges should have different content. New: '{new_value}', Old: '{old_value}'"
@@ -697,7 +675,7 @@ def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, age
assert block.value.startswith("[Viewing file start (out of 100 lines)]")
# Open a specific range using offset/length
offset = 50 # 1-indexed line 50
offset = 49 # 0-indexed for line 50
length = 5 # 5 lines (50-54)
open_response = client.agents.messages.create(
agent_id=agent_state.id,
@@ -851,6 +829,87 @@ def test_duplicate_file_handling_replace(disable_pinecone, client: LettaSDKClien
os.unlink(temp_file_path)
def test_upload_file_with_custom_name(disable_pinecone, client: LettaSDKClient):
"""Test that uploading a file with a custom name overrides the original filename"""
# Create agent
agent_state = client.agents.create(
name="test_agent_custom_name",
memory_blocks=[
CreateBlock(
label="persona",
value="I am a helpful assistant",
),
CreateBlock(
label="human",
value="The user is a developer",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
)
# Create source
source = client.sources.create(name="test_source_custom_name", embedding="openai/text-embedding-3-small")
# Attach source to agent
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Create a temporary file with specific content
import tempfile
temp_file_path = None
try:
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("This is a test file for custom naming")
temp_file_path = f.name
# Upload file with custom name
custom_name = "my_custom_file_name.txt"
file_metadata = upload_file_and_wait(client, source.id, temp_file_path, name=custom_name)
# Verify the file uses the custom name
assert file_metadata.file_name == custom_name
assert file_metadata.original_file_name == custom_name
# Verify file appears in source files list with custom name
files = client.sources.files.list(source_id=source.id, limit=1)
assert len(files) == 1
assert files[0].file_name == custom_name
assert files[0].original_file_name == custom_name
# Verify the custom name is used in file blocks
agent_state = client.agents.retrieve(agent_id=agent_state.id)
file_blocks = agent_state.memory.file_blocks
assert len(file_blocks) == 1
# Check that the custom name appears in the block label
assert custom_name.replace(".txt", "") in file_blocks[0].label
# Test duplicate handling with custom name - upload same file with same custom name
from letta.schemas.enums import DuplicateFileHandling
with pytest.raises(Exception) as exc_info:
upload_file_and_wait(client, source.id, temp_file_path, name=custom_name, duplicate_handling=DuplicateFileHandling.ERROR)
assert "already exists" in str(exc_info.value).lower()
# Upload same file with different custom name should succeed
different_custom_name = "folder_a/folder_b/another_custom_name.txt"
file_metadata2 = upload_file_and_wait(client, source.id, temp_file_path, name=different_custom_name)
assert file_metadata2.file_name == different_custom_name
assert file_metadata2.original_file_name == different_custom_name
# Verify both files exist
files = client.sources.files.list(source_id=source.id, limit=10)
assert len(files) == 2
file_names = {f.file_name for f in files}
assert custom_name in file_names
assert different_custom_name in file_names
finally:
# Clean up temporary file
if temp_file_path and os.path.exists(temp_file_path):
os.unlink(temp_file_path)
def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
"""Test that open_files tool schema contains correct descriptions from docstring"""
@@ -873,9 +932,9 @@ def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient
# Check that examples are included
assert "Examples:" in description
assert 'FileOpenRequest(file_name="project_utils/config.py")' in description
assert 'FileOpenRequest(file_name="project_utils/config.py", offset=1, length=50)' in description
assert 'FileOpenRequest(file_name="project_utils/config.py", offset=0, length=50)' in description
assert "# Lines 1-50" in description
assert "# Lines 100-199" in description
assert "# Lines 101-200" in description
assert "# Entire file" in description
assert "close_all_others=True" in description
assert "View specific portions of large files (e.g. functions or definitions)" in description
@@ -922,7 +981,7 @@ def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient
# Check offset field
assert "offset" in file_request_properties
offset_prop = file_request_properties["offset"]
expected_offset_desc = "Optional starting line number (1-indexed). If not specified, starts from beginning of file."
expected_offset_desc = "Optional offset for starting line number (0-indexed). If not specified, starts from beginning of file."
assert offset_prop["description"] == expected_offset_desc
assert offset_prop["type"] == "integer"
@@ -1074,10 +1133,43 @@ def test_pinecone_search_files_tool(client: LettaSDKClient):
), f"Search results should contain relevant content: {search_results}"
def test_pinecone_list_files_status(client: LettaSDKClient):
"""Test that list_source_files properly syncs embedding status with Pinecone"""
if not should_use_pinecone():
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
# create source
source = client.sources.create(name="test_list_files_status", embedding="openai/text-embedding-3-small")
file_paths = ["tests/data/long_test.txt"]
uploaded_files = []
for file_path in file_paths:
# use the new helper that polls via list_files
file_metadata = upload_file_and_wait_list_files(client, source.id, file_path)
uploaded_files.append(file_metadata)
assert file_metadata.processing_status == "completed", f"File {file_path} should be completed"
# now get files using list_source_files to verify status checking works
files_list = client.sources.files.list(source_id=source.id, limit=100)
# verify all files show completed status and have proper embedding counts
assert len(files_list) == len(uploaded_files), f"Expected {len(uploaded_files)} files, got {len(files_list)}"
for file_metadata in files_list:
assert file_metadata.processing_status == "completed", f"File {file_metadata.file_name} should show completed status"
# verify embedding counts for files that have chunks
if file_metadata.total_chunks and file_metadata.total_chunks > 0:
assert (
file_metadata.chunks_embedded == file_metadata.total_chunks
), f"File {file_metadata.file_name} should have all chunks embedded: {file_metadata.chunks_embedded}/{file_metadata.total_chunks}"
# cleanup
client.sources.delete(source_id=source.id)
def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
"""Test that file and source deletion removes records from Pinecone"""
import asyncio
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
if not should_use_pinecone():
@@ -1146,8 +1238,6 @@ def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
len(records_after) == 0
), f"All source records should be removed from Pinecone after source deletion, but found {len(records_after)}"
print("✓ Pinecone lifecycle verified - namespace is clean after source deletion")
def test_agent_open_file(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
"""Test client.agents.open_file() function"""

View File

@@ -1,5 +1,3 @@
import asyncio
import pytest
from letta.constants import MAX_FILENAME_LENGTH
@@ -522,19 +520,8 @@ def test_line_chunker_only_start_parameter():
# ---------------------- Alembic Revision TESTS ---------------------- #
@pytest.fixture(scope="module")
def event_loop():
"""
Create an event loop for the entire test session.
Ensures all async tasks use the same loop, avoiding cross-loop errors.
"""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.mark.asyncio
async def test_get_latest_alembic_revision(event_loop):
async def test_get_latest_alembic_revision():
"""Test that get_latest_alembic_revision returns a valid revision ID from the database."""
from letta.utils import get_latest_alembic_revision
@@ -553,7 +540,7 @@ async def test_get_latest_alembic_revision(event_loop):
@pytest.mark.asyncio
async def test_get_latest_alembic_revision_consistency(event_loop):
async def test_get_latest_alembic_revision_consistency():
"""Test that get_latest_alembic_revision returns the same value on multiple calls."""
from letta.utils import get_latest_alembic_revision

5958
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff