chore: bump version 0.16.2 (#3140)

This commit is contained in:
cthomas
2026-01-12 11:04:11 -08:00
committed by GitHub
165 changed files with 15002 additions and 1832 deletions

View File

@@ -184,7 +184,7 @@ def _start_server_once() -> str:
thread.start()
# Poll until up
timeout_seconds = 30
timeout_seconds = 60
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12

View File

@@ -90,6 +90,36 @@ Workflow for Postgres-targeted migration:
- `uv run alembic upgrade head`
- `uv run alembic revision --autogenerate -m "..."`
### 5. Resetting local Postgres for clean migration generation
If your local Postgres database has drifted from main (e.g., applied migrations
that no longer exist, or has stale schema), you can reset it to generate a clean
migration.
From the repo root (`/Users/sarahwooders/repos/letta-cloud`):
```bash
# 1. Remove postgres data directory
rm -rf ./data/postgres
# 2. Stop the running postgres container
docker stop $(docker ps -q --filter ancestor=ankane/pgvector)
# 3. Restart services (creates fresh postgres)
just start-services
# 4. Wait a moment for postgres to be ready, then apply all migrations
cd apps/core
export LETTA_PG_URI=postgresql+pg8000://postgres:postgres@localhost:5432/letta-core
uv run alembic upgrade head
# 5. Now generate your new migration
uv run alembic revision --autogenerate -m "your migration message"
```
This ensures the migration is generated against a clean database state matching
main, avoiding spurious diffs from local-only schema changes.
## Troubleshooting
- **"Target database is not up to date" when autogenerating**
@@ -101,7 +131,7 @@ Workflow for Postgres-targeted migration:
changed model is imported in Alembic env context.
- **Autogenerated migration has unexpected drops/renames**
- Review model changes; consider explicit operations instead of relying on
autogenerate.
autogenerate. Reset local Postgres (see workflow 5) to get a clean baseline.
## References

View File

@@ -0,0 +1,97 @@
"""add conversations tables and run conversation_id
Revision ID: 27de0f58e076
Revises: ee2b43eea55e
Create Date: 2026-01-01 20:36:09.101274
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "27de0f58e076"
down_revision: Union[str, None] = "ee2b43eea55e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"conversations",
sa.Column("id", sa.String(), nullable=False),
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("summary", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_conversations_agent_id", "conversations", ["agent_id"], unique=False)
op.create_index("ix_conversations_org_agent", "conversations", ["organization_id", "agent_id"], unique=False)
op.create_table(
"conversation_messages",
sa.Column("id", sa.String(), nullable=False),
sa.Column("conversation_id", sa.String(), nullable=True),
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("message_id", sa.String(), nullable=False),
sa.Column("position", sa.Integer(), nullable=False),
sa.Column("in_context", sa.Boolean(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["conversation_id"], ["conversations.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["message_id"], ["messages.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("conversation_id", "message_id", name="unique_conversation_message"),
)
op.create_index("ix_conv_msg_agent_conversation", "conversation_messages", ["agent_id", "conversation_id"], unique=False)
op.create_index("ix_conv_msg_agent_id", "conversation_messages", ["agent_id"], unique=False)
op.create_index("ix_conv_msg_conversation_position", "conversation_messages", ["conversation_id", "position"], unique=False)
op.create_index("ix_conv_msg_message_id", "conversation_messages", ["message_id"], unique=False)
op.add_column("messages", sa.Column("conversation_id", sa.String(), nullable=True))
op.create_index(op.f("ix_messages_conversation_id"), "messages", ["conversation_id"], unique=False)
op.create_foreign_key(None, "messages", "conversations", ["conversation_id"], ["id"], ondelete="SET NULL")
op.add_column("runs", sa.Column("conversation_id", sa.String(), nullable=True))
op.create_index("ix_runs_conversation_id", "runs", ["conversation_id"], unique=False)
op.create_foreign_key(None, "runs", "conversations", ["conversation_id"], ["id"], ondelete="SET NULL")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "runs", type_="foreignkey")
op.drop_index("ix_runs_conversation_id", table_name="runs")
op.drop_column("runs", "conversation_id")
op.drop_constraint(None, "messages", type_="foreignkey")
op.drop_index(op.f("ix_messages_conversation_id"), table_name="messages")
op.drop_column("messages", "conversation_id")
op.drop_index("ix_conv_msg_message_id", table_name="conversation_messages")
op.drop_index("ix_conv_msg_conversation_position", table_name="conversation_messages")
op.drop_index("ix_conv_msg_agent_id", table_name="conversation_messages")
op.drop_index("ix_conv_msg_agent_conversation", table_name="conversation_messages")
op.drop_table("conversation_messages")
op.drop_index("ix_conversations_org_agent", table_name="conversations")
op.drop_index("ix_conversations_agent_id", table_name="conversations")
op.drop_table("conversations")
# ### end Alembic commands ###

View File

@@ -0,0 +1,31 @@
"""add request_id to steps table
Revision ID: ee2b43eea55e
Revises: 39577145c45d
Create Date: 2025-12-17 13:48:08.642245
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "ee2b43eea55e"
down_revision: Union[str, None] = "39577145c45d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("steps", sa.Column("request_id", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("steps", "request_id")
# ### end Alembic commands ###

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@ try:
__version__ = version("letta")
except PackageNotFoundError:
# Fallback for development installations
__version__ = "0.16.1"
__version__ = "0.16.2"
if os.environ.get("LETTA_VERSION"):
__version__ = os.environ["LETTA_VERSION"]

View File

@@ -87,9 +87,13 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
raise self.llm_client.handle_llm_error(e)
# Process the stream and yield chunks immediately for TTFT
async for chunk in self.interface.process(stream): # TODO: add ttft span
# Yield each chunk immediately as it arrives
yield chunk
# Wrap in error handling to convert provider errors to common LLMError types
try:
async for chunk in self.interface.process(stream): # TODO: add ttft span
# Yield each chunk immediately as it arrives
yield chunk
except Exception as e:
raise self.llm_client.handle_llm_error(e)
# After streaming completes, extract the accumulated data
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()

View File

@@ -75,7 +75,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
run_id=self.run_id,
step_id=step_id,
)
elif self.llm_config.model_endpoint_type in [ProviderType.openai, ProviderType.deepseek]:
elif self.llm_config.model_endpoint_type in [ProviderType.openai, ProviderType.deepseek, ProviderType.zai]:
# Decide interface based on payload shape
use_responses = "input" in request_data and "messages" not in request_data
# No support for Responses API proxy

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator
from typing import TYPE_CHECKING, AsyncGenerator
from letta.constants import DEFAULT_MAX_STEPS
from letta.log import get_logger
@@ -10,6 +10,9 @@ from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
if TYPE_CHECKING:
from letta.schemas.letta_request import ClientToolSchema
class BaseAgentV2(ABC):
"""
@@ -42,9 +45,14 @@ class BaseAgentV2(ABC):
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
client_tools: list["ClientToolSchema"] | None = None,
) -> LettaResponse:
"""
Execute the agent loop in blocking mode, returning all messages at once.
Args:
client_tools: Optional list of client-side tools. When called, execution pauses
for client to provide tool returns.
"""
raise NotImplementedError
@@ -58,11 +66,17 @@ class BaseAgentV2(ABC):
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
conversation_id: str | None = None,
client_tools: list["ClientToolSchema"] | None = None,
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
"""
Execute the agent loop in streaming mode, yielding chunks as they become available.
If stream_tokens is True, individual tokens are streamed as they arrive from the LLM,
providing the lowest latency experience, otherwise each complete step (reasoning +
tool call + tool return) is yielded as it completes.
Args:
client_tools: Optional list of client-side tools. When called, execution pauses
for client to provide tool returns.
"""
raise NotImplementedError

View File

@@ -6,6 +6,7 @@ from uuid import UUID, uuid4
from letta.errors import PendingApprovalError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -131,16 +132,21 @@ async def _prepare_in_context_messages_no_persist_async(
message_manager: MessageManager,
actor: User,
run_id: Optional[str] = None,
conversation_id: Optional[str] = None,
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
When conversation_id is provided, messages are loaded from the conversation_messages
table instead of agent_state.message_ids.
Args:
input_messages (List[MessageCreate]): The new user input messages to process.
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
run_id (str): The run ID associated with this message processing.
conversation_id (str): Optional conversation ID to load messages from.
Returns:
Tuple[List[Message], List[Message]]: A tuple containing:
@@ -148,12 +154,74 @@ async def _prepare_in_context_messages_no_persist_async(
- The new in-context messages (messages created from the new input).
"""
if agent_state.message_buffer_autoclear:
# If autoclear is enabled, only include the most recent system message (usually at index 0)
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)]
if conversation_id:
# Conversation mode: load messages from conversation_messages table
from letta.services.conversation_manager import ConversationManager
conversation_manager = ConversationManager()
message_ids = await conversation_manager.get_message_ids_for_conversation(
conversation_id=conversation_id,
actor=actor,
)
if agent_state.message_buffer_autoclear and message_ids:
# If autoclear is enabled, only include the system message
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=message_ids[0], actor=actor)]
elif message_ids:
# Otherwise, include the full list of messages from the conversation
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
else:
# No messages in conversation yet - compile a new system message for this conversation
# Each conversation gets its own system message (captures memory state at conversation start)
from letta.prompts.prompt_generator import PromptGenerator
from letta.services.passage_manager import PassageManager
num_messages = await message_manager.size_async(actor=actor, agent_id=agent_state.id)
passage_manager = PassageManager()
num_archival_memories = await passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id)
system_message_str = await PromptGenerator.compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=get_utc_time(),
timezone=agent_state.timezone,
user_defined_variables=None,
append_icm_if_missing=True,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
sources=agent_state.sources,
max_files_open=agent_state.max_files_open,
)
system_message = Message.dict_to_message(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
openai_message_dict={"role": "system", "content": system_message_str},
)
# Persist the new system message
persisted_messages = await message_manager.create_many_messages_async([system_message], actor=actor)
system_message = persisted_messages[0]
# Add it to the conversation tracking
await conversation_manager.add_messages_to_conversation(
conversation_id=conversation_id,
agent_id=agent_state.id,
message_ids=[system_message.id],
actor=actor,
starting_position=0,
)
current_in_context_messages = [system_message]
else:
# Otherwise, include the full list of messages by ID for context
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
# Default mode: load messages from agent_state.message_ids
if agent_state.message_buffer_autoclear:
# If autoclear is enabled, only include the most recent system message (usually at index 0)
current_in_context_messages = [
await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
]
else:
# Otherwise, include the full list of messages by ID for context
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
# Check for approval-related message validation
if input_messages[0].type == "approval":

View File

@@ -34,6 +34,7 @@ from letta.schemas.agent import AgentState, UpdateAgent
from letta.schemas.enums import AgentType, MessageStreamStatus, RunStatus, StepStatus
from letta.schemas.letta_message import LettaMessage, MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate, MessageUpdate
@@ -173,6 +174,7 @@ class LettaAgentV2(BaseAgentV2):
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
client_tools: list[ClientToolSchema] | None = None,
) -> LettaResponse:
"""
Execute the agent loop in blocking mode, returning all messages at once.
@@ -184,6 +186,7 @@ class LettaAgentV2(BaseAgentV2):
use_assistant_message: Whether to use assistant message format
include_return_message_types: Filter for which message types to return
request_start_timestamp_ns: Start time for tracking request duration
client_tools: Optional list of client-side tools (not used in V2, for API compatibility)
Returns:
LettaResponse: Complete response with all messages and metadata
@@ -251,6 +254,8 @@ class LettaAgentV2(BaseAgentV2):
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
client_tools: list[ClientToolSchema] | None = None,
) -> AsyncGenerator[str, None]:
"""
Execute the agent loop in streaming mode, yielding chunks as they become available.
@@ -268,6 +273,7 @@ class LettaAgentV2(BaseAgentV2):
use_assistant_message: Whether to use assistant message format
include_return_message_types: Filter for which message types to return
request_start_timestamp_ns: Start time for tracking request duration
client_tools: Optional list of client-side tools (not used in V2, for API compatibility)
Yields:
str: JSON-formatted SSE data chunks for each completed step
@@ -1168,16 +1174,15 @@ class LettaAgentV2(BaseAgentV2):
"""
from letta.schemas.tool_execution_result import ToolExecutionResult
tool_name = target_tool.name
# Special memory case
# Check for None before accessing attributes
if not target_tool:
# TODO: fix this error message
return ToolExecutionResult(
func_return=f"Tool {tool_name} not found",
func_return="Tool not found",
status="error",
)
tool_name = target_tool.name
# TODO: This temp. Move this logic and code to executors
if agent_step_span:

View File

@@ -31,6 +31,7 @@ from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import ApprovalReturn, LettaErrorMessage, LettaMessage, MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
@@ -46,6 +47,7 @@ from letta.server.rest_api.utils import (
create_parallel_tool_messages_from_llm_response,
create_tool_returns_for_denials,
)
from letta.services.conversation_manager import ConversationManager
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
from letta.services.summarizer.summarizer_all import summarize_all
from letta.services.summarizer.summarizer_config import CompactionSettings
@@ -79,6 +81,10 @@ class LettaAgentV3(LettaAgentV2):
# affecting step-level telemetry.
self.context_token_estimate: int | None = None
self.in_context_messages: list[Message] = [] # in-memory tracker
# Conversation mode: when set, messages are tracked per-conversation
self.conversation_id: str | None = None
# Client-side tools passed in the request (executed by client, not server)
self.client_tools: list[ClientToolSchema] = []
def _compute_tool_return_truncation_chars(self) -> int:
"""Compute a dynamic cap for tool returns in requests.
@@ -101,6 +107,8 @@ class LettaAgentV3(LettaAgentV2):
use_assistant_message: bool = True, # NOTE: not used
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
) -> LettaResponse:
"""
Execute the agent loop in blocking mode, returning all messages at once.
@@ -112,16 +120,27 @@ class LettaAgentV3(LettaAgentV2):
use_assistant_message: Whether to use assistant message format
include_return_message_types: Filter for which message types to return
request_start_timestamp_ns: Start time for tracking request duration
conversation_id: Optional conversation ID for conversation-scoped messaging
client_tools: Optional list of client-side tools. When called, execution pauses
for client to provide tool returns.
Returns:
LettaResponse: Complete response with all messages and metadata
"""
self._initialize_state()
self.conversation_id = conversation_id
self.client_tools = client_tools or []
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
response_letta_messages = []
# Prepare in-context messages (conversation mode if conversation_id provided)
curr_in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor, run_id
input_messages,
self.agent_state,
self.message_manager,
self.actor,
run_id,
conversation_id=conversation_id,
)
follow_up_messages = []
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
@@ -234,6 +253,8 @@ class LettaAgentV3(LettaAgentV2):
use_assistant_message: bool = True, # NOTE: not used
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
) -> AsyncGenerator[str, None]:
"""
Execute the agent loop in streaming mode, yielding chunks as they become available.
@@ -251,11 +272,16 @@ class LettaAgentV3(LettaAgentV2):
use_assistant_message: Whether to use assistant message format
include_return_message_types: Filter for which message types to return
request_start_timestamp_ns: Start time for tracking request duration
conversation_id: Optional conversation ID for conversation-scoped messaging
client_tools: Optional list of client-side tools. When called, execution pauses
for client to provide tool returns.
Yields:
str: JSON-formatted SSE data chunks for each completed step
"""
self._initialize_state()
self.conversation_id = conversation_id
self.client_tools = client_tools or []
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
response_letta_messages = []
first_chunk = True
@@ -273,8 +299,14 @@ class LettaAgentV3(LettaAgentV2):
)
try:
# Prepare in-context messages (conversation mode if conversation_id provided)
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor, run_id
input_messages,
self.agent_state,
self.message_manager,
self.actor,
run_id,
conversation_id=conversation_id,
)
follow_up_messages = []
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
@@ -424,7 +456,7 @@ class LettaAgentV3(LettaAgentV2):
This handles:
- Persisting the new messages into the `messages` table
- Updating the in-memory trackers for in-context messages (`self.in_context_messages`) and agent state (`self.agent_state.message_ids`)
- Updating the DB with the current in-context messages (`self.agent_state.message_ids`)
- Updating the DB with the current in-context messages (`self.agent_state.message_ids`) OR conversation_messages table
Args:
run_id: The run ID to associate with the messages
@@ -446,14 +478,33 @@ class LettaAgentV3(LettaAgentV2):
template_id=self.agent_state.template_id,
)
# persist the in-context messages
# TODO: somehow make sure all the message ids are already persisted
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=[m.id for m in in_context_messages],
actor=self.actor,
)
self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state
if self.conversation_id:
# Conversation mode: update conversation_messages table
# Add new messages to conversation tracking
new_message_ids = [m.id for m in new_messages]
if new_message_ids:
await ConversationManager().add_messages_to_conversation(
conversation_id=self.conversation_id,
agent_id=self.agent_state.id,
message_ids=new_message_ids,
actor=self.actor,
)
# Update which messages are in context
await ConversationManager().update_in_context_messages(
conversation_id=self.conversation_id,
in_context_message_ids=[m.id for m in in_context_messages],
actor=self.actor,
)
else:
# Default mode: update agent.message_ids
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=[m.id for m in in_context_messages],
actor=self.actor,
)
self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state
self.in_context_messages = in_context_messages # update in-memory state
@trace_method
@@ -658,7 +709,8 @@ class LettaAgentV3(LettaAgentV2):
use_assistant_message=False, # NOTE: set to false
requires_approval_tools=self.tool_rules_solver.get_requires_approval_tools(
set([t["name"] for t in valid_tools])
),
)
+ [ct.name for ct in self.client_tools],
step_id=step_id,
actor=self.actor,
)
@@ -684,7 +736,7 @@ class LettaAgentV3(LettaAgentV2):
# checkpoint summarized messages
# TODO: might want to delay this checkpoint in case of corrupated state
try:
summary_message, messages = await self.compact(
summary_message, messages, _ = await self.compact(
messages, trigger_threshold=self.agent_state.llm_config.context_window
)
self.logger.info("Summarization succeeded, continuing to retry LLM request")
@@ -726,6 +778,15 @@ class LettaAgentV3(LettaAgentV2):
else:
tool_calls = []
# Enforce parallel_tool_calls=false by truncating to first tool call
# Some providers (e.g. Gemini) don't respect this setting via API, so we enforce it client-side
if len(tool_calls) > 1 and not self.agent_state.llm_config.parallel_tool_calls:
self.logger.warning(
f"LLM returned {len(tool_calls)} tool calls but parallel_tool_calls=false. "
f"Truncating to first tool call: {tool_calls[0].function.name}"
)
tool_calls = [tool_calls[0]]
# get the new generated `Message` objects from handling the LLM response
new_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
tool_calls=tool_calls,
@@ -795,7 +856,7 @@ class LettaAgentV3(LettaAgentV2):
self.logger.info(
f"Context window exceeded (current: {self.context_token_estimate}, threshold: {self.agent_state.llm_config.context_window}), trying to compact messages"
)
summary_message, messages = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window)
summary_message, messages, _ = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window)
# TODO: persist + return the summary message
# TODO: convert this to a SummaryMessage
self.response_messages.append(summary_message)
@@ -964,10 +1025,22 @@ class LettaAgentV3(LettaAgentV2):
messages_to_persist = (initial_messages or []) + assistant_message
return messages_to_persist, continue_stepping, stop_reason
# 2. Check whether tool call requires approval
# 2. Check whether tool call requires approval (includes client-side tools)
if not is_approval_response:
requested_tool_calls = [t for t in tool_calls if tool_rules_solver.is_requires_approval_tool(t.function.name)]
allowed_tool_calls = [t for t in tool_calls if not tool_rules_solver.is_requires_approval_tool(t.function.name)]
# Get names of client-side tools (these are executed by client, not server)
client_tool_names = {ct.name for ct in self.client_tools} if self.client_tools else set()
# Tools requiring approval: requires_approval tools OR client-side tools
requested_tool_calls = [
t
for t in tool_calls
if tool_rules_solver.is_requires_approval_tool(t.function.name) or t.function.name in client_tool_names
]
allowed_tool_calls = [
t
for t in tool_calls
if not tool_rules_solver.is_requires_approval_tool(t.function.name) and t.function.name not in client_tool_names
]
if requested_tool_calls:
approval_messages = create_approval_request_message_from_llm_response(
agent_id=self.agent_state.id,
@@ -1037,15 +1110,11 @@ class LettaAgentV3(LettaAgentV2):
# 5. Unified tool execution path (works for both single and multiple tools)
# 5a. Validate parallel tool calling constraints
if len(tool_calls) > 1:
# No parallel tool calls with tool rules
if self.agent_state.tool_rules and len([r for r in self.agent_state.tool_rules if r.type != "requires_approval"]) > 0:
raise ValueError(
"Parallel tool calling is not allowed when tool rules are present. Disable tool rules to use parallel tool calls."
)
# 5. Unified tool execution path (works for both single and multiple tools)
# Note: Parallel tool calling with tool rules is validated at agent create/update time.
# At runtime, we trust that if tool_rules exist, parallel_tool_calls=false is enforced earlier.
# 5b. Prepare execution specs for all tools
# 5a. Prepare execution specs for all tools
exec_specs = []
for tc in tool_calls:
call_id = tc.id or f"call_{uuid.uuid4().hex[:8]}"
@@ -1321,7 +1390,25 @@ class LettaAgentV3(LettaAgentV2):
last_function_response=self.last_function_response,
error_on_empty=False, # Return empty list instead of raising error
) or list(set(t.name for t in tools))
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
# Get client tool names to filter out server tools with same name (client tools override)
client_tool_names = {ct.name for ct in self.client_tools} if self.client_tools else set()
# Build allowed tools from server tools, excluding those overridden by client tools
allowed_tools = [
enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names) and t.name not in client_tool_names
]
# Merge client-side tools (use flat format matching enable_strict_mode output)
if self.client_tools:
for ct in self.client_tools:
client_tool_schema = {
"name": ct.name,
"description": ct.description,
"parameters": ct.parameters or {"type": "object", "properties": {}},
}
allowed_tools.append(client_tool_schema)
terminal_tool_names = {rule.tool_name for rule in self.tool_rules_solver.terminal_tool_rules}
allowed_tools = runtime_override_tool_json_schema(
tool_list=allowed_tools,
@@ -1332,7 +1419,9 @@ class LettaAgentV3(LettaAgentV2):
return allowed_tools
@trace_method
async def compact(self, messages, trigger_threshold: Optional[int] = None) -> Message:
async def compact(
self, messages, trigger_threshold: Optional[int] = None, compaction_settings: Optional["CompactionSettings"] = None
) -> tuple[Message, list[Message], str]:
"""Compact the current in-context messages for this agent.
Compaction uses a summarizer LLM configuration derived from
@@ -1341,9 +1430,11 @@ class LettaAgentV3(LettaAgentV2):
localized to summarization.
"""
# Use agent's compaction_settings if set, otherwise fall back to
# global defaults based on the agent's model handle.
if self.agent_state.compaction_settings is not None:
# Use the passed-in compaction_settings first, then agent's compaction_settings if set,
# otherwise fall back to global defaults based on the agent's model handle.
if compaction_settings is not None:
summarizer_config = compaction_settings
elif self.agent_state.compaction_settings is not None:
summarizer_config = self.agent_state.compaction_settings
else:
# Prefer the new handle field if set, otherwise derive from llm_config
@@ -1466,7 +1557,7 @@ class LettaAgentV3(LettaAgentV2):
if len(compacted_messages) > 1:
final_messages += compacted_messages[1:]
return summary_message_obj, final_messages
return summary_message_obj, final_messages, summary
@staticmethod
def _build_summarizer_llm_config(
@@ -1489,17 +1580,16 @@ class LettaAgentV3(LettaAgentV2):
# Parse provider/model from the handle, falling back to the agent's
# provider type when only a model name is given.
if "/" in summarizer_config.model:
provider, model_name = summarizer_config.model.split("/", 1)
if provider == "openai-proxy":
# fix for pydantic LLMConfig validation
provider = "openai"
provider_name, model_name = summarizer_config.model.split("/", 1)
else:
provider = agent_llm_config.model_endpoint_type
provider_name = agent_llm_config.provider_name
model_name = summarizer_config.model
# Start from the agent's config and override model + provider + handle
# Start from the agent's config and override model + provider_name + handle
# Note: model_endpoint_type is NOT overridden - the parsed provider_name
# is a custom label (e.g. "claude-pro-max"), not the endpoint type (e.g. "anthropic")
base = agent_llm_config.model_copy()
base.model_endpoint_type = provider
base.provider_name = provider_name
base.model = model_name
base.handle = summarizer_config.model

View File

@@ -8,6 +8,26 @@ LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
LETTA_MODEL_ENDPOINT = "https://inference.letta.com/v1/"
DEFAULT_TIMEZONE = "UTC"
# Provider ordering for model listing (matches original _enabled_providers list order)
PROVIDER_ORDER = {
"letta": 0,
"openai": 1,
"anthropic": 2,
"ollama": 3,
"google_ai": 4,
"google_vertex": 5,
"azure": 6,
"groq": 7,
"together": 8,
"vllm": 9,
"bedrock": 10,
"deepseek": 11,
"xai": 12,
"lmstudio": 13,
"zai": 14,
"openrouter": 15, # Note: OpenRouter uses OpenRouterProvider, not a ProviderType enum
}
ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OLLAMA_API_PREFIX = "/v1"
@@ -432,6 +452,10 @@ REDIS_SET_DEFAULT_VAL = "None"
REDIS_DEFAULT_CACHE_PREFIX = "letta_cache"
REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
# Conversation lock constants
CONVERSATION_LOCK_PREFIX = "conversation:lock:"
CONVERSATION_LOCK_TTL_SECONDS = 300 # 5 minutes
# TODO: This is temporary, eventually use token-based eviction
# File based controls
DEFAULT_MAX_FILES_OPEN = 5

View File

@@ -2,17 +2,20 @@ import asyncio
from functools import wraps
from typing import Any, Dict, List, Optional, Set, Union
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
from letta.constants import CONVERSATION_LOCK_PREFIX, CONVERSATION_LOCK_TTL_SECONDS, REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
from letta.errors import ConversationBusyError
from letta.log import get_logger
from letta.settings import settings
try:
from redis import RedisError
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.lock import Lock
except ImportError:
RedisError = None
Redis = None
ConnectionPool = None
Lock = None
logger = get_logger(__name__)
@@ -171,6 +174,62 @@ class AsyncRedisClient:
client = await self.get_client()
return await client.delete(*keys)
async def acquire_conversation_lock(
self,
conversation_id: str,
token: str,
) -> Optional["Lock"]:
"""
Acquire a distributed lock for a conversation.
Args:
conversation_id: The ID for the conversation
token: Unique identifier for the lock holder (for debugging/tracing)
Returns:
Lock object if acquired, raises ConversationBusyError if in use
"""
if Lock is None:
return None
client = await self.get_client()
lock_key = f"{CONVERSATION_LOCK_PREFIX}{conversation_id}"
lock = Lock(
client,
lock_key,
timeout=CONVERSATION_LOCK_TTL_SECONDS,
blocking=False,
thread_local=False, # We manage token explicitly
raise_on_release_error=False, # We handle release errors ourselves
)
if await lock.acquire(token=token):
return lock
lock_holder_token = await client.get(lock_key)
raise ConversationBusyError(
conversation_id=conversation_id,
lock_holder_token=lock_holder_token,
)
async def release_conversation_lock(self, conversation_id: str) -> bool:
"""
Release a conversation lock by conversation_id.
Args:
conversation_id: The conversation ID to release the lock for
Returns:
True if lock was released, False if release failed
"""
try:
client = await self.get_client()
lock_key = f"{CONVERSATION_LOCK_PREFIX}{conversation_id}"
await client.delete(lock_key)
return True
except Exception as e:
logger.warning(f"Failed to release conversation lock for conversation {conversation_id}: {e}")
return False
@with_retry()
async def exists(self, *keys: str) -> int:
"""Check if keys exist."""
@@ -395,6 +454,16 @@ class NoopAsyncRedisClient(AsyncRedisClient):
async def delete(self, *keys: str) -> int:
return 0
async def acquire_conversation_lock(
self,
conversation_id: str,
token: str,
) -> Optional["Lock"]:
return None
async def release_conversation_lock(self, conversation_id: str) -> bool:
return False
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
return False

View File

@@ -53,6 +53,42 @@ class PendingApprovalError(LettaError):
super().__init__(message=message, code=code, details=details)
class NoActiveRunsToCancelError(LettaError):
"""Error raised when attempting to cancel but there are no active runs to cancel."""
def __init__(self, agent_id: Optional[str] = None):
message = "No active runs to cancel"
if agent_id:
message = f"No active runs to cancel for agent {agent_id}"
details = {"error_code": "NO_ACTIVE_RUNS_TO_CANCEL", "agent_id": agent_id}
super().__init__(message=message, code=ErrorCode.CONFLICT, details=details)
class ConcurrentUpdateError(LettaError):
"""Error raised when a resource was updated by another transaction (optimistic locking conflict)."""
def __init__(self, resource_type: str, resource_id: str):
message = f"{resource_type} with id '{resource_id}' was updated by another transaction. Please retry your request."
details = {"error_code": "CONCURRENT_UPDATE", "resource_type": resource_type, "resource_id": resource_id}
super().__init__(message=message, code=ErrorCode.CONFLICT, details=details)
class ConversationBusyError(LettaError):
"""Error raised when attempting to send a message while another request is already processing for the same conversation."""
def __init__(self, conversation_id: str, lock_holder_token: Optional[str] = None):
self.conversation_id = conversation_id
self.lock_holder_token = lock_holder_token
message = "Cannot send a new message: Another request is currently being processed for this conversation. Please wait for the current request to complete."
code = ErrorCode.CONFLICT
details = {
"error_code": "CONVERSATION_BUSY",
"conversation_id": conversation_id,
"lock_holder_token": lock_holder_token,
}
super().__init__(message=message, code=code, details=details)
class LettaToolCreateError(LettaError):
"""Error raised when a tool cannot be created."""
@@ -90,6 +126,19 @@ class LettaConfigurationError(LettaError):
super().__init__(message=message, details={"missing_fields": self.missing_fields})
class EmbeddingConfigRequiredError(LettaError):
"""Error raised when an operation requires embedding_config but the agent doesn't have one configured."""
def __init__(self, agent_id: Optional[str] = None, operation: Optional[str] = None):
self.agent_id = agent_id
self.operation = operation
message = "This operation requires an embedding configuration, but the agent does not have one configured."
if operation:
message = f"Operation '{operation}' requires an embedding configuration, but the agent does not have one configured."
details = {"agent_id": agent_id, "operation": operation}
super().__init__(message=message, code=ErrorCode.INVALID_ARGUMENT, details=details)
class LettaAgentNotFoundError(LettaError):
"""Error raised when an agent is not found."""

View File

@@ -217,7 +217,7 @@ async def archival_memory_search(
top_k: Maximum number of results to return (default: 10)
Returns:
A list of relevant memories with timestamps and content, ranked by similarity.
A list of relevant memories with IDs, timestamps, and content, ranked by similarity.
Examples:
# Search for project discussions

View File

@@ -12,7 +12,9 @@ from letta.schemas.group import Group, ManagerType
from letta.schemas.job import JobUpdate
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.run import Run, RunUpdate
from letta.schemas.user import User
@@ -44,6 +46,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
use_assistant_message: bool = False,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
client_tools: list[ClientToolSchema] | None = None,
) -> LettaResponse:
self.run_ids = []
@@ -57,6 +60,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types,
request_start_timestamp_ns=request_start_timestamp_ns,
client_tools=client_tools,
)
await self.run_sleeptime_agents()
@@ -74,6 +78,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
client_tools: list[ClientToolSchema] | None = None,
) -> AsyncGenerator[str, None]:
self.run_ids = []
@@ -90,6 +95,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types,
request_start_timestamp_ns=request_start_timestamp_ns,
client_tools=client_tools,
):
yield chunk
finally:
@@ -214,6 +220,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
run_update = RunUpdate(
status=RunStatus.completed,
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
stop_reason=result.stop_reason.stop_reason if result.stop_reason else StopReasonType.end_turn,
metadata={
"result": result.model_dump(mode="json"),
"agent_id": sleeptime_agent_state.id,
@@ -225,6 +232,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
run_update = RunUpdate(
status=RunStatus.failed,
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
stop_reason=StopReasonType.error,
metadata={"error": str(e)},
)
await self.run_manager.update_run_by_id_async(run_id=run_id, update=run_update, actor=self.actor)

View File

@@ -12,7 +12,9 @@ from letta.schemas.group import Group, ManagerType
from letta.schemas.job import JobUpdate
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.run import Run, RunUpdate
from letta.schemas.user import User
@@ -44,6 +46,8 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
) -> LettaResponse:
self.run_ids = []
@@ -57,6 +61,8 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types,
request_start_timestamp_ns=request_start_timestamp_ns,
conversation_id=conversation_id,
client_tools=client_tools,
)
run_ids = await self.run_sleeptime_agents()
@@ -73,6 +79,8 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
) -> AsyncGenerator[str, None]:
self.run_ids = []
@@ -89,6 +97,8 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types,
request_start_timestamp_ns=request_start_timestamp_ns,
conversation_id=conversation_id,
client_tools=client_tools,
):
yield chunk
finally:
@@ -231,6 +241,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
run_update = RunUpdate(
status=RunStatus.completed,
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
stop_reason=result.stop_reason.stop_reason if result.stop_reason else StopReasonType.end_turn,
metadata={
"result": result.model_dump(mode="json"),
"agent_id": sleeptime_agent_state.id,
@@ -242,6 +253,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
run_update = RunUpdate(
status=RunStatus.failed,
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
stop_reason=StopReasonType.error,
metadata={"error": str(e)},
)
await self.run_manager.update_run_by_id_async(run_id=run_id, update=run_update, actor=self.actor)

View File

@@ -6,6 +6,10 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
from sqlalchemy import Dialect
from letta.functions.mcp_client.types import StdioServerConfig
from letta.helpers.json_helpers import sanitize_null_bytes
from letta.log import get_logger
logger = get_logger(__name__)
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType, ToolRuleType
from letta.schemas.letta_message import ApprovalReturn, MessageReturnType
@@ -184,16 +188,22 @@ def deserialize_tool_rule(
def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]) -> List[Dict]:
"""Convert a list of OpenAI ToolCall objects into JSON-serializable format."""
"""Convert a list of OpenAI ToolCall objects into JSON-serializable format.
Note: Tool call arguments may contain null bytes from various sources.
These are sanitized to prevent PostgreSQL errors.
"""
if not tool_calls:
return []
serialized_calls = []
for call in tool_calls:
if isinstance(call, OpenAIToolCall):
serialized_calls.append(call.model_dump(mode="json"))
# Sanitize null bytes from tool call data to prevent PostgreSQL errors
serialized_calls.append(sanitize_null_bytes(call.model_dump(mode="json")))
elif isinstance(call, dict):
serialized_calls.append(call) # Already a dictionary, leave it as-is
# Sanitize null bytes from dictionary data
serialized_calls.append(sanitize_null_bytes(call))
else:
raise TypeError(f"Unexpected tool call type: {type(call)}")
@@ -221,16 +231,22 @@ def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]:
def serialize_tool_returns(tool_returns: Optional[List[Union[ToolReturn, dict]]]) -> List[Dict]:
"""Convert a list of ToolReturn objects into JSON-serializable format."""
"""Convert a list of ToolReturn objects into JSON-serializable format.
Note: Tool returns may contain null bytes from sandbox execution or binary data.
These are sanitized to prevent PostgreSQL errors.
"""
if not tool_returns:
return []
serialized_tool_returns = []
for tool_return in tool_returns:
if isinstance(tool_return, ToolReturn):
serialized_tool_returns.append(tool_return.model_dump(mode="json"))
# Sanitize null bytes from tool return data to prevent PostgreSQL errors
serialized_tool_returns.append(sanitize_null_bytes(tool_return.model_dump(mode="json")))
elif isinstance(tool_return, dict):
serialized_tool_returns.append(tool_return) # Already a dictionary, leave it as-is
# Sanitize null bytes from dictionary data
serialized_tool_returns.append(sanitize_null_bytes(tool_return))
else:
raise TypeError(f"Unexpected tool return type: {type(tool_return)}")
@@ -256,18 +272,24 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolReturn, dict]]]) -> List[Dict]:
"""Convert a list of ToolReturn objects into JSON-serializable format."""
"""Convert a list of ToolReturn objects into JSON-serializable format.
Note: Approval data may contain null bytes from various sources.
These are sanitized to prevent PostgreSQL errors.
"""
if not approvals:
return []
serialized_approvals = []
for approval in approvals:
if isinstance(approval, ApprovalReturn):
serialized_approvals.append(approval.model_dump(mode="json"))
# Sanitize null bytes from approval data to prevent PostgreSQL errors
serialized_approvals.append(sanitize_null_bytes(approval.model_dump(mode="json")))
elif isinstance(approval, ToolReturn):
serialized_approvals.append(approval.model_dump(mode="json"))
serialized_approvals.append(sanitize_null_bytes(approval.model_dump(mode="json")))
elif isinstance(approval, dict):
serialized_approvals.append(approval) # Already a dictionary, leave it as-is
# Sanitize null bytes from dictionary data
serialized_approvals.append(sanitize_null_bytes(approval))
else:
raise TypeError(f"Unexpected approval type: {type(approval)}")
@@ -318,7 +340,11 @@ def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalRetu
def serialize_message_content(message_content: Optional[List[Union[MessageContent, dict]]]) -> List[Dict]:
"""Convert a list of MessageContent objects into JSON-serializable format."""
"""Convert a list of MessageContent objects into JSON-serializable format.
Note: Message content may contain null bytes from various sources.
These are sanitized to prevent PostgreSQL errors.
"""
if not message_content:
return []
@@ -327,9 +353,11 @@ def serialize_message_content(message_content: Optional[List[Union[MessageConten
if isinstance(content, MessageContent):
if content.type == MessageContentType.image:
assert content.source.type == ImageSourceType.letta, f"Invalid image source type: {content.source.type}"
serialized_message_content.append(content.model_dump(mode="json"))
# Sanitize null bytes from message content to prevent PostgreSQL errors
serialized_message_content.append(sanitize_null_bytes(content.model_dump(mode="json")))
elif isinstance(content, dict):
serialized_message_content.append(content) # Already a dictionary, leave it as-is
# Sanitize null bytes from dictionary data
serialized_message_content.append(sanitize_null_bytes(content))
else:
raise TypeError(f"Unexpected message content type: {type(content)}")
return serialized_message_content

View File

@@ -1,16 +1,23 @@
import asyncio
import base64
import hashlib
import os
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Optional
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from letta.settings import settings
# Eagerly load the cryptography backend at module import time.
_CRYPTO_BACKEND = default_backend()
# Dedicated thread pool for CPU-intensive crypto operations
# Prevents crypto from blocking health checks and other operations
_crypto_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="CryptoWorker")
# Common API key prefixes that should not be considered encrypted
# These are plaintext credentials that happen to be long strings
PLAINTEXT_PREFIXES = (
@@ -46,6 +53,11 @@ class CryptoUtils:
# Salt size for key derivation
SALT_SIZE = 16
# WARNING: DO NOT CHANGE THIS VALUE UNLESS YOU ARE SURE WHAT YOU ARE DOING
# EXISTING ENCRYPTED SECRETS MUST BE DECRYPTED WITH THE SAME ITERATIONS
# Number of PBKDF2 iterations
PBKDF2_ITERATIONS = 100000
@classmethod
@lru_cache(maxsize=256)
def _derive_key_cached(cls, master_key: str, salt: bytes) -> bytes:
@@ -55,11 +67,19 @@ class CryptoUtils:
This is a CPU-intensive operation (100k iterations of PBKDF2-HMAC-SHA256)
that can take 100-500ms. Results are cached since key derivation is deterministic.
Uses Python's standard hashlib.pbkdf2_hmac which produces identical output
to the cryptography library's PBKDF2HMAC for the same parameters.
WARNING: This is a synchronous blocking operation. Use _derive_key_async()
in async contexts to avoid blocking the event loop.
"""
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=cls.KEY_SIZE, salt=salt, iterations=100000, backend=default_backend())
return kdf.derive(master_key.encode())
return hashlib.pbkdf2_hmac(
hash_name="sha256",
password=master_key.encode(),
salt=salt,
iterations=cls.PBKDF2_ITERATIONS,
dklen=cls.KEY_SIZE,
)
@classmethod
def _derive_key(cls, master_key: str, salt: bytes) -> bytes:
@@ -69,13 +89,16 @@ class CryptoUtils:
@classmethod
async def _derive_key_async(cls, master_key: str, salt: bytes) -> bytes:
"""
Async version of _derive_key that runs PBKDF2 in a thread pool.
Async version of _derive_key that runs PBKDF2 in a dedicated thread pool.
This prevents PBKDF2 (a CPU-intensive operation) from blocking the event loop.
PBKDF2 with 100k iterations typically takes 100-500ms, which would freeze
the event loop and prevent all other requests from being processed.
Uses a dedicated crypto thread pool (8 workers) to prevent PBKDF2 operations
from exhausting the default ThreadPoolExecutor (16 threads) and blocking
health checks and other operations during high load.
PBKDF2 with 100k iterations typically takes 100-500ms per operation.
"""
return await asyncio.to_thread(cls._derive_key, master_key, salt)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(_crypto_executor, cls._derive_key, master_key, salt)
@classmethod
def encrypt(cls, plaintext: str, master_key: Optional[str] = None) -> str:
@@ -111,7 +134,7 @@ class CryptoUtils:
key = cls._derive_key(master_key, salt)
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=_CRYPTO_BACKEND)
encryptor = cipher.encryptor()
# Encrypt the plaintext
@@ -160,7 +183,7 @@ class CryptoUtils:
key = await cls._derive_key_async(master_key, salt)
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=_CRYPTO_BACKEND)
encryptor = cipher.encryptor()
# Encrypt the plaintext
@@ -215,7 +238,7 @@ class CryptoUtils:
key = cls._derive_key(master_key, salt)
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend())
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=_CRYPTO_BACKEND)
decryptor = cipher.decryptor()
# Decrypt the ciphertext
@@ -266,7 +289,7 @@ class CryptoUtils:
key = await cls._derive_key_async(master_key, salt)
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend())
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=_CRYPTO_BACKEND)
decryptor = cipher.decryptor()
# Decrypt the ciphertext

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.log import get_logger
from letta.otel.tracing import tracer
from letta.plugins.plugins import get_experimental_checker
from letta.settings import settings
@@ -109,35 +110,59 @@ def async_redis_cache(
@wraps(func)
async def async_wrapper(*args, **kwargs):
redis_client = await get_redis_client()
with tracer.start_as_current_span("redis_cache", attributes={"cache.function": func.__name__}) as span:
# 1. Get Redis client
with tracer.start_as_current_span("redis_cache.get_client"):
redis_client = await get_redis_client()
# Don't bother going through other operations for no reason.
if isinstance(redis_client, NoopAsyncRedisClient):
return await func(*args, **kwargs)
cache_key = get_cache_key(*args, **kwargs)
cached_value = await redis_client.get(cache_key)
# Don't bother going through other operations for no reason.
if isinstance(redis_client, NoopAsyncRedisClient):
span.set_attribute("cache.noop", True)
return await func(*args, **kwargs)
try:
if cached_value is not None:
stats.hits += 1
if model_class:
return model_class.model_validate_json(cached_value)
return json.loads(cached_value)
except Exception as e:
logger.warning(f"Failed to retrieve value from cache: {e}")
cache_key = get_cache_key(*args, **kwargs)
span.set_attribute("cache.key", cache_key)
stats.misses += 1
result = await func(*args, **kwargs)
try:
if model_class:
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
elif isinstance(result, (dict, list, str, int, float, bool)):
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
else:
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return result
# 2. Try cache read
with tracer.start_as_current_span("redis_cache.get") as get_span:
cached_value = await redis_client.get(cache_key)
get_span.set_attribute("cache.hit", cached_value is not None)
try:
if cached_value is not None:
stats.hits += 1
span.set_attribute("cache.result", "hit")
# 3. Deserialize cache hit
with tracer.start_as_current_span("redis_cache.deserialize"):
if model_class:
return model_class.model_validate_json(cached_value)
return json.loads(cached_value)
except Exception as e:
logger.warning(f"Failed to retrieve value from cache: {e}")
span.record_exception(e)
stats.misses += 1
span.set_attribute("cache.result", "miss")
# 4. Call original function
with tracer.start_as_current_span("redis_cache.call_original"):
result = await func(*args, **kwargs)
# 5. Write to cache
try:
with tracer.start_as_current_span("redis_cache.set") as set_span:
if model_class:
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
elif isinstance(result, (dict, list, str, int, float, bool)):
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
else:
set_span.set_attribute("cache.set_skipped", True)
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
span.record_exception(e)
return result
async def invalidate(*args, **kwargs) -> bool:
stats.invalidations += 1

View File

@@ -1,6 +1,42 @@
import base64
import json
from datetime import datetime
from typing import Any
def sanitize_null_bytes(value: Any) -> Any:
"""Recursively remove null bytes (0x00) from strings.
PostgreSQL TEXT columns don't accept null bytes in UTF-8 encoding, which causes
asyncpg.exceptions.CharacterNotInRepertoireError when data with null bytes is inserted.
This function sanitizes:
- Strings: removes all null bytes
- Dicts: recursively sanitizes all string values
- Lists: recursively sanitizes all elements
- Other types: returned as-is
Args:
value: The value to sanitize
Returns:
The sanitized value with null bytes removed from all strings
"""
if isinstance(value, str):
# Remove null bytes from strings
return value.replace("\x00", "")
elif isinstance(value, dict):
# Recursively sanitize dictionary keys and values
return {sanitize_null_bytes(k): sanitize_null_bytes(v) for k, v in value.items()}
elif isinstance(value, list):
# Recursively sanitize list elements
return [sanitize_null_bytes(item) for item in value]
elif isinstance(value, tuple):
# Recursively sanitize tuple elements (return as tuple)
return tuple(sanitize_null_bytes(item) for item in value)
else:
# Return other types as-is (int, float, bool, None, etc.)
return value
def json_loads(data):
@@ -8,15 +44,33 @@ def json_loads(data):
def json_dumps(data, indent=2) -> str:
"""Serialize data to JSON string, sanitizing null bytes to prevent PostgreSQL errors.
PostgreSQL TEXT columns reject null bytes (0x00) in UTF-8 encoding. This function
sanitizes all strings in the data structure before JSON serialization to prevent
asyncpg.exceptions.CharacterNotInRepertoireError.
Args:
data: The data to serialize
indent: JSON indentation level (default: 2)
Returns:
JSON string with null bytes removed from all string values
"""
# Sanitize null bytes before serialization to prevent PostgreSQL errors
sanitized_data = sanitize_null_bytes(data)
def safe_serializer(obj):
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, bytes):
try:
return obj.decode("utf-8")
decoded = obj.decode("utf-8")
# Also sanitize decoded bytes
return decoded.replace("\x00", "")
except Exception:
# TODO: this is to handle Gemini thought signatures, b64 decode this back to bytes when sending back to Gemini
return base64.b64encode(obj).decode("utf-8")
raise TypeError(f"Type {type(obj)} not serializable")
return json.dumps(data, indent=indent, default=safe_serializer, ensure_ascii=False)
return json.dumps(sanitized_data, indent=indent, default=safe_serializer, ensure_ascii=False)

View File

@@ -12,23 +12,52 @@ from letta.schemas.letta_message_content import Base64Image, ImageContent, Image
from letta.schemas.message import Message, MessageCreate
async def _fetch_image_from_url(url: str) -> tuple[bytes, str | None]:
async def _fetch_image_from_url(url: str, max_retries: int = 1, timeout_seconds: float = 5.0) -> tuple[bytes, str | None]:
"""
Async helper to fetch image from URL without blocking the event loop.
Retries once on timeout to handle transient network issues.
Args:
url: URL of the image to fetch
max_retries: Number of retry attempts (default: 1)
timeout_seconds: Total timeout in seconds (default: 5.0)
Returns:
Tuple of (image_bytes, media_type)
Raises:
LettaImageFetchError: If image fetch fails after all retries
"""
timeout = httpx.Timeout(15.0, connect=5.0)
# Connect timeout is half of total timeout, capped at 3 seconds
connect_timeout = min(timeout_seconds / 2, 3.0)
timeout = httpx.Timeout(timeout_seconds, connect=connect_timeout)
headers = {"User-Agent": f"Letta/{__version__}"}
try:
async with httpx.AsyncClient(timeout=timeout, headers=headers) as client:
image_response = await client.get(url, follow_redirects=True)
image_response.raise_for_status()
image_bytes = image_response.content
image_media_type = image_response.headers.get("content-type")
return image_bytes, image_media_type
except (httpx.RemoteProtocolError, httpx.TimeoutException, httpx.HTTPStatusError) as e:
raise LettaImageFetchError(url=url, reason=str(e))
except Exception as e:
raise LettaImageFetchError(url=url, reason=f"Unexpected error: {e}")
last_exception = None
for attempt in range(max_retries + 1):
try:
async with httpx.AsyncClient(timeout=timeout, headers=headers) as client:
image_response = await client.get(url, follow_redirects=True)
image_response.raise_for_status()
image_bytes = image_response.content
image_media_type = image_response.headers.get("content-type")
return image_bytes, image_media_type
except httpx.TimeoutException as e:
last_exception = e
if attempt < max_retries:
# Brief delay before retry
await asyncio.sleep(0.5)
continue
# Final attempt failed
raise LettaImageFetchError(url=url, reason=f"Timeout after {max_retries + 1} attempts: {e}")
except (httpx.RemoteProtocolError, httpx.HTTPStatusError) as e:
# Don't retry on protocol errors or HTTP errors (4xx, 5xx)
raise LettaImageFetchError(url=url, reason=str(e))
except Exception as e:
raise LettaImageFetchError(url=url, reason=f"Unexpected error: {e}")
# Should never reach here, but just in case
raise LettaImageFetchError(url=url, reason=f"Failed after {max_retries + 1} attempts: {last_exception}")
async def convert_message_creates_to_messages(

View File

@@ -400,6 +400,7 @@ class TurbopufferClient:
created_ats: List[datetime],
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_ids: Optional[List[Optional[str]]] = None,
) -> bool:
"""Insert messages into Turbopuffer.
@@ -413,6 +414,7 @@ class TurbopufferClient:
created_ats: List of creation timestamps for each message
project_id: Optional project ID for all messages
template_id: Optional template ID for all messages
conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts)
Returns:
True if successful
@@ -441,22 +443,26 @@ class TurbopufferClient:
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
if len(message_ids) != len(created_ats):
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids = []
agent_ids = []
organization_ids_list = []
agent_ids_list = []
message_roles = []
created_at_timestamps = []
project_ids = []
template_ids = []
project_ids_list = []
template_ids_list = []
conversation_ids_list = []
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
message_id = message_ids[original_idx]
role = roles[original_idx]
created_at = created_ats[original_idx]
conversation_id = conversation_ids[original_idx] if conversation_ids else None
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
@@ -470,31 +476,36 @@ class TurbopufferClient:
ids.append(message_id)
vectors.append(embedding)
texts.append(text)
organization_ids.append(organization_id)
agent_ids.append(agent_id)
organization_ids_list.append(organization_id)
agent_ids_list.append(agent_id)
message_roles.append(role.value)
created_at_timestamps.append(timestamp)
project_ids.append(project_id)
template_ids.append(template_id)
project_ids_list.append(project_id)
template_ids_list.append(template_id)
conversation_ids_list.append(conversation_id)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids,
"agent_id": agent_ids,
"organization_id": organization_ids_list,
"agent_id": agent_ids_list,
"role": message_roles,
"created_at": created_at_timestamps,
}
# only include conversation_id if it's provided
if conversation_ids is not None:
upsert_columns["conversation_id"] = conversation_ids_list
# only include project_id if it's provided
if project_id is not None:
upsert_columns["project_id"] = project_ids
upsert_columns["project_id"] = project_ids_list
# only include template_id if it's provided
if template_id is not None:
upsert_columns["template_id"] = template_ids
upsert_columns["template_id"] = template_ids_list
try:
# use global semaphore to limit concurrent Turbopuffer writes
@@ -506,7 +517,10 @@ class TurbopufferClient:
await namespace.write(
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={"text": {"type": "string", "full_text_search": True}},
schema={
"text": {"type": "string", "full_text_search": True},
"conversation_id": {"type": "string"},
},
)
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
return True
@@ -561,67 +575,80 @@ class TurbopufferClient:
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
if search_mode == "timestamp":
# retrieve most recent items by timestamp
query_params = {
"rank_by": ("created_at", "desc"),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
if search_mode == "timestamp":
# retrieve most recent items by timestamp
query_params = {
"rank_by": ("created_at", "desc"),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
elif search_mode == "vector":
# vector search query
query_params = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
elif search_mode == "vector":
# vector search query
query_params = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
elif search_mode == "fts":
# full-text search query
query_params = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
elif search_mode == "fts":
# full-text search query
query_params = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
else: # hybrid mode
queries = []
else: # hybrid mode
queries = []
# vector search query
vector_query = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
vector_query["filters"] = filters
queries.append(vector_query)
# vector search query
vector_query = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
vector_query["filters"] = filters
queries.append(vector_query)
# full-text search query
fts_query = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
fts_query["filters"] = filters
queries.append(fts_query)
# full-text search query
fts_query = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
fts_query["filters"] = filters
queries.append(fts_query)
# execute multi-query
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
# execute multi-query
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
except Exception as e:
# Wrap turbopuffer errors with user-friendly messages
from turbopuffer import NotFoundError
if isinstance(e, NotFoundError):
# Extract just the error message without implementation details
error_msg = str(e)
if "namespace" in error_msg.lower() and "not found" in error_msg.lower():
raise ValueError("No conversation history found. Please send a message first to enable search.") from e
raise ValueError(f"Search data not found: {error_msg}") from e
# Re-raise other errors as-is
raise
@trace_method
async def query_passages(
@@ -779,6 +806,7 @@ class TurbopufferClient:
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
@@ -796,6 +824,7 @@ class TurbopufferClient:
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
conversation_id: Optional conversation ID to filter messages by (use "default" for NULL)
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
@@ -862,6 +891,19 @@ class TurbopufferClient:
if template_id:
template_filter = ("template_id", "Eq", template_id)
# build conversation_id filter if provided
# three cases:
# 1. conversation_id=None (omitted) -> return all messages (no filter)
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
# 3. conversation_id="xyz" -> return only messages in that conversation
conversation_filter = None
if conversation_id == "default":
# "default" is reserved for default messages only (conversation_id is none)
conversation_filter = ("conversation_id", "Eq", None)
elif conversation_id is not None:
# Specific conversation
conversation_filter = ("conversation_id", "Eq", conversation_id)
# combine all filters
all_filters = [agent_filter] # always include agent_id filter
if role_filter:
@@ -870,6 +912,8 @@ class TurbopufferClient:
all_filters.append(project_filter)
if template_filter:
all_filters.append(template_filter)
if conversation_filter:
all_filters.append(conversation_filter)
if date_filters:
all_filters.extend(date_filters)
@@ -888,7 +932,7 @@ class TurbopufferClient:
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
include_attributes=True,
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
@@ -939,6 +983,7 @@ class TurbopufferClient:
agent_id: Optional[str] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
@@ -956,6 +1001,10 @@ class TurbopufferClient:
agent_id: Optional agent ID to filter messages by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
conversation_id: Optional conversation ID to filter messages by. Special values:
- None (omitted): Return all messages
- "default": Return only default messages (conversation_id IS NULL)
- Any other value: Return messages in that specific conversation
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
@@ -1004,6 +1053,18 @@ class TurbopufferClient:
if template_id:
all_filters.append(("template_id", "Eq", template_id))
# conversation filter
# three cases:
# 1. conversation_id=None (omitted) -> return all messages (no filter)
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
# 3. conversation_id="xyz" -> return only messages in that conversation
if conversation_id == "default":
# "default" is reserved for default messages only (conversation_id is none)
all_filters.append(("conversation_id", "Eq", None))
elif conversation_id is not None:
# Specific conversation
all_filters.append(("conversation_id", "Eq", conversation_id))
# date filters
if start_date:
# Convert to UTC to match stored timestamps
@@ -1036,7 +1097,7 @@ class TurbopufferClient:
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
include_attributes=True,
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
@@ -1121,6 +1182,7 @@ class TurbopufferClient:
"agent_id": getattr(row, "agent_id", None),
"role": getattr(row, "role", None),
"created_at": getattr(row, "created_at", None),
"conversation_id": getattr(row, "conversation_id", None),
}
messages.append(message_dict)

View File

@@ -88,6 +88,7 @@ class SimpleAnthropicStreamingInterface:
self.tool_call_name = None
self.accumulated_tool_call_args = ""
self.previous_parse = {}
self.thinking_signature = None
# usage trackers
self.input_tokens = 0
@@ -426,20 +427,23 @@ class SimpleAnthropicStreamingInterface:
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
)
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
source="reasoner_model",
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
# Only emit reasoning message if we have actual content
if delta.thinking and delta.thinking.strip():
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
source="reasoner_model",
reasoning=delta.thinking,
signature=self.thinking_signature,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
elif isinstance(delta, BetaSignatureDelta):
# Safety check
@@ -448,21 +452,15 @@ class SimpleAnthropicStreamingInterface:
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
)
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
source="reasoner_model",
reasoning="",
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
# Store signature but don't emit empty reasoning message
# Signature will be attached when actual thinking content arrives
self.thinking_signature = delta.signature
# Update the last reasoning message with the signature so it gets persisted
if self.reasoning_messages:
last_msg = self.reasoning_messages[-1]
if isinstance(last_msg, ReasoningMessage):
last_msg.signature = delta.signature
elif isinstance(event, BetaRawMessageStartEvent):
self.message_id = event.message.id

View File

@@ -224,40 +224,32 @@ class SimpleGeminiStreamingInterface:
# NOTE: the thought_signature comes on the Part with the function_call
thought_signature = part.thought_signature
self.thinking_signature = base64.b64encode(thought_signature).decode("utf-8")
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
yield ReasoningMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
source="reasoner_model",
reasoning="",
signature=self.thinking_signature,
)
prev_message_type = "reasoning_message"
# Don't emit empty reasoning message - signature will be attached to actual reasoning content
# Thinking summary content part (bool means text is thought part)
if part.thought:
reasoning_summary = part.text
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
yield ReasoningMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
source="reasoner_model",
reasoning=reasoning_summary,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "reasoning_message"
self.content_parts.append(
ReasoningContent(
is_native=True,
# Only emit reasoning message if we have actual content
if reasoning_summary and reasoning_summary.strip():
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
yield ReasoningMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
source="reasoner_model",
reasoning=reasoning_summary,
signature=self.thinking_signature,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "reasoning_message"
self.content_parts.append(
ReasoningContent(
is_native=True,
reasoning=reasoning_summary,
signature=self.thinking_signature,
)
)
)
# Plain text content part
elif part.text:

View File

@@ -5,6 +5,7 @@ import re
from typing import Dict, List, Optional, Union
import anthropic
import httpx
from anthropic import AsyncStream
from anthropic.types.beta import BetaMessage as AnthropicMessage, BetaRawMessageStreamEvent
from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming
@@ -28,6 +29,7 @@ from letta.errors import (
)
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.decorators import deprecated
from letta.llm_api.anthropic_constants import ANTHROPIC_MAX_STRICT_TOOLS, ANTHROPIC_STRICT_MODE_ALLOWLIST
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
@@ -81,9 +83,17 @@ class AnthropicClient(LLMClientBase):
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
betas.append("context-management-2025-06-27")
# Structured outputs beta
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
betas.append("structured-outputs-2025-11-13")
# Structured outputs beta - only for supported models
# Supported: Claude Sonnet 4.5, Opus 4.1, Opus 4.5, Haiku 4.5
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
# See PR #7495 for original implementation
# supports_structured_outputs = _supports_structured_outputs(llm_config.model)
#
# if supports_structured_outputs:
# # Always enable structured outputs beta on supported models.
# # NOTE: We do NOT send `strict` on tool schemas because the current Anthropic SDK
# # typed tool params reject unknown fields (e.g., `tools.0.custom.strict`).
# betas.append("structured-outputs-2025-11-13")
if betas:
response = client.beta.messages.create(**request_data, betas=betas)
@@ -94,7 +104,6 @@ class AnthropicClient(LLMClientBase):
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = await self._get_anthropic_client_async(llm_config, async_client=True)
betas: list[str] = []
# interleaved thinking for reasoner
if llm_config.enable_reasoner:
@@ -119,9 +128,13 @@ class AnthropicClient(LLMClientBase):
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
betas.append("context-management-2025-06-27")
# Structured outputs beta
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
betas.append("structured-outputs-2025-11-13")
# Structured outputs beta - only for supported models
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
# See PR #7495 for original implementation
# supports_structured_outputs = _supports_structured_outputs(llm_config.model)
#
# if supports_structured_outputs:
# betas.append("structured-outputs-2025-11-13")
if betas:
response = await client.beta.messages.create(**request_data, betas=betas)
@@ -164,9 +177,13 @@ class AnthropicClient(LLMClientBase):
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
betas.append("context-management-2025-06-27")
# Structured outputs beta
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
betas.append("structured-outputs-2025-11-13")
# Structured outputs beta - only for supported models
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
# See PR #7495 for original implementation
# supports_structured_outputs = _supports_structured_outputs(llm_config.model)
#
# if supports_structured_outputs:
# betas.append("structured-outputs-2025-11-13")
# log failed requests
try:
@@ -234,17 +251,35 @@ class AnthropicClient(LLMClientBase):
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
api_key, _, _ = self.get_byok_overrides(llm_config)
# For claude-pro-max provider, use OAuth Bearer token instead of api_key
is_oauth_provider = llm_config.provider_name == "claude-pro-max"
if async_client:
return (
anthropic.AsyncAnthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
if api_key
else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
)
return (
anthropic.Anthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
if api_key
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
)
if api_key:
if is_oauth_provider:
return anthropic.AsyncAnthropic(
max_retries=model_settings.anthropic_max_retries,
default_headers={
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
"anthropic-beta": "oauth-2025-04-20",
},
)
return anthropic.AsyncAnthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
return anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
if api_key:
if is_oauth_provider:
return anthropic.Anthropic(
max_retries=model_settings.anthropic_max_retries,
default_headers={
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
"anthropic-beta": "oauth-2025-04-20",
},
)
return anthropic.Anthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
return anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
@trace_method
async def _get_anthropic_client_async(
@@ -252,17 +287,35 @@ class AnthropicClient(LLMClientBase):
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
# For claude-pro-max provider, use OAuth Bearer token instead of api_key
is_oauth_provider = llm_config.provider_name == "claude-pro-max"
if async_client:
return (
anthropic.AsyncAnthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
if api_key
else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
)
return (
anthropic.Anthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
if api_key
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
)
if api_key:
if is_oauth_provider:
return anthropic.AsyncAnthropic(
max_retries=model_settings.anthropic_max_retries,
default_headers={
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
"anthropic-beta": "oauth-2025-04-20",
},
)
return anthropic.AsyncAnthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
return anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
if api_key:
if is_oauth_provider:
return anthropic.Anthropic(
max_retries=model_settings.anthropic_max_retries,
default_headers={
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
"anthropic-beta": "oauth-2025-04-20",
},
)
return anthropic.Anthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
return anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
@trace_method
def build_request_data(
@@ -331,11 +384,13 @@ class AnthropicClient(LLMClientBase):
}
# Structured outputs via response_format
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
data["output_format"] = {
"type": "json_schema",
"schema": llm_config.response_format.json_schema["schema"],
}
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
# See PR #7495 for original implementation
# if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
# data["output_format"] = {
# "type": "json_schema",
# "schema": llm_config.response_format.json_schema["schema"],
# }
# Tools
# For an overview on tool choice:
@@ -385,7 +440,12 @@ class AnthropicClient(LLMClientBase):
if tools_for_request and len(tools_for_request) > 0:
# TODO eventually enable parallel tool use
data["tools"] = convert_tools_to_anthropic_format(tools_for_request)
# DISABLED: use_strict=False to disable structured outputs (TTFT latency impact)
# See PR #7495 for original implementation
data["tools"] = convert_tools_to_anthropic_format(
tools_for_request,
use_strict=False, # Was: _supports_structured_outputs(llm_config.model)
)
# Add cache control to the last tool for caching tool definitions
if len(data["tools"]) > 0:
data["tools"][-1]["cache_control"] = {"type": "ephemeral"}
@@ -522,13 +582,18 @@ class AnthropicClient(LLMClientBase):
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
logging.getLogger("httpx").setLevel(logging.WARNING)
# Use the default client; token counting is lightweight and does not require BYOK overrides
client = anthropic.AsyncAnthropic()
if messages and len(messages) == 0:
messages = None
if tools and len(tools) > 0:
anthropic_tools = convert_tools_to_anthropic_format(tools)
# Token counting endpoint requires additionalProperties: false (use_strict=True)
# but does NOT support the `strict` field on tools (add_strict_field=False)
anthropic_tools = convert_tools_to_anthropic_format(
tools,
use_strict=True,
add_strict_field=False,
)
else:
anthropic_tools = None
@@ -637,6 +702,12 @@ class AnthropicClient(LLMClientBase):
if thinking_enabled:
betas.append("context-management-2025-06-27")
# Structured outputs beta - only for supported models
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
# See PR #7495 for original implementation
# if model and _supports_structured_outputs(model):
# betas.append("structured-outputs-2025-11-13")
if betas:
result = await client.beta.messages.count_tokens(**count_params, betas=betas)
else:
@@ -669,6 +740,8 @@ class AnthropicClient(LLMClientBase):
or "exceeds context" in error_str
or "too many total text bytes" in error_str
or "total text bytes" in error_str
or "request_too_large" in error_str
or "request exceeds the maximum size" in error_str
):
logger.warning(f"[Anthropic] Context window exceeded: {str(e)}")
return ContextWindowExceededError(
@@ -691,6 +764,27 @@ class AnthropicClient(LLMClientBase):
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
# Handle httpx.RemoteProtocolError which can occur during streaming
# when the remote server closes the connection unexpectedly
# (e.g., "peer closed connection without sending complete message body")
if isinstance(e, httpx.RemoteProtocolError):
logger.warning(f"[Anthropic] Remote protocol error during streaming: {e}")
return LLMConnectionError(
message=f"Connection error during Anthropic streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
# Handle httpx network errors which can occur during streaming
# when the connection is unexpectedly closed while reading/writing
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
logger.warning(f"[Anthropic] Network error during streaming: {type(e).__name__}: {e}")
return LLMConnectionError(
message=f"Network error during Anthropic streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
)
if isinstance(e, anthropic.RateLimitError):
logger.warning("[Anthropic] Rate limited (429). Consider backoff.")
return LLMRateLimitError(
@@ -750,6 +844,12 @@ class AnthropicClient(LLMClientBase):
if isinstance(e, anthropic.APIStatusError):
logger.warning(f"[Anthropic] API status error: {str(e)}")
# Handle 413 Request Entity Too Large - request payload exceeds size limits
if hasattr(e, "status_code") and e.status_code == 413:
logger.warning(f"[Anthropic] Request too large (413): {str(e)}")
return ContextWindowExceededError(
message=f"Request too large for Anthropic (413): {str(e)}",
)
if "overloaded" in str(e).lower():
return LLMProviderOverloaded(
message=f"Anthropic API is overloaded: {str(e)}",
@@ -827,7 +927,13 @@ class AnthropicClient(LLMClientBase):
if content_part.type == "tool_use":
# hack for incorrect tool format
tool_input = json.loads(json.dumps(content_part.input))
if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input:
# Check if id is a string before calling startswith (sometimes it's an int)
if (
"id" in tool_input
and isinstance(tool_input["id"], str)
and tool_input["id"].startswith("toolu_")
and "function" in tool_input
):
if isinstance(tool_input["function"], str):
tool_input["function"] = json.loads(tool_input["function"])
arguments = json.dumps(tool_input["function"]["arguments"], indent=2)
@@ -964,7 +1070,34 @@ class AnthropicClient(LLMClientBase):
return messages
def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
def _supports_structured_outputs(model: str) -> bool:
"""Check if the model supports structured outputs (strict mode).
Only these 4 models are supported:
- Claude Sonnet 4.5
- Claude Opus 4.1
- Claude Opus 4.5
- Claude Haiku 4.5
"""
model_lower = model.lower()
if "sonnet-4-5" in model_lower:
return True
elif "opus-4-1" in model_lower:
return True
elif "opus-4-5" in model_lower:
return True
elif "haiku-4-5" in model_lower:
return True
return False
def convert_tools_to_anthropic_format(
tools: List[OpenAITool],
use_strict: bool = False,
add_strict_field: bool = True,
) -> List[dict]:
"""See: https://docs.anthropic.com/claude/docs/tool-use
OpenAI style:
@@ -975,18 +1108,11 @@ def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
"description": "find ....",
"parameters": {
"type": "object",
"properties": {
PARAM: {
"type": PARAM_TYPE, # eg "string"
"description": PARAM_DESCRIPTION,
},
...
},
"properties": {...},
"required": List[str],
}
}
}
]
}]
Anthropic style:
"tools": [{
@@ -994,89 +1120,114 @@ def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
"description": "find ....",
"input_schema": {
"type": "object",
"properties": {
PARAM: {
"type": PARAM_TYPE, # eg "string"
"description": PARAM_DESCRIPTION,
},
...
},
"properties": {...},
"required": List[str],
}
}
]
},
}]
Two small differences:
- 1 level less of nesting
- "parameters" -> "input_schema"
Args:
tools: List of OpenAI-style tools to convert
use_strict: If True, add additionalProperties: false to all object schemas
add_strict_field: If True (and use_strict=True), add strict: true to allowlisted tools.
Set to False for token counting endpoint which doesn't support this field.
"""
formatted_tools = []
strict_count = 0
for tool in tools:
# Get the input schema
input_schema = tool.function.parameters or {"type": "object", "properties": {}, "required": []}
# Clean up the properties in the schema
# The presence of union types / default fields seems Anthropic to produce invalid JSON for tool calls
if isinstance(input_schema, dict) and "properties" in input_schema:
cleaned_properties = {}
for prop_name, prop_schema in input_schema.get("properties", {}).items():
if isinstance(prop_schema, dict):
cleaned_properties[prop_name] = _clean_property_schema(prop_schema)
else:
cleaned_properties[prop_name] = prop_schema
# Create cleaned input schema
cleaned_input_schema = {
"type": input_schema.get("type", "object"),
"properties": cleaned_properties,
}
# Only add required field if it exists and is non-empty
if "required" in input_schema and input_schema["required"]:
cleaned_input_schema["required"] = input_schema["required"]
else:
cleaned_input_schema = input_schema
formatted_tool = {
# Use the older lightweight cleanup: remove defaults and simplify union-with-null.
# When using structured outputs (use_strict=True), also add additionalProperties: false to all object types.
cleaned_schema = (
_clean_property_schema(input_schema, add_additional_properties_false=use_strict)
if isinstance(input_schema, dict)
else input_schema
)
# Normalize to a safe "object" schema shape to avoid downstream assumptions failing.
if isinstance(cleaned_schema, dict):
if cleaned_schema.get("type") != "object":
cleaned_schema["type"] = "object"
if not isinstance(cleaned_schema.get("properties"), dict):
cleaned_schema["properties"] = {}
# Ensure additionalProperties: false for structured outputs on the top-level schema
# Must override any existing additionalProperties: true as well
if use_strict:
cleaned_schema["additionalProperties"] = False
formatted_tool: dict = {
"name": tool.function.name,
"description": tool.function.description if tool.function.description else "",
"input_schema": cleaned_input_schema,
"input_schema": cleaned_schema,
}
# Structured outputs "strict" mode: always attach `strict` for allowlisted tools
# when we are using structured outputs models. Limit the number of strict tools
# to avoid exceeding Anthropic constraints.
# NOTE: The token counting endpoint does NOT support `strict` - only the messages endpoint does.
if (
use_strict
and add_strict_field
and tool.function.name in ANTHROPIC_STRICT_MODE_ALLOWLIST
and strict_count < ANTHROPIC_MAX_STRICT_TOOLS
):
formatted_tool["strict"] = True
strict_count += 1
formatted_tools.append(formatted_tool)
return formatted_tools
def _clean_property_schema(prop_schema: dict) -> dict:
"""Clean up a property schema by removing defaults and simplifying union types."""
cleaned = {}
def _clean_property_schema(schema: dict, add_additional_properties_false: bool = False) -> dict:
"""Older schema cleanup used for Anthropic tools.
# Handle type field - simplify union types like ["null", "string"] to just "string"
if "type" in prop_schema:
prop_type = prop_schema["type"]
if isinstance(prop_type, list):
# Remove "null" from union types to simplify
# e.g., ["null", "string"] becomes "string"
non_null_types = [t for t in prop_type if t != "null"]
if len(non_null_types) == 1:
cleaned["type"] = non_null_types[0]
elif len(non_null_types) > 1:
# Keep as array if multiple non-null types
cleaned["type"] = non_null_types
Removes / simplifies fields that commonly cause Anthropic tool schema issues:
- Remove `default` values
- Simplify nullable unions like {"type": ["null", "string"]} -> {"type": "string"}
- Recurse through nested schemas (properties/items/anyOf/oneOf/allOf/etc.)
- Optionally add additionalProperties: false to object types (required for structured outputs)
"""
if not isinstance(schema, dict):
return schema
cleaned: dict = {}
# Simplify union types like ["null", "string"] to "string"
if "type" in schema:
t = schema.get("type")
if isinstance(t, list):
non_null = [x for x in t if x != "null"]
if len(non_null) == 1:
cleaned["type"] = non_null[0]
elif len(non_null) > 1:
cleaned["type"] = non_null
else:
# If only "null" was in the list, default to string
cleaned["type"] = "string"
else:
cleaned["type"] = prop_type
cleaned["type"] = t
# Copy over other fields except 'default'
for key, value in prop_schema.items():
if key not in ["type", "default"]: # Skip 'default' field
if key == "properties" and isinstance(value, dict):
# Recursively clean nested properties
cleaned["properties"] = {k: _clean_property_schema(v) if isinstance(v, dict) else v for k, v in value.items()}
else:
cleaned[key] = value
for key, value in schema.items():
if key == "type":
continue
if key == "default":
continue
if key == "properties" and isinstance(value, dict):
cleaned["properties"] = {k: _clean_property_schema(v, add_additional_properties_false) for k, v in value.items()}
elif key == "items" and isinstance(value, dict):
cleaned["items"] = _clean_property_schema(value, add_additional_properties_false)
elif key in ("anyOf", "oneOf", "allOf") and isinstance(value, list):
cleaned[key] = [_clean_property_schema(v, add_additional_properties_false) if isinstance(v, dict) else v for v in value]
elif key in ("additionalProperties",) and isinstance(value, dict):
cleaned[key] = _clean_property_schema(value, add_additional_properties_false)
else:
cleaned[key] = value
# For structured outputs, Anthropic requires additionalProperties: false on all object types
# We must override any existing additionalProperties: true as well
if add_additional_properties_false and cleaned.get("type") == "object":
cleaned["additionalProperties"] = False
return cleaned

View File

@@ -0,0 +1,27 @@
# Anthropic-specific constants for the Letta LLM API
# Allowlist of simple tools that work with Anthropic's structured outputs (strict mode).
# These tools have few parameters and no complex nesting, making them safe for strict mode.
# Tools with many optional params or deeply nested structures should use non-strict mode.
#
# Anthropic limitations for strict mode:
# - Max 15 tools can use strict mode per request
# - Max 24 optional parameters per tool (counted recursively in undocumented ways)
# - Schema complexity limits
#
# Rather than trying to count parameters correctly, we allowlist simple tools that we know work.
ANTHROPIC_STRICT_MODE_ALLOWLIST = {
"Write", # 2 required params, no optional
"Read", # 1 required, 2 simple optional
"Edit", # 3 required, 1 simple optional
"Glob", # 1 required, 1 simple optional
"KillBash", # 1 required, no optional
"fetch_webpage", # 1 required, no optional
"EnterPlanMode", # no params
"ExitPlanMode", # no params
"Skill", # 1 required, 1 optional array
"conversation_search", # 1 required, 4 simple optional
}
# Maximum number of tools that can use strict mode in a single request
ANTHROPIC_MAX_STRICT_TOOLS = 15

View File

@@ -3,6 +3,7 @@ import json
import uuid
from typing import AsyncIterator, List, Optional
import httpx
from google.genai import Client, errors
from google.genai.types import (
FunctionCallingConfig,
@@ -860,6 +861,27 @@ class GoogleVertexClient(LLMClientBase):
},
)
# Handle httpx.RemoteProtocolError which can occur during streaming
# when the remote server closes the connection unexpectedly
# (e.g., "peer closed connection without sending complete message body")
if isinstance(e, httpx.RemoteProtocolError):
logger.warning(f"{self._provider_prefix()} Remote protocol error during streaming: {e}")
return LLMConnectionError(
message=f"Connection error during {self._provider_name()} streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
# Handle httpx network errors which can occur during streaming
# when the connection is unexpectedly closed while reading/writing
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
logger.warning(f"{self._provider_prefix()} Network error during streaming: {type(e).__name__}: {e}")
return LLMConnectionError(
message=f"Network error during {self._provider_name()} streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
)
# Handle connection-related errors
if "connection" in str(e).lower() or "timeout" in str(e).lower():
logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")

View File

@@ -79,6 +79,13 @@ class LLMClient:
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.zai:
from letta.llm_api.zai_client import ZAIClient
return ZAIClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.groq:
from letta.llm_api.groq_client import GroqClient

View File

@@ -2,11 +2,12 @@ import json
from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import httpx
from anthropic.types.beta.messages import BetaMessageBatch
from openai import AsyncStream, Stream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.errors import LLMError
from letta.errors import ErrorCode, LLMConnectionError, LLMError
from letta.otel.tracing import log_event, trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType, ProviderCategory
@@ -215,29 +216,59 @@ class LLMClientBase:
Returns:
An LLMError subclass that represents the error in a provider-agnostic way
"""
# Handle httpx.RemoteProtocolError which can occur during streaming
# when the remote server closes the connection unexpectedly
# (e.g., "peer closed connection without sending complete message body")
if isinstance(e, httpx.RemoteProtocolError):
from letta.log import get_logger
logger = get_logger(__name__)
logger.warning(f"[LLM] Remote protocol error during streaming: {e}")
return LLMConnectionError(
message=f"Connection error during streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
return LLMError(f"Unhandled LLM error: {str(e)}")
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Returns the override key for the given llm config.
Only fetches API key from database for BYOK providers.
Base providers use environment variables directly.
"""
api_key = None
# Only fetch API key from database for BYOK providers
# Base providers should always use environment variables
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
# If we got an empty string from the database, treat it as None
# so the client can fall back to environment variables or default behavior
if api_key == "":
api_key = None
return api_key, None, None
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Returns the override key for the given llm config.
Only fetches API key from database for BYOK providers.
Base providers use environment variables directly.
"""
api_key = None
# Only fetch API key from database for BYOK providers
# Base providers should always use environment variables
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
# If we got an empty string from the database, treat it as None
# so the client can fall back to environment variables or default behavior
if api_key == "":
api_key = None
return api_key, None, None

View File

@@ -4,6 +4,7 @@ import os
import time
from typing import Any, List, Optional
import httpx
import openai
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types import Reasoning
@@ -960,6 +961,27 @@ class OpenAIClient(LLMClientBase):
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
# Handle httpx.RemoteProtocolError which can occur during streaming
# when the remote server closes the connection unexpectedly
# (e.g., "peer closed connection without sending complete message body")
if isinstance(e, httpx.RemoteProtocolError):
logger.warning(f"[OpenAI] Remote protocol error during streaming: {e}")
return LLMConnectionError(
message=f"Connection error during OpenAI streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
# Handle httpx network errors which can occur during streaming
# when the connection is unexpectedly closed while reading/writing
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
logger.warning(f"[OpenAI] Network error during streaming: {type(e).__name__}: {e}")
return LLMConnectionError(
message=f"Network error during OpenAI streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
)
if isinstance(e, openai.RateLimitError):
logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}")
return LLMRateLimitError(

View File

@@ -0,0 +1,81 @@
import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.settings import model_settings
class ZAIClient(OpenAIClient):
"""Z.ai (ZhipuAI) client - uses OpenAI-compatible API."""
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return False
@trace_method
def build_request_data(
self,
agent_type: AgentType,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
return data
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to Z.ai API and returns raw response dict.
"""
api_key = model_settings.zai_api_key
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to Z.ai API and returns raw response dict.
"""
api_key = model_settings.zai_api_key
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to Z.ai and returns the async stream iterator.
"""
api_key = model_settings.zai_api_key
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.zai_api_key
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
return [r.embedding for r in response.data]

View File

@@ -7,6 +7,7 @@ import asyncio
import threading
import time
import traceback
from collections import defaultdict
from typing import Optional
from letta.log import get_logger
@@ -31,9 +32,12 @@ class EventLoopWatchdog:
self._thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._last_heartbeat = time.time()
self._heartbeat_scheduled_at = time.time()
self._heartbeat_lock = threading.Lock()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._monitoring = False
self._last_dump_time = 0.0 # Cooldown between task dumps
self._saturation_start: Optional[float] = None # Track when saturation began
def start(self, loop: asyncio.AbstractEventLoop):
"""Start the watchdog thread."""
@@ -43,7 +47,9 @@ class EventLoopWatchdog:
self._loop = loop
self._monitoring = True
self._stop_event.clear()
self._last_heartbeat = time.time()
now = time.time()
self._last_heartbeat = now
self._heartbeat_scheduled_at = now
self._thread = threading.Thread(target=self._watch_loop, daemon=True, name="EventLoopWatchdog")
self._thread.start()
@@ -51,7 +57,10 @@ class EventLoopWatchdog:
# Schedule periodic heartbeats on the event loop
loop.call_soon(self._schedule_heartbeats)
logger.info(f"Watchdog started (timeout: {self.timeout_threshold}s)")
logger.info(
f"Event loop watchdog started - monitoring thread running, heartbeat every 1s, "
f"checks every {self.check_interval}s, hang threshold: {self.timeout_threshold}s"
)
def stop(self):
"""Stop the watchdog thread."""
@@ -66,8 +75,16 @@ class EventLoopWatchdog:
if not self._monitoring:
return
now = time.time()
with self._heartbeat_lock:
self._last_heartbeat = time.time()
# Calculate event loop lag: time between when we scheduled this callback and when it ran
lag = now - self._heartbeat_scheduled_at
self._last_heartbeat = now
self._heartbeat_scheduled_at = now + 1.0
# Log if lag is significant (> 2 seconds means event loop is saturated)
if lag > 2.0:
logger.warning(f"Event loop lag in heartbeat: {lag:.2f}s (expected ~1.0s)")
if self._loop and self._monitoring:
self._loop.call_later(1.0, self._schedule_heartbeats)
@@ -75,6 +92,7 @@ class EventLoopWatchdog:
def _watch_loop(self):
"""Main watchdog loop running in separate thread."""
consecutive_hangs = 0
max_lag_seen = 0.0
while not self._stop_event.is_set():
try:
@@ -82,8 +100,13 @@ class EventLoopWatchdog:
with self._heartbeat_lock:
last_beat = self._last_heartbeat
scheduled_at = self._heartbeat_scheduled_at
time_since_heartbeat = time.time() - last_beat
now = time.time()
time_since_heartbeat = now - last_beat
# Calculate current lag: how far behind schedule is the heartbeat?
current_lag = now - scheduled_at
max_lag_seen = max(max_lag_seen, current_lag)
# Try to estimate event loop load (safe from separate thread)
task_count = -1
@@ -98,9 +121,34 @@ class EventLoopWatchdog:
# ALWAYS log every check to prove watchdog is alive
logger.debug(
f"WATCHDOG_CHECK: heartbeat_age={time_since_heartbeat:.1f}s, consecutive_hangs={consecutive_hangs}, tasks={task_count}"
f"WATCHDOG_CHECK: heartbeat_age={time_since_heartbeat:.1f}s, current_lag={current_lag:.2f}s, "
f"max_lag={max_lag_seen:.2f}s, consecutive_hangs={consecutive_hangs}, tasks={task_count}"
)
# Log at INFO if we see significant lag (> 2 seconds indicates saturation)
if current_lag > 2.0:
# Track saturation duration
if self._saturation_start is None:
self._saturation_start = now
saturation_duration = now - self._saturation_start
logger.info(
f"Event loop saturation detected: lag={current_lag:.2f}s, duration={saturation_duration:.1f}s, "
f"tasks={task_count}, max_lag_seen={max_lag_seen:.2f}s"
)
# Only dump stack traces with 60s cooldown to avoid spam
if (now - self._last_dump_time) > 60.0:
self._dump_asyncio_tasks() # Dump async tasks
self._dump_state() # Dump thread stacks
self._last_dump_time = now
else:
# Reset saturation tracking when recovered
if self._saturation_start is not None:
duration = now - self._saturation_start
logger.info(f"Event loop saturation ended after {duration:.1f}s")
self._saturation_start = None
if time_since_heartbeat > self.timeout_threshold:
consecutive_hangs += 1
logger.error(
@@ -108,7 +156,8 @@ class EventLoopWatchdog:
f"tasks={task_count}"
)
# Dump basic state
# Dump both thread state and asyncio tasks
self._dump_asyncio_tasks()
self._dump_state()
if consecutive_hangs >= 2:
@@ -153,6 +202,93 @@ class EventLoopWatchdog:
except Exception as e:
logger.error(f"Failed to dump state: {e}")
def _dump_asyncio_tasks(self):
"""Dump asyncio task stack traces to diagnose event loop saturation."""
try:
if not self._loop or self._loop.is_closed():
return
active_tasks = asyncio.all_tasks(self._loop)
if not active_tasks:
return
logger.warning(f"Severe lag detected - dumping active tasks ({len(active_tasks)} total):")
# Collect task data in single pass
tasks_by_location = defaultdict(list)
for task in active_tasks:
try:
if task.done():
continue
stack = task.get_stack()
if not stack:
continue
# Find top letta frame for grouping
for frame in reversed(stack):
if "letta" in frame.f_code.co_filename:
idx = frame.f_code.co_filename.find("letta/")
path = frame.f_code.co_filename[idx + 6 :] if idx != -1 else frame.f_code.co_filename
location = f"{path}:{frame.f_lineno}:{frame.f_code.co_name}"
# For bounded tasks, use wrapped coroutine location instead
if frame.f_code.co_name == "bounded_coro":
task_name = task.get_name()
if task_name and task_name.startswith("bounded["):
location = task_name[8:-1] # Extract "file:line:func" from "bounded[...]"
tasks_by_location[location].append((task, stack))
break
except Exception:
continue
if not tasks_by_location:
return
total_tasks = sum(len(tasks) for tasks in tasks_by_location.values())
logger.warning(f" Letta tasks: {total_tasks} total")
# Sort by task count (most blocked first) and show detailed stacks for top 3
sorted_patterns = sorted(tasks_by_location.items(), key=lambda x: len(x[1]), reverse=True)
num_patterns = len(sorted_patterns)
logger.warning(f" Task patterns ({num_patterns} unique locations):")
# Show detailed stacks for top 3, summary for rest
for i, (location, tasks) in enumerate(sorted_patterns, 1):
count = len(tasks)
pct = (count / total_tasks) * 100 if total_tasks > 0 else 0
if i <= 3:
# Top 3: show detailed vertical stack trace
logger.warning(f" [{i}] {count} tasks ({pct:.0f}%) at: {location}")
_, sample_stack = tasks[0]
# Show up to 8 frames vertically for better context
for frame in sample_stack[-8:]:
filename = frame.f_code.co_filename
letta_idx = filename.find("letta/")
if letta_idx != -1:
short_path = filename[letta_idx + 6 :]
logger.warning(f" {short_path}:{frame.f_lineno} in {frame.f_code.co_name}")
else:
pkg_idx = filename.find("site-packages/")
if pkg_idx != -1:
lib_path = filename[pkg_idx + 14 :]
logger.warning(f" [{lib_path}:{frame.f_lineno}] {frame.f_code.co_name}")
elif i <= 10:
# Positions 4-10: show location only
logger.warning(f" [{i}] {count} tasks ({pct:.0f}%) at: {location}")
else:
# Beyond 10: just show count in summary
if i == 11:
remaining = sum(len(t) for _, t in sorted_patterns[10:])
remaining_patterns = num_patterns - 10
logger.warning(f" ... and {remaining} more tasks across {remaining_patterns} other locations")
except Exception as e:
logger.error(f"Failed to dump asyncio tasks: {e}")
_global_watchdog: Optional[EventLoopWatchdog] = None

View File

@@ -6,6 +6,8 @@ from letta.orm.base import Base
from letta.orm.block import Block
from letta.orm.block_history import BlockHistory
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.conversation import Conversation
from letta.orm.conversation_messages import ConversationMessage
from letta.orm.file import FileMetadata
from letta.orm.files_agents import FileAgent
from letta.orm.group import Group

View File

@@ -1,19 +1,22 @@
import asyncio
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, List, Optional, Set
from typing import TYPE_CHECKING, Any, List, Optional, Set
from sqlalchemy import JSON, Boolean, DateTime, Index, Integer, String
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy import JSON, Boolean, DateTime, Index, Integer, String, select
from sqlalchemy.ext.asyncio import AsyncAttrs, async_object_session
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block
from letta.orm.custom_columns import CompactionSettingsColumn, EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn
from letta.orm.identity import Identity
from letta.orm.message import Message as MessageModel
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
from letta.orm.organization import Organization
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agent import AgentState as PydanticAgentState
ENCRYPTED_PLACEHOLDER = "<encrypted>"
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType
from letta.schemas.environment_variables import AgentEnvironmentVariable as PydanticAgentEnvVar
@@ -22,11 +25,12 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.response_format import ResponseFormatUnion
from letta.schemas.tool_rule import ToolRule
from letta.utils import calculate_file_defaults_based_on_context_window
from letta.utils import bounded_gather, calculate_file_defaults_based_on_context_window
if TYPE_CHECKING:
from letta.orm.agents_tags import AgentsTags
from letta.orm.archives_agents import ArchivesAgents
from letta.orm.conversation import Conversation
from letta.orm.files_agents import FileAgent
from letta.orm.identity import Identity
from letta.orm.organization import Organization
@@ -74,7 +78,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
LLMConfigColumn, nullable=True, doc="the LLM backend configuration object for this agent."
)
embedding_config: Mapped[Optional[EmbeddingConfig]] = mapped_column(
EmbeddingConfigColumn, doc="the embedding configuration object for this agent."
EmbeddingConfigColumn, nullable=True, doc="the embedding configuration object for this agent."
)
compaction_settings: Mapped[Optional[dict]] = mapped_column(
CompactionSettingsColumn, nullable=True, doc="the compaction settings configuration object for compaction."
@@ -147,7 +151,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
"Run",
back_populates="agent",
cascade="all, delete-orphan",
lazy="selectin",
lazy="raise",
doc="Runs associated with the agent.",
)
identities: Mapped[List["Identity"]] = relationship(
@@ -186,6 +190,13 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
lazy="noload",
doc="Archives accessible by this agent.",
)
conversations: Mapped[List["Conversation"]] = relationship(
"Conversation",
back_populates="agent",
cascade="all, delete-orphan",
lazy="raise",
doc="Conversations for concurrent messaging on this agent.",
)
def _get_per_file_view_window_char_limit(self) -> int:
"""Get the per_file_view_window_char_limit, calculating defaults if None."""
@@ -298,10 +309,33 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
return self.__pydantic_model__(**state)
async def _get_pending_approval_async(self) -> Optional[Any]:
if self.message_ids and len(self.message_ids) > 0:
# Try to get the async session this object is attached to
session = async_object_session(self)
if not session:
# Object is detached, can't safely query
return None
latest_message_id = self.message_ids[-1]
result = await session.execute(select(MessageModel).where(MessageModel.id == latest_message_id))
latest_message = result.scalar_one_or_none()
if (
latest_message
and latest_message.role == "approval"
and latest_message.tool_calls is not None
and len(latest_message.tool_calls) > 0
):
pydantic_message = latest_message.to_pydantic()
return pydantic_message._convert_approval_request_message()
return None
async def to_pydantic_async(
self,
include_relationships: Optional[Set[str]] = None,
include: Optional[List[str]] = None,
decrypt: bool = True,
) -> PydanticAgentState:
"""
Converts the SQLAlchemy Agent model into its Pydantic counterpart.
@@ -368,6 +402,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
"managed_group": None,
"tool_exec_environment_variables": [],
"secrets": [],
"pending_approval": None,
}
# Initialize include_relationships to an empty set if it's None
@@ -411,9 +446,20 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
file_agents = (
self.awaitable_attrs.file_agents if "memory" in include_relationships or "agent.blocks" in include_set else empty_list_async()
)
pending_approval = self._get_pending_approval_async() if "agent.pending_approval" in include_set else none_async()
(tags, tools, sources, memory, identities, multi_agent_group, tool_exec_environment_variables, file_agents) = await asyncio.gather(
tags, tools, sources, memory, identities, multi_agent_group, tool_exec_environment_variables, file_agents
(
tags,
tools,
sources,
memory,
identities,
multi_agent_group,
tool_exec_environment_variables,
file_agents,
pending_approval,
) = await asyncio.gather(
tags, tools, sources, memory, identities, multi_agent_group, tool_exec_environment_variables, file_agents, pending_approval
)
state["tags"] = [t.tag for t in tags]
@@ -433,12 +479,29 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
state["identities"] = [i.to_pydantic() for i in identities]
state["multi_agent_group"] = multi_agent_group
state["managed_group"] = multi_agent_group
# Convert ORM env vars to Pydantic with async decryption
env_vars_pydantic = []
for e in tool_exec_environment_variables:
env_vars_pydantic.append(await PydanticAgentEnvVar.from_orm_async(e))
# Convert ORM env vars to Pydantic, optionally skipping decryption
if decrypt:
env_vars_pydantic = await bounded_gather([PydanticAgentEnvVar.from_orm_async(e) for e in tool_exec_environment_variables])
else:
# Skip decryption - return with encrypted values (faster, no PBKDF2)
from letta.schemas.environment_variables import AgentEnvironmentVariable
from letta.schemas.secret import Secret
env_vars_pydantic = []
for e in tool_exec_environment_variables:
data = {
"id": e.id,
"key": e.key,
"description": e.description,
"organization_id": e.organization_id,
"agent_id": e.agent_id,
"value": ENCRYPTED_PLACEHOLDER,
"value_enc": Secret.from_encrypted(e.value_enc) if e.value_enc else None,
}
env_vars_pydantic.append(AgentEnvironmentVariable.model_validate(data))
state["tool_exec_environment_variables"] = env_vars_pydantic
state["secrets"] = env_vars_pydantic
state["pending_approval"] = pending_approval
state["model"] = self.llm_config.handle if self.llm_config else None
state["model_settings"] = self.llm_config._to_model_settings() if self.llm_config else None
state["embedding"] = self.embedding_config.handle if self.embedding_config else None

49
letta/orm/conversation.py Normal file
View File

@@ -0,0 +1,49 @@
import uuid
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.conversation import Conversation as PydanticConversation
if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.conversation_messages import ConversationMessage
class Conversation(SqlalchemyBase, OrganizationMixin):
"""Conversations that can be created on an agent for concurrent messaging."""
__tablename__ = "conversations"
__pydantic_model__ = PydanticConversation
__table_args__ = (
Index("ix_conversations_agent_id", "agent_id"),
Index("ix_conversations_org_agent", "organization_id", "agent_id"),
)
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv-{uuid.uuid4()}")
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False)
summary: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Summary of the conversation")
# Relationships
agent: Mapped["Agent"] = relationship("Agent", back_populates="conversations", lazy="raise")
message_associations: Mapped[List["ConversationMessage"]] = relationship(
"ConversationMessage",
back_populates="conversation",
cascade="all, delete-orphan",
lazy="selectin",
)
def to_pydantic(self) -> PydanticConversation:
"""Converts the SQLAlchemy model to its Pydantic counterpart."""
return self.__pydantic_model__(
id=self.id,
agent_id=self.agent_id,
summary=self.summary,
created_at=self.created_at,
updated_at=self.updated_at,
created_by_id=self.created_by_id,
last_updated_by_id=self.last_updated_by_id,
)

View File

@@ -0,0 +1,73 @@
import uuid
from typing import TYPE_CHECKING, Optional
from sqlalchemy import Boolean, ForeignKey, Index, Integer, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
if TYPE_CHECKING:
from letta.orm.conversation import Conversation
from letta.orm.message import Message
class ConversationMessage(SqlalchemyBase, OrganizationMixin):
"""
Track in-context messages for a conversation.
This replaces the message_ids JSON list on agents with proper relational modeling.
- conversation_id=NULL represents the "default" conversation (backward compatible)
- conversation_id=<id> represents a named conversation for concurrent messaging
"""
__tablename__ = "conversation_messages"
__table_args__ = (
Index("ix_conv_msg_conversation_position", "conversation_id", "position"),
Index("ix_conv_msg_message_id", "message_id"),
Index("ix_conv_msg_agent_id", "agent_id"),
Index("ix_conv_msg_agent_conversation", "agent_id", "conversation_id"),
UniqueConstraint("conversation_id", "message_id", name="unique_conversation_message"),
)
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv_msg-{uuid.uuid4()}")
conversation_id: Mapped[Optional[str]] = mapped_column(
String,
ForeignKey("conversations.id", ondelete="CASCADE"),
nullable=True,
doc="NULL for default conversation, otherwise FK to conversation",
)
agent_id: Mapped[str] = mapped_column(
String,
ForeignKey("agents.id", ondelete="CASCADE"),
nullable=False,
doc="The agent this message association belongs to",
)
message_id: Mapped[str] = mapped_column(
String,
ForeignKey("messages.id", ondelete="CASCADE"),
nullable=False,
doc="The message being tracked",
)
position: Mapped[int] = mapped_column(
Integer,
nullable=False,
doc="Position in conversation (for ordering)",
)
in_context: Mapped[bool] = mapped_column(
Boolean,
default=True,
nullable=False,
doc="Whether message is currently in the agent's context window",
)
# Relationships
conversation: Mapped[Optional["Conversation"]] = relationship(
"Conversation",
back_populates="message_associations",
lazy="raise",
)
message: Mapped["Message"] = relationship(
"Message",
lazy="selectin",
)

View File

@@ -69,6 +69,34 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
cascade="all, delete-orphan",
)
def to_pydantic(self, strip_directory_prefix: bool = False) -> PydanticFileMetadata:
"""
Convert to Pydantic model without any relationship loading.
"""
file_name = self.file_name
if strip_directory_prefix and "/" in file_name:
file_name = "/".join(file_name.split("/")[1:])
return PydanticFileMetadata(
id=self.id,
organization_id=self.organization_id,
source_id=self.source_id,
file_name=file_name,
original_file_name=self.original_file_name,
file_path=self.file_path,
file_type=self.file_type,
file_size=self.file_size,
file_creation_date=self.file_creation_date,
file_last_modified_date=self.file_last_modified_date,
processing_status=self.processing_status,
error_message=self.error_message,
total_chunks=self.total_chunks,
chunks_embedded=self.chunks_embedded,
created_at=self.created_at,
updated_at=self.updated_at,
content=None,
)
async def to_pydantic_async(self, include_content: bool = False, strip_directory_prefix: bool = False) -> PydanticFileMetadata:
"""
Async version of `to_pydantic` that supports optional relationship loading

View File

@@ -55,6 +55,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
nullable=True,
doc="The id of the LLMBatchItem that this message is associated with",
)
conversation_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("conversations.id", ondelete="SET NULL"),
nullable=True,
index=True,
doc="The conversation this message belongs to (NULL = default conversation)",
)
is_err: Mapped[Optional[bool]] = mapped_column(
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
)

View File

@@ -40,8 +40,8 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
@declared_attr
def organization(cls) -> Mapped["Organization"]:
"""Relationship to organization"""
return relationship("Organization", back_populates="passages", lazy="selectin")
"""Relationship to organization - use lazy='raise' to prevent accidental blocking in async contexts"""
return relationship("Organization", back_populates="passages", lazy="raise")
class SourcePassage(BasePassage, FileMixin, SourceMixin):
@@ -53,7 +53,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="source_passages", lazy="selectin")
return relationship("Organization", back_populates="source_passages", lazy="raise")
@declared_attr
def __table_args__(cls):
@@ -84,7 +84,7 @@ class ArchivalPassage(BasePassage, ArchiveMixin):
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="archival_passages", lazy="selectin")
return relationship("Organization", back_populates="archival_passages", lazy="raise")
@declared_attr
def __table_args__(cls):

View File

@@ -30,6 +30,7 @@ class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
Index("ix_runs_created_at", "created_at", "id"),
Index("ix_runs_agent_id", "agent_id"),
Index("ix_runs_organization_id", "organization_id"),
Index("ix_runs_conversation_id", "conversation_id"),
)
# Generate run ID with run- prefix
@@ -50,6 +51,11 @@ class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
# Agent relationship - A run belongs to one agent
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), nullable=False, doc="The agent that owns this run.")
# Conversation relationship - Optional, a run may be associated with a conversation
conversation_id: Mapped[Optional[str]] = mapped_column(
String, ForeignKey("conversations.id", ondelete="SET NULL"), nullable=True, doc="The conversation this run belongs to."
)
# Callback related columns
callback_url: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="When set, POST to this URL after run completion.")
callback_sent_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="Timestamp when the callback was last attempted.")

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
from sqlalchemy.orm.exc import StaleDataError
from sqlalchemy.orm.interfaces import ORMOption
from letta.errors import ConcurrentUpdateError
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
@@ -259,28 +260,40 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if before_obj and after_obj:
# Window-based query - get records between before and after
conditions.append(
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id))
)
conditions.append(
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id))
)
# Skip pagination if either object has null created_at
if before_obj.created_at is not None and after_obj.created_at is not None:
conditions.append(
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id))
)
conditions.append(
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id))
)
else:
logger.warning(
f"Skipping pagination: before_obj.created_at={before_obj.created_at}, after_obj.created_at={after_obj.created_at}"
)
else:
# Pure pagination query
if before_obj:
conditions.append(
or_(
cls.created_at < before_obj.created_at if ascending else cls.created_at > before_obj.created_at,
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
if before_obj.created_at is not None:
conditions.append(
or_(
cls.created_at < before_obj.created_at if ascending else cls.created_at > before_obj.created_at,
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
)
)
)
else:
logger.warning(f"Skipping 'before' pagination: before_obj.created_at is None (id={before_obj.id})")
if after_obj:
conditions.append(
or_(
cls.created_at > after_obj.created_at if ascending else cls.created_at < after_obj.created_at,
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
if after_obj.created_at is not None:
conditions.append(
or_(
cls.created_at > after_obj.created_at if ascending else cls.created_at < after_obj.created_at,
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
)
)
)
else:
logger.warning(f"Skipping 'after' pagination: after_obj.created_at is None (id={after_obj.id})")
if conditions:
query = query.where(and_(*conditions))
@@ -619,6 +632,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if actor:
self._set_created_and_updated_by_fields(actor.id)
self.set_updated_at()
# Capture id before try block to avoid accessing expired attributes after rollback
object_id = self.id
class_name = self.__class__.__name__
try:
db_session.add(self)
if no_commit:
@@ -633,8 +651,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# This can occur when using optimistic locking (version_id_col) and:
# 1. The row doesn't exist (0 rows matched)
# 2. The version has changed (concurrent update)
# We convert this to NoResultFound to return a proper 404 error
raise NoResultFound(f"{self.__class__.__name__} with id '{self.id}' not found or was updated by another transaction") from e
# In practice, case 1 is rare (blocks aren't frequently deleted), so we always
# return 409 ConcurrentUpdateError. If it was actually deleted, the retry will get 404.
# Not worth performing another db query to check if the row exists.
raise ConcurrentUpdateError(resource_type=class_name, resource_id=object_id) from e
except (DBAPIError, IntegrityError) as e:
self._handle_dbapi_error(e)

View File

@@ -60,6 +60,9 @@ class Step(SqlalchemyBase, ProjectMixin):
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")
request_id: Mapped[Optional[str]] = mapped_column(
None, nullable=True, doc="The API request log ID from cloud-api for correlating steps with API requests."
)
feedback: Mapped[Optional[str]] = mapped_column(
None, nullable=True, doc="The feedback for this step. Must be either 'positive' or 'negative'."
)
@@ -72,9 +75,9 @@ class Step(SqlalchemyBase, ProjectMixin):
status: Mapped[Optional[StepStatus]] = mapped_column(None, nullable=True, doc="Step status: pending, success, or failed")
# Relationships (foreign keys)
organization: Mapped[Optional["Organization"]] = relationship("Organization")
provider: Mapped[Optional["Provider"]] = relationship("Provider")
run: Mapped[Optional["Run"]] = relationship("Run", back_populates="steps")
organization: Mapped[Optional["Organization"]] = relationship("Organization", lazy="raise")
provider: Mapped[Optional["Provider"]] = relationship("Provider", lazy="raise")
run: Mapped[Optional["Run"]] = relationship("Run", back_populates="steps", lazy="raise")
# Relationships (backrefs)
messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload")

View File

@@ -82,8 +82,8 @@ class StepMetrics(SqlalchemyBase, ProjectMixin, AgentMixin):
# Relationships (foreign keys)
step: Mapped["Step"] = relationship("Step", back_populates="metrics", uselist=False)
run: Mapped[Optional["Run"]] = relationship("Run")
agent: Mapped[Optional["Agent"]] = relationship("Agent")
run: Mapped[Optional["Run"]] = relationship("Run", lazy="raise")
agent: Mapped[Optional["Agent"]] = relationship("Agent", lazy="raise")
def create(
self,

View File

@@ -37,6 +37,9 @@ _excluded_v1_endpoints_regex: List[str] = [
async def _trace_request_middleware(request: Request, call_next):
# Capture earliest possible timestamp when request enters application
entry_time = time.time()
if not _is_tracing_initialized:
return await call_next(request)
initial_span_name = f"{request.method} {request.url.path}"
@@ -47,8 +50,17 @@ async def _trace_request_middleware(request: Request, call_next):
initial_span_name,
kind=trace.SpanKind.SERVER,
) as span:
# Record when we entered the application (useful for detecting worker queuing)
span.set_attribute("entry.timestamp_ms", int(entry_time * 1000))
try:
response = await call_next(request)
# Update span name with route pattern after FastAPI has matched the route
route = request.scope.get("route")
if route and hasattr(route, "path"):
span.update_name(f"{request.method} {route.path}")
span.set_attribute("http.status_code", response.status_code)
span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR))
return response
@@ -67,44 +79,50 @@ async def _update_trace_attributes(request: Request):
if not span:
return
# Update span name with route pattern
route = request.scope.get("route")
if route and hasattr(route, "path"):
span.update_name(f"{request.method} {route.path}")
# Wrap attribute-setting work in a span to measure time before body parsing
with tracer.start_as_current_span("trace.set_attributes"):
# Update span name with route pattern
route = request.scope.get("route")
if route and hasattr(route, "path"):
span.update_name(f"{request.method} {route.path}")
# Add request info
span.set_attribute("http.method", request.method)
span.set_attribute("http.url", str(request.url))
# Add request info
span.set_attribute("http.method", request.method)
span.set_attribute("http.url", str(request.url))
# Add path params
for key, value in request.path_params.items():
span.set_attribute(f"http.{key}", value)
# Add path params
for key, value in request.path_params.items():
span.set_attribute(f"http.{key}", value)
# Add the following headers to span if available
header_attributes = {
"user_id": "user.id",
"x-organization-id": "organization.id",
"x-project-id": "project.id",
"x-agent-id": "agent.id",
"x-template-id": "template.id",
"x-base-template-id": "base_template.id",
"user-agent": "client",
"x-stainless-package-version": "sdk.version",
"x-stainless-lang": "sdk.language",
"x-letta-source": "source",
}
for header_key, span_key in header_attributes.items():
header_value = request.headers.get(header_key)
if header_value:
span.set_attribute(span_key, header_value)
# Add the following headers to span if available
header_attributes = {
"user_id": "user.id",
"x-organization-id": "organization.id",
"x-project-id": "project.id",
"x-agent-id": "agent.id",
"x-template-id": "template.id",
"x-base-template-id": "base_template.id",
"user-agent": "client",
"x-stainless-package-version": "sdk.version",
"x-stainless-lang": "sdk.language",
"x-letta-source": "source",
}
for header_key, span_key in header_attributes.items():
header_value = request.headers.get(header_key)
if header_value:
span.set_attribute(span_key, header_value)
# Add request body if available
try:
body = await request.json()
for key, value in body.items():
span.set_attribute(f"http.request.body.{key}", str(value))
except Exception:
pass
# Add request body if available (only for JSON requests)
content_type = request.headers.get("content-type", "")
if "application/json" in content_type and request.method in ("POST", "PUT", "PATCH"):
try:
with tracer.start_as_current_span("trace.request_body"):
body = await request.json()
for key, value in body.items():
span.set_attribute(f"http.request.body.{key}", str(value))
except Exception:
# Ignore JSON parsing errors (empty body, invalid JSON, etc.)
pass
async def _trace_error_handler(_request: Request, exc: Exception) -> JSONResponse:

View File

@@ -14,6 +14,7 @@ from letta.schemas.file import FileStatus
from letta.schemas.group import Group
from letta.schemas.identity import Identity
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import ApprovalRequestMessage
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
@@ -51,6 +52,7 @@ AgentRelationships = Literal[
"agent.blocks",
"agent.identities",
"agent.managed_group",
"agent.pending_approval",
"agent.secrets",
"agent.sources",
"agent.tags",
@@ -81,8 +83,8 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
llm_config: LLMConfig = Field(
..., description="Deprecated: Use `model` field instead. The LLM configuration used by the agent.", deprecated=True
)
embedding_config: EmbeddingConfig = Field(
..., description="Deprecated: Use `embedding` field instead. The embedding configuration used by the agent.", deprecated=True
embedding_config: Optional[EmbeddingConfig] = Field(
None, description="Deprecated: Use `embedding` field instead. The embedding configuration used by the agent.", deprecated=True
)
model: Optional[str] = Field(None, description="The model handle used by the agent (format: provider/model-name).")
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
@@ -125,6 +127,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
[], description="Deprecated: Use `identities` field instead. The ids of the identities associated with this agent.", deprecated=True
)
identities: List[Identity] = Field([], description="The identities associated with this agent.")
pending_approval: Optional[ApprovalRequestMessage] = Field(
None, description="The latest approval request message pending for this agent, if any."
)
# An advanced configuration that makes it so this agent does not remember any previous messages
message_buffer_autoclear: bool = Field(

View File

@@ -0,0 +1,28 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from letta.schemas.letta_base import OrmMetadataBase
class Conversation(OrmMetadataBase):
"""Represents a conversation on an agent for concurrent messaging."""
__id_prefix__ = "conv"
id: str = Field(..., description="The unique identifier of the conversation.")
agent_id: str = Field(..., description="The ID of the agent this conversation belongs to.")
summary: Optional[str] = Field(None, description="A summary of the conversation.")
in_context_message_ids: List[str] = Field(default_factory=list, description="The IDs of in-context messages for the conversation.")
class CreateConversation(BaseModel):
"""Request model for creating a new conversation."""
summary: Optional[str] = Field(None, description="A summary of the conversation.")
class UpdateConversation(BaseModel):
"""Request model for updating a conversation."""
summary: Optional[str] = Field(None, description="A summary of the conversation.")

View File

@@ -26,6 +26,7 @@ class PrimitiveType(str, Enum):
SANDBOX_CONFIG = "sandbox" # Note: sandbox_config IDs use "sandbox" prefix
STEP = "step"
IDENTITY = "identity"
CONVERSATION = "conv"
# Infrastructure types
MCP_SERVER = "mcp_server"
@@ -67,6 +68,7 @@ class ProviderType(str, Enum):
together = "together"
vllm = "vllm"
xai = "xai"
zai = "zai"
class AgentType(str, Enum):

View File

@@ -1,11 +1,15 @@
import traceback
from typing import Optional
from pydantic import Field, model_validator
from letta.log import get_logger
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
from letta.schemas.secret import Secret
logger = get_logger(__name__)
# Base Environment Variable
class EnvironmentVariableBase(OrmMetadataBase):
@@ -24,18 +28,30 @@ class EnvironmentVariableBase(OrmMetadataBase):
# This validator syncs `value` and `value_enc` for backward compatibility:
# - If `value_enc` is set but `value` is empty -> populate `value` from decrypted `value_enc`
# - If `value` is set but `value_enc` is empty -> populate `value_enc` from encrypted `value`
@model_validator(mode="after")
def sync_value_and_value_enc(self):
"""Sync deprecated `value` field with `value_enc` for backward compatibility."""
if self.value_enc and not self.value:
# Decrypt value_enc -> value (for API responses)
plaintext = self.value_enc.get_plaintext()
if plaintext:
self.value = plaintext
elif self.value and not self.value_enc:
# Encrypt value -> value_enc (for backward compat when value is provided directly)
self.value_enc = Secret.from_plaintext(self.value)
return self
# @model_validator(mode="after")
# def sync_value_and_value_enc(self):
# """Sync deprecated `value` field with `value_enc` for backward compatibility."""
# if self.value_enc and not self.value:
# # ERROR: This should not happen - all code paths should populate value via async decryption
# # Log error with stack trace to identify the caller that bypassed async decryption
# logger.warning(
# f"Sync decryption fallback triggered for env var key={self.key}. "
# f"This indicates a code path that bypassed async decryption. Stack trace:\n{''.join(traceback.format_stack())}"
# )
# # Decrypt value_enc -> value (for API responses)
# plaintext = self.value_enc.get_plaintext()
# if plaintext:
# self.value = plaintext
# elif self.value and not self.value_enc:
# # WARNING: This triggers sync encryption - should use async encryption where possible
# # Log warning with stack trace to identify the caller
# logger.warning(
# f"Sync encryption fallback triggered for env var key={self.key}. "
# f"This indicates a code path that bypassed async encryption. Stack trace:\n{''.join(traceback.format_stack())}"
# )
# # Encrypt value -> value_enc (for backward compat when value is provided directly)
# self.value_enc = Secret.from_plaintext(self.value)
# return self
class EnvironmentVariableCreateBase(LettaBase):

View File

@@ -1,5 +1,5 @@
import uuid
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
@@ -9,6 +9,19 @@ from letta.schemas.letta_message_content import LettaMessageContentUnion
from letta.schemas.message import MessageCreate, MessageCreateUnion, MessageRole
class ClientToolSchema(BaseModel):
"""Schema for a client-side tool passed in the request.
Client-side tools are executed by the client, not the server. When the agent
calls a client-side tool, execution pauses and returns control to the client
to execute the tool and provide the result.
"""
name: str = Field(..., description="The name of the tool function")
description: Optional[str] = Field(None, description="Description of what the tool does")
parameters: Optional[Dict[str, Any]] = Field(None, description="JSON Schema for the function parameters")
class LettaRequest(BaseModel):
messages: Optional[List[MessageCreateUnion]] = Field(None, description="The messages to be sent to the agent.")
input: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(
@@ -45,6 +58,13 @@ class LettaRequest(BaseModel):
deprecated=True,
)
# Client-side tools
client_tools: Optional[List[ClientToolSchema]] = Field(
None,
description="Client-side tools that the agent can call. When the agent calls a client-side tool, "
"execution pauses and returns control to the client to execute the tool and provide the result via a ToolReturn.",
)
@field_validator("messages", mode="before")
@classmethod
def add_default_type_to_messages(cls, v):

View File

@@ -48,6 +48,7 @@ class LLMConfig(BaseModel):
"bedrock",
"deepseek",
"xai",
"zai",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
@@ -317,6 +318,7 @@ class LLMConfig(BaseModel):
OpenAIReasoning,
TogetherModelSettings,
XAIModelSettings,
ZAIModelSettings,
)
if self.model_endpoint_type == "openai":
@@ -359,6 +361,11 @@ class LLMConfig(BaseModel):
max_output_tokens=self.max_tokens or 4096,
temperature=self.temperature,
)
elif self.model_endpoint_type == "zai":
return ZAIModelSettings(
max_output_tokens=self.max_tokens or 4096,
temperature=self.temperature,
)
elif self.model_endpoint_type == "groq":
return GroqModelSettings(
max_output_tokens=self.max_tokens or 4096,

View File

@@ -3,7 +3,9 @@ import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from pydantic import Field
from urllib.parse import urlparse
from pydantic import Field, field_validator
logger = logging.getLogger(__name__)
@@ -51,6 +53,21 @@ class MCPServer(BaseMCPServer):
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
@field_validator("server_url")
@classmethod
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
"""Validate that server_url is a valid HTTP(S) URL if provided."""
if v is None:
return v
if not v:
raise ValueError("server_url cannot be empty")
parsed = urlparse(v)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
if not parsed.netloc:
raise ValueError(f"server_url must have a valid host, got: '{v}'")
return v
def get_token_secret(self) -> Optional[Secret]:
"""Get the token as a Secret object."""
return self.token_enc
@@ -199,6 +216,21 @@ class UpdateSSEMCPServer(LettaBase):
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
@field_validator("server_url")
@classmethod
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
"""Validate that server_url is a valid HTTP(S) URL if provided."""
if v is None:
return v
if not v:
raise ValueError("server_url cannot be empty")
parsed = urlparse(v)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
if not parsed.netloc:
raise ValueError(f"server_url must have a valid host, got: '{v}'")
return v
class UpdateStdioMCPServer(LettaBase):
"""Update a Stdio MCP server"""
@@ -218,6 +250,21 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
@field_validator("server_url")
@classmethod
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
"""Validate that server_url is a valid HTTP(S) URL if provided."""
if v is None:
return v
if not v:
raise ValueError("server_url cannot be empty")
parsed = urlparse(v)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
if not parsed.netloc:
raise ValueError(f"server_url must have a valid host, got: '{v}'")
return v
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer]

View File

@@ -1,8 +1,9 @@
import json
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from urllib.parse import urlparse
from pydantic import Field
from pydantic import Field, field_validator
from letta.functions.mcp_client.types import (
MCP_AUTH_HEADER_AUTHORIZATION,
@@ -41,6 +42,19 @@ class CreateSSEMCPServer(LettaBase):
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
@field_validator("server_url")
@classmethod
def validate_server_url(cls, v: str) -> str:
"""Validate that server_url is a valid HTTP(S) URL."""
if not v:
raise ValueError("server_url cannot be empty")
parsed = urlparse(v)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
if not parsed.netloc:
raise ValueError(f"server_url must have a valid host, got: '{v}'")
return v
class CreateStreamableHTTPMCPServer(LettaBase):
"""Create a new Streamable HTTP MCP server"""
@@ -51,6 +65,19 @@ class CreateStreamableHTTPMCPServer(LettaBase):
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
@field_validator("server_url")
@classmethod
def validate_server_url(cls, v: str) -> str:
"""Validate that server_url is a valid HTTP(S) URL."""
if not v:
raise ValueError("server_url cannot be empty")
parsed = urlparse(v)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
if not parsed.netloc:
raise ValueError(f"server_url must have a valid host, got: '{v}'")
return v
CreateMCPServerUnion = Union[CreateStdioMCPServer, CreateSSEMCPServer, CreateStreamableHTTPMCPServer]
@@ -99,6 +126,21 @@ class UpdateSSEMCPServer(LettaBase):
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
@field_validator("server_url")
@classmethod
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
"""Validate that server_url is a valid HTTP(S) URL if provided."""
if v is None:
return v
if not v:
raise ValueError("server_url cannot be empty")
parsed = urlparse(v)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
if not parsed.netloc:
raise ValueError(f"server_url must have a valid host, got: '{v}'")
return v
class UpdateStreamableHTTPMCPServer(LettaBase):
"""Update schema for Streamable HTTP MCP server - all fields optional"""
@@ -109,6 +151,21 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
@field_validator("server_url")
@classmethod
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
"""Validate that server_url is a valid HTTP(S) URL if provided."""
if v is None:
return v
if not v:
raise ValueError("server_url cannot be empty")
parsed = urlparse(v)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
if not parsed.netloc:
raise ValueError(f"server_url must have a valid host, got: '{v}'")
return v
UpdateMCPServerUnion = Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStreamableHTTPMCPServer]

View File

@@ -456,6 +456,7 @@ class CreateArchivalMemory(BaseModel):
class ArchivalMemorySearchResult(BaseModel):
id: str = Field(..., description="Unique identifier of the archival memory passage")
timestamp: str = Field(..., description="Timestamp of when the memory was created, formatted in agent's timezone")
content: str = Field(..., description="Text content of the archival memory passage")
tags: List[str] = Field(default_factory=list, description="List of tags associated with this memory")

View File

@@ -211,6 +211,7 @@ class Message(BaseMessage):
tool_returns (List[ToolReturn]): The list of tool returns requested.
group_id (str): The multi-agent group that the message was sent in.
sender_id (str): The id of the sender of the message, can be an identity id or agent id.
conversation_id (str): The conversation this message belongs to.
t
"""
@@ -237,6 +238,7 @@ class Message(BaseMessage):
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
conversation_id: Optional[str] = Field(default=None, description="The conversation this message belongs to")
is_err: Optional[bool] = Field(
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
)
@@ -1639,13 +1641,13 @@ class Message(BaseMessage):
# TextContent, ImageContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent
if isinstance(content_part, ReasoningContent):
if current_model == self.model:
content.append(
{
"type": "thinking",
"thinking": content_part.reasoning,
"signature": content_part.signature,
}
)
block = {
"type": "thinking",
"thinking": content_part.reasoning,
}
if content_part.signature:
block["signature"] = content_part.signature
content.append(block)
elif isinstance(content_part, RedactedReasoningContent):
if current_model == self.model:
content.append(
@@ -1671,13 +1673,13 @@ class Message(BaseMessage):
for content_part in self.content:
if isinstance(content_part, ReasoningContent):
if current_model == self.model:
content.append(
{
"type": "thinking",
"thinking": content_part.reasoning,
"signature": content_part.signature,
}
)
block = {
"type": "thinking",
"thinking": content_part.reasoning,
}
if content_part.signature:
block["signature"] = content_part.signature
content.append(block)
if isinstance(content_part, RedactedReasoningContent):
if current_model == self.model:
content.append(
@@ -1729,29 +1731,42 @@ class Message(BaseMessage):
elif self.role == "tool":
# NOTE: Anthropic uses role "user" for "tool" responses
content = []
for tool_return in self.tool_returns:
if not tool_return.tool_call_id:
from letta.log import get_logger
# Handle the case where tool_returns is None or empty
if self.tool_returns:
# For single tool returns, we can use the message's tool_call_id as fallback
# since self.tool_call_id == tool_returns[0].tool_call_id for legacy compatibility.
# For multiple tool returns (parallel tool calls), each must have its own ID
# to correctly map results to their corresponding tool invocations.
use_message_fallback = len(self.tool_returns) == 1
for idx, tool_return in enumerate(self.tool_returns):
# Get tool_call_id from tool_return; only use message fallback for single returns
resolved_tool_call_id = tool_return.tool_call_id
if not resolved_tool_call_id and use_message_fallback:
resolved_tool_call_id = self.tool_call_id
if not resolved_tool_call_id:
from letta.log import get_logger
logger = get_logger(__name__)
logger.error(
f"Missing tool_call_id in tool return. "
f"Message ID: {self.id}, "
f"Tool name: {getattr(tool_return, 'name', 'unknown')}, "
f"Tool return: {tool_return}"
logger = get_logger(__name__)
logger.error(
f"Missing tool_call_id in tool return and no fallback available. "
f"Message ID: {self.id}, "
f"Tool name: {self.name or 'unknown'}, "
f"Tool return index: {idx}/{len(self.tool_returns)}, "
f"Tool return status: {tool_return.status}"
)
raise TypeError(
f"Anthropic API requires tool_use_id to be set. "
f"Message ID: {self.id}, Tool: {self.name or 'unknown'}, "
f"Tool return index: {idx}/{len(self.tool_returns)}"
)
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
content.append(
{
"type": "tool_result",
"tool_use_id": resolved_tool_call_id,
"content": func_response,
}
)
raise TypeError(
f"Anthropic API requires tool_use_id to be set. "
f"Message ID: {self.id}, Tool: {getattr(tool_return, 'name', 'unknown')}"
)
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
content.append(
{
"type": "tool_result",
"tool_use_id": tool_return.tool_call_id,
"content": func_response,
}
)
if content:
anthropic_message = {
"role": "user",
@@ -2302,6 +2317,7 @@ class MessageSearchRequest(BaseModel):
agent_id: Optional[str] = Field(None, description="Filter messages by agent ID")
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
template_id: Optional[str] = Field(None, description="Filter messages by template ID")
conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID")
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
@@ -2310,6 +2326,8 @@ class MessageSearchRequest(BaseModel):
class SearchAllMessagesRequest(BaseModel):
query: str = Field(..., description="Text query for full-text search")
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
agent_id: Optional[str] = Field(None, description="Filter messages by agent ID")
conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID")
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")

View File

@@ -47,6 +47,7 @@ class Model(LLMConfig, ModelBase):
"bedrock",
"deepseek",
"xai",
"zai",
] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True)
context_window: int = Field(
..., description="Deprecated: Use 'max_context_window' field instead. The context window size for the model.", deprecated=True
@@ -131,6 +132,7 @@ class Model(LLMConfig, ModelBase):
ProviderType.google_vertex: GoogleVertexModelSettings,
ProviderType.azure: AzureModelSettings,
ProviderType.xai: XAIModelSettings,
ProviderType.zai: ZAIModelSettings,
ProviderType.groq: GroqModelSettings,
ProviderType.deepseek: DeepseekModelSettings,
ProviderType.together: TogetherModelSettings,
@@ -352,6 +354,22 @@ class XAIModelSettings(ModelSettings):
}
class ZAIModelSettings(ModelSettings):
"""Z.ai (ZhipuAI) model configuration (OpenAI-compatible)."""
provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.")
temperature: float = Field(0.7, description="The temperature of the model.")
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
def _to_legacy_config_params(self) -> dict:
return {
"temperature": self.temperature,
"max_tokens": self.max_output_tokens,
"response_format": self.response_format,
"parallel_tool_calls": self.parallel_tool_calls,
}
class GroqModelSettings(ModelSettings):
"""Groq model configuration (OpenAI-compatible)."""
@@ -424,6 +442,7 @@ ModelSettingsUnion = Annotated[
GoogleVertexModelSettings,
AzureModelSettings,
XAIModelSettings,
ZAIModelSettings,
GroqModelSettings,
DeepseekModelSettings,
TogetherModelSettings,

View File

@@ -18,6 +18,7 @@ from .openrouter import OpenRouterProvider
from .together import TogetherProvider
from .vllm import VLLMProvider
from .xai import XAIProvider
from .zai import ZAIProvider
__all__ = [
# Base classes
@@ -43,5 +44,6 @@ __all__ = [
"TogetherProvider",
"VLLMProvider", # Replaces ChatCompletions and Completions
"XAIProvider",
"ZAIProvider",
"OpenRouterProvider",
]

View File

@@ -109,18 +109,19 @@ class AnthropicProvider(Provider):
async def check_api_key(self):
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if api_key:
anthropic_client = anthropic.Anthropic(api_key=api_key)
try:
# just use a cheap model to count some tokens - as of 5/7/2025 this is faster than fetching the list of models
anthropic_client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}])
except anthropic.AuthenticationError as e:
raise LLMAuthenticationError(message=f"Failed to authenticate with Anthropic: {e}", code=ErrorCode.UNAUTHENTICATED)
except Exception as e:
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
else:
if not api_key:
raise ValueError("No API key provided")
try:
# Use async Anthropic client
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
# just use a cheap model to count some tokens - as of 5/7/2025 this is faster than fetching the list of models
await anthropic_client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}])
except anthropic.AuthenticationError as e:
raise LLMAuthenticationError(message=f"Failed to authenticate with Anthropic: {e}", code=ErrorCode.UNAUTHENTICATED)
except Exception as e:
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
def get_default_max_output_tokens(self, model_name: str) -> int:
"""Get the default max output tokens for Anthropic models."""
if "opus" in model_name:
@@ -138,8 +139,21 @@ class AnthropicProvider(Provider):
NOTE: currently there is no GET /models, so we need to hardcode
"""
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
# For claude-pro-max provider, use OAuth Bearer token instead of api_key
is_oauth_provider = self.name == "claude-pro-max"
if api_key:
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
if is_oauth_provider:
anthropic_client = anthropic.AsyncAnthropic(
default_headers={
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
"anthropic-beta": "oauth-2025-04-20",
},
)
else:
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.AsyncAnthropic()
else:

View File

@@ -196,6 +196,7 @@ class Provider(ProviderBase):
TogetherProvider,
VLLMProvider,
XAIProvider,
ZAIProvider,
)
if self.base_url == "":
@@ -230,6 +231,8 @@ class Provider(ProviderBase):
return CerebrasProvider(**self.model_dump(exclude_none=True))
case ProviderType.xai:
return XAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.zai:
return ZAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.lmstudio_openai:
return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.bedrock:

View File

@@ -8,10 +8,13 @@ from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.base import Provider
LETTA_EMBEDDING_ENDPOINT = "https://embeddings.letta.com/"
class LettaProvider(Provider):
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field(LETTA_EMBEDDING_ENDPOINT, description="Base URL for the Letta API (used for embeddings).")
async def list_llm_models_async(self) -> list[LLMConfig]:
return [
@@ -32,7 +35,7 @@ class LettaProvider(Provider):
EmbeddingConfig(
embedding_model="letta-free", # NOTE: renamed
embedding_endpoint_type="openai",
embedding_endpoint="https://embeddings.letta.com/",
embedding_endpoint=self.base_url,
embedding_dim=1536,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle("letta-free", is_embedding=True),

View File

@@ -1,8 +1,10 @@
from typing import Literal
from openai import AsyncOpenAI, AuthenticationError
from pydantic import Field
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_CONTEXT_WINDOW
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.log import get_logger
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory, ProviderType
@@ -23,11 +25,21 @@ class OpenAIProvider(Provider):
base_url: str = Field("https://api.openai.com/v1", description="Base URL for the OpenAI API.")
async def check_api_key(self):
from letta.llm_api.openai import openai_check_valid_api_key # TODO: DO NOT USE THIS - old code path
# Decrypt API key before using
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
openai_check_valid_api_key(self.base_url, api_key)
if not api_key:
raise ValueError("No API key provided")
try:
# Use async OpenAI client to check API key validity
client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
# Just list models to verify API key works
await client.models.list()
except AuthenticationError as e:
raise LLMAuthenticationError(message=f"Failed to authenticate with OpenAI: {e}", code=ErrorCode.UNAUTHENTICATED)
except Exception as e:
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
def get_default_max_output_tokens(self, model_name: str) -> int:
"""Get the default max output tokens for OpenAI models."""

View File

@@ -0,0 +1,71 @@
from typing import Literal
from letta.log import get_logger
logger = get_logger(__name__)
from pydantic import Field
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.openai import OpenAIProvider
# Z.ai model context windows
# Reference: https://docs.z.ai/
MODEL_CONTEXT_WINDOWS = {
"glm-4.5": 128000,
"glm-4.6": 200000,
"glm-4.7": 200000,
}
class ZAIProvider(OpenAIProvider):
"""Z.ai (ZhipuAI) provider - https://docs.z.ai/"""
provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str | None = Field(None, description="API key for the Z.ai API.", deprecated=True)
base_url: str = Field("https://api.z.ai/api/paas/v4/", description="Base URL for the Z.ai API.")
def get_model_context_window_size(self, model_name: str) -> int | None:
# Z.ai doesn't return context window in the model listing,
# this is hardcoded from documentation
return MODEL_CONTEXT_WINDOWS.get(model_name)
async def list_llm_models_async(self) -> list[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
data = response.get("data", response)
configs = []
for model in data:
assert "id" in model, f"Z.ai model missing 'id' field: {model}"
model_name = model["id"]
# In case Z.ai starts supporting it in the future:
if "context_length" in model:
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
if not context_window_size:
logger.warning(f"Couldn't find context window size for model {model_name}")
continue
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type=self.provider_type.value,
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
max_tokens=self.get_default_max_output_tokens(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs

View File

@@ -27,6 +27,9 @@ class Run(RunBase):
# Agent relationship
agent_id: str = Field(..., description="The unique identifier of the agent associated with the run.")
# Conversation relationship
conversation_id: Optional[str] = Field(None, description="The unique identifier of the conversation associated with the run.")
# Template fields
base_template_id: Optional[str] = Field(None, description="The base template ID that the run belongs to.")

View File

@@ -5,6 +5,7 @@ from pydantic_core import core_schema
from letta.helpers.crypto_utils import CryptoUtils
from letta.log import get_logger
from letta.utils import bounded_gather
logger = get_logger(__name__)
@@ -67,6 +68,72 @@ class Secret(BaseModel):
return instance
raise # Re-raise if it's a different error
@classmethod
async def from_plaintext_async(cls, value: Optional[str]) -> "Secret":
"""
Create a Secret from a plaintext value, encrypting it asynchronously.
This async version runs encryption in a thread pool to avoid blocking
the event loop during the CPU-intensive PBKDF2 key derivation (100-500ms).
Use this method in all async contexts (FastAPI endpoints, async services, etc.)
to avoid blocking the event loop.
Args:
value: The plaintext value to encrypt
Returns:
A Secret instance with the encrypted (or plaintext) value
"""
if value is None:
return cls.model_construct(encrypted_value=None)
# Guard against double encryption - check if value is already encrypted
if CryptoUtils.is_encrypted(value):
logger.warning("Creating Secret from already-encrypted value. This can be dangerous.")
# Try to encrypt asynchronously, but fall back to storing plaintext if no encryption key
try:
encrypted = await CryptoUtils.encrypt_async(value)
return cls.model_construct(encrypted_value=encrypted)
except ValueError as e:
# No encryption key available, store as plaintext in the _enc column
if "No encryption key configured" in str(e):
logger.warning(
"No encryption key configured. Storing Secret value as plaintext in _enc column. "
"Set LETTA_ENCRYPTION_KEY environment variable to enable encryption."
)
instance = cls.model_construct(encrypted_value=value)
instance._plaintext_cache = value # Cache it since we know the plaintext
return instance
raise # Re-raise if it's a different error
@classmethod
async def from_plaintexts_async(cls, values: dict[str, str], max_concurrency: int = 10) -> dict[str, "Secret"]:
"""
Create multiple Secrets from plaintexts concurrently with bounded concurrency.
Uses bounded_gather() to encrypt values in parallel while limiting
concurrent operations to prevent overwhelming the event loop.
Args:
values: Dict of key -> plaintext value
max_concurrency: Maximum number of concurrent encryption operations (default: 10)
Returns:
Dict of key -> Secret
"""
if not values:
return {}
keys = list(values.keys())
async def encrypt_one(key: str) -> "Secret":
return await cls.from_plaintext_async(values[key])
secrets = await bounded_gather([encrypt_one(k) for k in keys], max_concurrency=max_concurrency)
return dict(zip(keys, secrets))
@classmethod
def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret":
"""

View File

@@ -38,6 +38,7 @@ class Step(StepBase):
tags: List[str] = Field([], description="Metadata tags.")
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
trace_id: Optional[str] = Field(None, description="The trace id of the agent step.")
request_id: Optional[str] = Field(None, description="The API request log ID from cloud-api for correlating steps with API requests.")
messages: List[Message] = Field(
[],
description="The messages generated during this step. Deprecated: use `GET /v1/steps/{step_id}/messages` endpoint instead",

View File

@@ -39,12 +39,17 @@ else:
# Add asyncpg-specific settings for connection
if not settings.disable_sqlalchemy_pooling:
engine_args["connect_args"] = {
connect_args = {
"timeout": settings.pg_pool_timeout,
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid.uuid4()}__",
"statement_cache_size": 0,
"prepared_statement_cache_size": 0,
}
# Only add SSL if not already specified in connection string
if "sslmode" not in async_pg_uri and "ssl" not in async_pg_uri:
connect_args["ssl"] = "require"
engine_args["connect_args"] = connect_args
# Create the engine once at module level
engine: AsyncEngine = create_async_engine(async_pg_uri, **engine_args)

View File

@@ -18,7 +18,7 @@ faulthandler.enable()
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, ORJSONResponse
from marshmallow import ValidationError
from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.middleware.cors import CORSMiddleware
@@ -32,6 +32,9 @@ from letta.errors import (
AgentFileImportError,
AgentNotFoundForExportError,
BedrockPermissionError,
ConcurrentUpdateError,
ConversationBusyError,
EmbeddingConfigRequiredError,
HandleNotFoundError,
LettaAgentNotFoundError,
LettaExpiredError,
@@ -49,6 +52,7 @@ from letta.errors import (
LLMProviderOverloaded,
LLMRateLimitError,
LLMTimeoutError,
NoActiveRunsToCancelError,
PendingApprovalError,
)
from letta.helpers.pinecone_utils import get_pinecone_indices, should_use_pinecone, upsert_pinecone_indices
@@ -69,7 +73,7 @@ from letta.server.global_exception_handler import setup_global_exception_handler
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware, RequestIdMiddleware
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
@@ -241,10 +245,6 @@ def create_application() -> "FastAPI":
os.environ.setdefault("DD_PROFILING_MEMORY_ENABLED", str(telemetry_settings.datadog_profiling_memory_enabled).lower())
os.environ.setdefault("DD_PROFILING_HEAP_ENABLED", str(telemetry_settings.datadog_profiling_heap_enabled).lower())
# Enable LLM Observability for tracking LLM calls, prompts, and completions
os.environ.setdefault("DD_LLMOBS_ENABLED", "1")
os.environ.setdefault("DD_LLMOBS_ML_APP", "memgpt-server")
# Note: DD_LOGS_INJECTION, DD_APPSEC_ENABLED, DD_IAST_ENABLED, DD_APPSEC_SCA_ENABLED
# are set via deployment configs and automatically picked up by ddtrace
@@ -252,6 +252,49 @@ def create_application() -> "FastAPI":
import ddtrace
ddtrace.patch_all() # Auto-instrument FastAPI, HTTP, DB, etc.
llmobs_flag = os.getenv("DD_LLMOBS_ENABLED", "")
from ddtrace.llmobs import LLMObs
try:
from ddtrace.llmobs._constants import MODEL_PROVIDER
from ddtrace.llmobs._integrations.openai import OpenAIIntegration
if not getattr(OpenAIIntegration, "_letta_provider_patch_done", False):
original_set_tags = OpenAIIntegration._llmobs_set_tags
def _letta_set_tags(self, span, args, kwargs, response=None, operation=""):
original_set_tags(self, span, args, kwargs, response=response, operation=operation)
base_url = span.get_tag("openai.api_base")
if not base_url:
try:
client = getattr(self, "_client", None)
base_url = str(getattr(client, "_base_url", "") or "")
except Exception:
base_url = ""
u = (base_url or "").lower()
provider = None
if "openrouter" in u:
provider = "openrouter"
elif "groq" in u:
provider = "groq"
if provider:
span._set_ctx_item(MODEL_PROVIDER, provider)
span._set_tag_str("openai.request.provider", provider)
OpenAIIntegration._llmobs_set_tags = _letta_set_tags
OpenAIIntegration._letta_provider_patch_done = True
except Exception:
logger.exception("Failed to patch ddtrace OpenAI LLMObs provider detection")
if llmobs_flag:
LLMObs.enable(
ml_app=os.getenv("DD_LLMOBS_ML_APP") or telemetry_settings.datadog_service_name,
)
logger.info(
f"Datadog tracer initialized: env={dd_env}, "
f"service={telemetry_settings.datadog_service_name}, "
@@ -296,6 +339,7 @@ def create_application() -> "FastAPI":
version=letta_version,
debug=debug_mode, # if True, the stack trace will be printed in the response
lifespan=lifespan,
default_response_class=ORJSONResponse, # Use orjson for 10x faster JSON serialization
)
# === Global Exception Handlers ===
@@ -431,6 +475,7 @@ def create_application() -> "FastAPI":
app.add_exception_handler(LettaToolCreateError, _error_handler_400)
app.add_exception_handler(LettaToolNameConflictError, _error_handler_400)
app.add_exception_handler(AgentFileImportError, _error_handler_400)
app.add_exception_handler(EmbeddingConfigRequiredError, _error_handler_400)
app.add_exception_handler(ValueError, _error_handler_400)
# 404 Not Found errors
@@ -451,7 +496,10 @@ def create_application() -> "FastAPI":
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409)
app.add_exception_handler(IntegrityError, _error_handler_409)
app.add_exception_handler(ConcurrentUpdateError, _error_handler_409)
app.add_exception_handler(ConversationBusyError, _error_handler_409)
app.add_exception_handler(PendingApprovalError, _error_handler_409)
app.add_exception_handler(NoActiveRunsToCancelError, _error_handler_409)
# 415 Unsupported Media Type errors
app.add_exception_handler(LettaUnsupportedFileUploadError, _error_handler_415)
@@ -586,6 +634,10 @@ def create_application() -> "FastAPI":
# Add unified logging middleware - enriches log context and logs exceptions
app.add_middleware(LoggingMiddleware)
# Add request ID middleware - extracts x-api-request-log-id header and sets it in contextvar
# This is a pure ASGI middleware to properly propagate contextvars to streaming responses
app.add_middleware(RequestIdMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
@@ -634,9 +686,6 @@ def create_application() -> "FastAPI":
# app.include_router(route, prefix="", include_in_schema=False)
app.include_router(route, prefix="/latest", include_in_schema=False)
# NOTE: ethan these are the extra routes
# TODO(ethan) remove
# admin/users
app.include_router(users_router, prefix=ADMIN_PREFIX)
app.include_router(organizations_router, prefix=ADMIN_PREFIX)

View File

@@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, Optional
from fastapi import Header
from pydantic import BaseModel
from letta.otel.tracing import tracer
if TYPE_CHECKING:
from letta.server.server import SyncServer
@@ -39,25 +41,27 @@ def get_headers(
modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox"),
) -> HeaderParams:
"""Dependency injection function to extract common headers from requests."""
return HeaderParams(
actor_id=actor_id,
user_agent=user_agent,
project_id=project_id,
letta_source=letta_source,
sdk_version=sdk_version,
experimental_params=ExperimentalParams(
message_async=(message_async == "true") if message_async else None,
letta_v1_agent=(letta_v1_agent == "true") if letta_v1_agent else None,
letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None,
modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None,
),
)
with tracer.start_as_current_span("dependency.get_headers"):
return HeaderParams(
actor_id=actor_id,
user_agent=user_agent,
project_id=project_id,
letta_source=letta_source,
sdk_version=sdk_version,
experimental_params=ExperimentalParams(
message_async=(message_async == "true") if message_async else None,
letta_v1_agent=(letta_v1_agent == "true") if letta_v1_agent else None,
letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None,
modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None,
),
)
# TODO: why does this double up the interface?
async def get_letta_server() -> "SyncServer":
# Check if a global server is already instantiated
from letta.server.rest_api.app import server
with tracer.start_as_current_span("dependency.get_letta_server"):
# Check if a global server is already instantiated
from letta.server.rest_api.app import server
# assert isinstance(server, SyncServer)
return server
# assert isinstance(server, SyncServer)
return server

View File

@@ -1,4 +1,5 @@
from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware
from letta.server.rest_api.middleware.logging import LoggingMiddleware
from letta.server.rest_api.middleware.request_id import RequestIdMiddleware
__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware"]
__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware", "RequestIdMiddleware"]

View File

@@ -2,7 +2,6 @@
Unified logging middleware that enriches log context and ensures exceptions are logged.
"""
import re
import traceback
from typing import Callable
@@ -11,6 +10,7 @@ from starlette.requests import Request
from letta.log import get_logger
from letta.log_context import clear_log_context, update_log_context
from letta.otel.tracing import tracer
from letta.schemas.enums import PrimitiveType
from letta.validators import PRIMITIVE_ID_PATTERNS
@@ -33,95 +33,96 @@ class LoggingMiddleware(BaseHTTPMiddleware):
clear_log_context()
try:
# Extract and set log context
context = {}
with tracer.start_as_current_span("middleware.logging"):
# Extract and set log context
context = {}
with tracer.start_as_current_span("middleware.logging.context"):
# Headers
actor_id = request.headers.get("user_id")
if actor_id:
context["actor_id"] = actor_id
# Headers
actor_id = request.headers.get("user_id")
if actor_id:
context["actor_id"] = actor_id
project_id = request.headers.get("x-project-id")
if project_id:
context["project_id"] = project_id
project_id = request.headers.get("x-project-id")
if project_id:
context["project_id"] = project_id
org_id = request.headers.get("x-organization-id")
if org_id:
context["org_id"] = org_id
org_id = request.headers.get("x-organization-id")
if org_id:
context["org_id"] = org_id
user_agent = request.headers.get("x-agent-id")
if user_agent:
context["agent_id"] = user_agent
user_agent = request.headers.get("x-agent-id")
if user_agent:
context["agent_id"] = user_agent
run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id")
if run_id_header:
context["run_id"] = run_id_header
run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id")
if run_id_header:
context["run_id"] = run_id_header
path = request.url.path
path_parts = [p for p in path.split("/") if p]
path = request.url.path
path_parts = [p for p in path.split("/") if p]
# Path
matched_parts = set()
for part in path_parts:
if part in matched_parts:
continue
# Path
matched_parts = set()
for part in path_parts:
if part in matched_parts:
continue
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
if pattern and pattern.match(part):
context_key = f"{primitive_type.name.lower()}_id"
if pattern and pattern.match(part):
context_key = f"{primitive_type.name.lower()}_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
context[context_key] = part
matched_parts.add(part)
break
context[context_key] = part
matched_parts.add(part)
break
# Query Parameters
for param_value in request.query_params.values():
if param_value in matched_parts:
continue
# Query Parameters
for param_value in request.query_params.values():
if param_value in matched_parts:
continue
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
if pattern and pattern.match(param_value):
context_key = f"{primitive_type.name.lower()}_id"
if pattern and pattern.match(param_value):
context_key = f"{primitive_type.name.lower()}_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
# Only set if not already set from path (path takes precedence over query params)
# Query params can overwrite headers, but path values take precedence
if context_key not in context:
context[context_key] = param_value
matched_parts.add(param_value)
break
# Only set if not already set from path (path takes precedence over query params)
# Query params can overwrite headers, but path values take precedence
if context_key not in context:
context[context_key] = param_value
matched_parts.add(param_value)
break
if context:
update_log_context(**context)
if context:
update_log_context(**context)
logger.debug(
f"Incoming request: {request.method} {request.url.path}",
extra={
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"client_host": request.client.host if request.client else None,
},
)
logger.debug(
f"Incoming request: {request.method} {request.url.path}",
extra={
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"client_host": request.client.host if request.client else None,
},
)
response = await call_next(request)
return response
response = await call_next(request)
return response
except Exception as exc:
# Extract request context

View File

@@ -0,0 +1,66 @@
"""
Middleware for extracting and propagating API request IDs from cloud-api.
Uses a pure ASGI middleware pattern to properly propagate the request_id
to streaming responses. BaseHTTPMiddleware has a known limitation where
contextvars are not propagated to streaming response generators.
See: https://github.com/encode/starlette/discussions/1729
This middleware:
1. Extracts the x-api-request-log-id header from cloud-api
2. Sets it in the contextvar (for non-streaming code)
3. Stores it in request.state (for streaming responses where contextvars don't propagate)
"""
from contextvars import ContextVar
from typing import Optional
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send
from letta.otel.tracing import tracer
# Contextvar for storing the request ID across async boundaries
request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
def get_request_id() -> Optional[str]:
"""Get the request ID from the current context."""
return request_id_var.get()
class RequestIdMiddleware:
"""
Pure ASGI middleware that extracts and propagates the API request ID.
The request ID comes from cloud-api via the x-api-request-log-id header
and is used to correlate steps with API request logs.
This middleware stores the request_id in:
- The request_id_var contextvar (works for non-streaming responses)
- request.state.request_id (works for streaming responses where contextvars may not propagate)
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
with tracer.start_as_current_span("middleware.request_id"):
# Create a Request object for easier header access
request = Request(scope)
# Extract request_id from header
request_id = request.headers.get("x-api-request-log-id")
# Set in contextvar (for non-streaming code paths)
request_id_var.set(request_id)
# Also store in request.state for streaming responses where contextvars don't propagate
# This is accessible via request.state.request_id throughout the request lifecycle
request.state.request_id = request_id
await self.app(scope, receive, send)

View File

@@ -4,7 +4,9 @@ import asyncio
import json
import time
from collections import defaultdict
from typing import AsyncIterator, Dict, List, Optional
from collections.abc import AsyncGenerator, AsyncIterator
from contextlib import aclosing
from typing import Dict, List, Optional
from letta.data_sources.redis_client import AsyncRedisClient
from letta.log import get_logger
@@ -194,12 +196,13 @@ class RedisSSEStreamWriter:
async def create_background_stream_processor(
stream_generator,
stream_generator: AsyncGenerator[str | bytes | tuple[str | bytes, int], None],
redis_client: AsyncRedisClient,
run_id: str,
writer: Optional[RedisSSEStreamWriter] = None,
run_manager: Optional[RunManager] = None,
actor: Optional[User] = None,
conversation_id: Optional[str] = None,
) -> None:
"""
Process a stream in the background and store chunks to Redis.
@@ -214,10 +217,12 @@ async def create_background_stream_processor(
writer: Optional pre-configured writer (creates new if not provided)
run_manager: Optional run manager for updating run status
actor: Optional actor for run status updates
conversation_id: Optional conversation ID for releasing lock on terminal states
"""
stop_reason = None
saw_done = False
saw_error = False
error_metadata = None
if writer is None:
writer = RedisSSEStreamWriter(redis_client)
@@ -227,32 +232,52 @@ async def create_background_stream_processor(
should_stop_writer = False
try:
async for chunk in stream_generator:
if isinstance(chunk, tuple):
chunk = chunk[0]
# Always close the upstream async generator so its `finally` blocks run.
# (e.g., stream adapters may persist terminal error metadata on close)
async with aclosing(stream_generator):
async for chunk in stream_generator:
if isinstance(chunk, tuple):
chunk = chunk[0]
# Track terminal events
if isinstance(chunk, str):
if "data: [DONE]" in chunk:
saw_done = True
if "event: error" in chunk:
saw_error = True
# Track terminal events
if isinstance(chunk, str):
if "data: [DONE]" in chunk:
saw_done = True
if "event: error" in chunk:
saw_error = True
is_done = saw_done or saw_error
# Best-effort extraction of the error payload so we can persist it on the run.
# Chunk format is typically: "event: error\ndata: {json}\n\n"
if saw_error and error_metadata is None:
try:
# Grab the first `data:` line after `event: error`
for line in chunk.splitlines():
if line.startswith("data: "):
maybe_json = line[len("data: ") :].strip()
if maybe_json and maybe_json[0] in "[{":
error_metadata = {"error": json.loads(maybe_json)}
else:
error_metadata = {"error": {"message": maybe_json}}
break
except Exception:
# Don't let parsing failures interfere with streaming
error_metadata = {"error": {"message": "Failed to parse error payload from stream."}}
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
is_done = saw_done or saw_error
if is_done:
break
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
try:
# Extract stop_reason from stop_reason chunks
maybe_json_chunk = chunk.split("data: ")[1]
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
stop_reason = maybe_stop_reason.get("stop_reason")
except:
pass
if is_done:
break
try:
# Extract stop_reason from stop_reason chunks
maybe_json_chunk = chunk.split("data: ")[1]
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
stop_reason = maybe_stop_reason.get("stop_reason")
except:
pass
# Stream ended naturally - check if we got a proper terminal
if not saw_done and not saw_error:
@@ -319,6 +344,7 @@ async def create_background_stream_processor(
run_id=run_id,
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value, metadata={"error": str(e)}),
actor=actor,
conversation_id=conversation_id,
)
finally:
if should_stop_writer:
@@ -357,10 +383,15 @@ async def create_background_stream_processor(
logger.warning(f"Unknown stop_reason '{final_stop_reason}' for run {run_id}, defaulting to completed")
run_status = RunStatus.completed
update_kwargs = {"status": run_status, "stop_reason": final_stop_reason}
if run_status == RunStatus.failed and error_metadata is not None:
update_kwargs["metadata"] = error_metadata
await run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=run_status, stop_reason=final_stop_reason),
update=RunUpdate(**update_kwargs),
actor=actor,
conversation_id=conversation_id,
)
# Belt-and-suspenders: always append a terminal [DONE] chunk to ensure clients terminate

View File

@@ -3,6 +3,7 @@ from letta.server.rest_api.routers.v1.anthropic import router as anthropic_route
from letta.server.rest_api.routers.v1.archives import router as archives_router
from letta.server.rest_api.routers.v1.blocks import router as blocks_router
from letta.server.rest_api.routers.v1.chat_completions import router as chat_completions_router, router as openai_chat_completions_router
from letta.server.rest_api.routers.v1.conversations import router as conversations_router
from letta.server.rest_api.routers.v1.embeddings import router as embeddings_router
from letta.server.rest_api.routers.v1.folders import router as folders_router
from letta.server.rest_api.routers.v1.groups import router as groups_router
@@ -36,6 +37,7 @@ ROUTERS = [
sources_router,
folders_router,
agents_router,
conversations_router,
chat_completions_router,
groups_router,
identities_router,

View File

@@ -24,8 +24,8 @@ from letta.errors import (
AgentExportProcessingError,
AgentFileImportError,
AgentNotFoundForExportError,
NoActiveRunsToCancelError,
PendingApprovalError,
RunCancelError,
)
from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
@@ -66,6 +66,7 @@ from letta.server.server import SyncServer
from letta.services.lettuce import LettuceClient
from letta.services.run_manager import RunManager
from letta.services.streaming_service import StreamingService
from letta.services.summarizer.summarizer_config import CompactionSettings
from letta.settings import settings
from letta.utils import is_1_0_sdk_version, safe_create_shielded_task, safe_create_task, truncate_file_visible_content
from letta.validators import AgentId, BlockId, FileId, MessageId, SourceId, ToolId
@@ -389,7 +390,7 @@ async def import_agent(
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
serialized_data = file.file.read()
serialized_data = await file.read()
file_size_mb = len(serialized_data) / (1024 * 1024)
logger.info(f"Agent import: loaded {file_size_mb:.2f} MB into memory")
agent_json = json.loads(serialized_data)
@@ -744,9 +745,10 @@ async def detach_source(
if not agent_state.sources:
agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=actor)
files = await server.file_manager.list_files(source_id, actor)
file_ids = [f.id for f in files]
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
# Query files_agents directly to get exactly what was attached, regardless of source changes
file_ids = await server.file_agent_manager.get_file_ids_for_agent_by_source(agent_id=agent_id, source_id=source_id, actor=actor)
if file_ids:
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
if agent_state.enable_sleeptime:
try:
@@ -775,9 +777,10 @@ async def detach_folder_from_agent(
if not agent_state.sources:
agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=actor)
files = await server.file_manager.list_files(folder_id, actor)
file_ids = [f.id for f in files]
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
# Query files_agents directly to get exactly what was attached, regardless of source changes
file_ids = await server.file_agent_manager.get_file_ids_for_agent_by_source(agent_id=agent_id, source_id=folder_id, actor=actor)
if file_ids:
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
if agent_state.enable_sleeptime:
try:
@@ -1382,6 +1385,7 @@ async def list_messages(
),
order_by: Literal["created_at"] = Query("created_at", description="Field to sort by"),
group_id: str | None = Query(None, description="Group ID to filter messages by."),
conversation_id: str | None = Query(None, description="Conversation ID to filter messages by."),
use_assistant_message: bool = Query(True, description="Whether to use assistant messages", deprecated=True),
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool.", deprecated=True),
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument.", deprecated=True),
@@ -1401,6 +1405,7 @@ async def list_messages(
before=before,
limit=limit,
group_id=group_id,
conversation_id=conversation_id,
reverse=(order == "desc"),
return_message_object=False,
use_assistant_message=use_assistant_message,
@@ -1521,6 +1526,7 @@ async def send_message(
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]
@@ -1558,6 +1564,7 @@ async def send_message(
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
client_tools=request.client_tools,
)
else:
result = await server.send_message_to_agent(
@@ -1615,6 +1622,7 @@ async def send_message(
},
}
},
deprecated=True,
)
async def send_message_streaming(
request_obj: Request, # FastAPI Request
@@ -1625,6 +1633,9 @@ async def send_message_streaming(
) -> StreamingResponse | LettaResponse:
"""
Process a user message and return the agent's response.
Deprecated: Use the `POST /{agent_id}/messages` endpoint with `streaming=true` in the request body instead.
This endpoint accepts a message from a user and processes it through the agent.
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
"""
@@ -1668,40 +1679,50 @@ async def cancel_message(
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
run_ids = request.run_ids if request else None
if not run_ids:
redis_client = await get_redis_client()
run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}")
run_id = None
try:
redis_client = await get_redis_client()
run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}")
except Exception as e:
# Redis is optional; fall back to DB to avoid surfacing 5XXs for cancellation.
logger.warning(f"Failed to look up run to cancel in redis for agent {agent_id}, falling back to DB: {e}")
if run_id is None:
logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.")
run_ids = await server.run_manager.list_runs(
runs = await server.run_manager.list_runs(
actor=actor,
statuses=[RunStatus.created, RunStatus.running],
ascending=False,
agent_id=agent_id, # NOTE: this will override agent_ids if provided
limit=100, # Limit to 10 most recent active runs for cancellation
limit=100, # Limit to 100 most recent active runs for cancellation
)
run_ids = [run.id for run in run_ids]
run_ids = [run.id for run in runs]
else:
run_ids = [run_id]
if not run_ids:
raise NoActiveRunsToCancelError(agent_id=agent_id)
results = {}
failed_to_cancel = []
for run_id in run_ids:
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
if run.metadata.get("lettuce"):
lettuce_client = await LettuceClient.create()
await lettuce_client.cancel(run_id)
try:
run = await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id)
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
if run.metadata and run.metadata.get("lettuce"):
try:
lettuce_client = await LettuceClient.create()
await lettuce_client.cancel(run_id)
except Exception as e:
# Do not surface cancellation failures as 5XXs.
logger.error(f"Failed to cancel Lettuce run {run_id}: {e}")
await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id)
except Exception as e:
results[run_id] = "failed"
# Cancellation failures should not raise errors back to the client.
logger.error(f"Failed to cancel run {run_id}: {str(e)}")
failed_to_cancel.append(run_id)
continue
results[run_id] = "cancelled"
logger.info(f"Cancelled run {run_id}")
if failed_to_cancel:
raise RunCancelError(f"Failed to cancel runs: {failed_to_cancel}")
return results
@@ -1732,6 +1753,7 @@ async def search_messages(
agent_id=request.agent_id,
project_id=request.project_id,
template_id=request.template_id,
conversation_id=request.conversation_id,
limit=request.limit,
start_date=request.start_date,
end_date=request.end_date,
@@ -1771,6 +1793,7 @@ async def _process_message_background(
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]
@@ -2075,6 +2098,7 @@ async def preview_model_request(
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]
@@ -2091,9 +2115,23 @@ async def preview_model_request(
)
@router.post("/{agent_id}/summarize", status_code=204, operation_id="summarize_messages")
class CompactionRequest(BaseModel):
compaction_settings: Optional[CompactionSettings] = Field(
default=None,
description="Optional compaction settings to use for this summarization request. If not provided, the agent's default settings will be used.",
)
class CompactionResponse(BaseModel):
summary: str
num_messages_before: int
num_messages_after: int
@router.post("/{agent_id}/summarize", response_model=CompactionResponse, operation_id="summarize_messages")
async def summarize_messages(
agent_id: AgentId,
request: Optional[CompactionRequest] = Body(default=None),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
@@ -2114,6 +2152,7 @@ async def summarize_messages(
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]
@@ -2121,12 +2160,27 @@ async def summarize_messages(
if agent_eligible and model_compatible:
agent_loop = LettaAgentV3(agent_state=agent, actor=actor)
in_context_messages = await server.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
summary_message, messages = await agent_loop.compact(
compaction_settings = request.compaction_settings if request else None
num_messages_before = len(in_context_messages)
summary_message, messages, summary = await agent_loop.compact(
messages=in_context_messages,
compaction_settings=compaction_settings,
)
num_messages_after = len(messages)
# update the agent state
logger.info(f"Summarized {num_messages_before} messages to {num_messages_after}")
if num_messages_before <= num_messages_after:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Summarization failed to reduce the number of messages. You may need to use a different CompactionSettings (e.g. using `all` mode).",
)
await agent_loop._checkpoint_messages(run_id=None, step_id=None, new_messages=[summary_message], in_context_messages=messages)
return CompactionResponse(
summary=summary,
num_messages_before=num_messages_before,
num_messages_after=num_messages_after,
)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -2185,6 +2239,7 @@ async def capture_messages(
response_messages = await server.message_manager.create_many_messages_async(messages_to_persist, actor=actor)
run_ids = []
sleeptime_group = agent.multi_agent_group if agent.multi_agent_group and agent.multi_agent_group.manager_type == "sleeptime" else None
if sleeptime_group:
sleeptime_agent_loop = SleeptimeMultiAgentV4(agent_state=agent, actor=actor, group=sleeptime_group)

View File

@@ -0,0 +1,273 @@
from datetime import timedelta
from typing import Annotated, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from pydantic import Field
from starlette.responses import StreamingResponse
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import LettaExpiredError, LettaInvalidArgumentError
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.conversation import Conversation, CreateConversation
from letta.schemas.enums import RunStatus
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_request import LettaStreamingRequest, RetrieveStreamRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
from letta.server.rest_api.streaming_response import (
StreamingResponseWithStatusCode,
add_keepalive_to_stream,
cancellation_aware_stream_wrapper,
)
from letta.server.server import SyncServer
from letta.services.conversation_manager import ConversationManager
from letta.services.run_manager import RunManager
from letta.services.streaming_service import StreamingService
from letta.settings import settings
from letta.validators import ConversationId
router = APIRouter(prefix="/conversations", tags=["conversations"])
# Instantiate manager
conversation_manager = ConversationManager()
@router.post("/", response_model=Conversation, operation_id="create_conversation")
async def create_conversation(
agent_id: str = Query(..., description="The agent ID to create a conversation for"),
conversation_create: CreateConversation = Body(default_factory=CreateConversation),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""Create a new conversation for an agent."""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await conversation_manager.create_conversation(
agent_id=agent_id,
conversation_create=conversation_create,
actor=actor,
)
@router.get("/", response_model=List[Conversation], operation_id="list_conversations")
async def list_conversations(
agent_id: str = Query(..., description="The agent ID to list conversations for"),
limit: int = Query(50, description="Maximum number of conversations to return"),
after: Optional[str] = Query(None, description="Cursor for pagination (conversation ID)"),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""List all conversations for an agent."""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await conversation_manager.list_conversations(
agent_id=agent_id,
actor=actor,
limit=limit,
after=after,
)
@router.get("/{conversation_id}", response_model=Conversation, operation_id="retrieve_conversation")
async def retrieve_conversation(
conversation_id: ConversationId,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""Retrieve a specific conversation."""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await conversation_manager.get_conversation_by_id(
conversation_id=conversation_id,
actor=actor,
)
ConversationMessagesResponse = Annotated[
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
]
@router.get(
"/{conversation_id}/messages",
response_model=ConversationMessagesResponse,
operation_id="list_conversation_messages",
)
async def list_conversation_messages(
conversation_id: ConversationId,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
before: Optional[str] = Query(
None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the conversation"
),
after: Optional[str] = Query(
None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the conversation"
),
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
):
"""
List all messages in a conversation.
Returns LettaMessage objects (UserMessage, AssistantMessage, etc.) for all
messages in the conversation, ordered by position (oldest first),
with support for cursor-based pagination.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await conversation_manager.list_conversation_messages(
conversation_id=conversation_id,
actor=actor,
limit=limit,
before=before,
after=after,
)
@router.post(
"/{conversation_id}/messages",
response_model=LettaStreamingResponse,
operation_id="send_conversation_message",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
)
async def send_conversation_message(
conversation_id: ConversationId,
request: LettaStreamingRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
) -> StreamingResponse | LettaResponse:
"""
Send a message to a conversation and get a streaming response.
This endpoint sends a message to an existing conversation and streams
the agent's response back.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Get the conversation to find the agent_id
conversation = await conversation_manager.get_conversation_by_id(
conversation_id=conversation_id,
actor=actor,
)
# Force streaming mode for this endpoint
request.streaming = True
# Use streaming service
streaming_service = StreamingService(server)
run, result = await streaming_service.create_agent_stream(
agent_id=conversation.agent_id,
actor=actor,
request=request,
run_type="send_conversation_message",
conversation_id=conversation_id,
)
return result
@router.post(
"/{conversation_id}/stream",
response_model=None,
operation_id="retrieve_conversation_stream",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {
"description": "Server-Sent Events stream",
"schema": {
"oneOf": [
{"$ref": "#/components/schemas/SystemMessage"},
{"$ref": "#/components/schemas/UserMessage"},
{"$ref": "#/components/schemas/ReasoningMessage"},
{"$ref": "#/components/schemas/HiddenReasoningMessage"},
{"$ref": "#/components/schemas/ToolCallMessage"},
{"$ref": "#/components/schemas/ToolReturnMessage"},
{"$ref": "#/components/schemas/AssistantMessage"},
{"$ref": "#/components/schemas/ApprovalRequestMessage"},
{"$ref": "#/components/schemas/ApprovalResponseMessage"},
{"$ref": "#/components/schemas/LettaPing"},
{"$ref": "#/components/schemas/LettaErrorMessage"},
{"$ref": "#/components/schemas/LettaStopReason"},
{"$ref": "#/components/schemas/LettaUsageStatistics"},
]
},
},
},
}
},
)
async def retrieve_conversation_stream(
conversation_id: ConversationId,
request: RetrieveStreamRequest = Body(None),
headers: HeaderParams = Depends(get_headers),
server: SyncServer = Depends(get_letta_server),
):
"""
Resume the stream for the most recent active run in a conversation.
This endpoint allows you to reconnect to an active background stream
for a conversation, enabling recovery from network interruptions.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
runs_manager = RunManager()
# Find the most recent active run for this conversation
active_runs = await runs_manager.list_runs(
actor=actor,
conversation_id=conversation_id,
statuses=[RunStatus.created, RunStatus.running],
limit=1,
ascending=False,
)
if not active_runs:
raise LettaInvalidArgumentError("No active runs found for this conversation.")
run = active_runs[0]
if not run.background:
raise LettaInvalidArgumentError("Run was not created in background mode, so it cannot be retrieved.")
if run.created_at < get_utc_time() - timedelta(hours=3):
raise LettaExpiredError("Run was created more than 3 hours ago, and is now expired.")
redis_client = await get_redis_client()
if isinstance(redis_client, NoopAsyncRedisClient):
raise HTTPException(
status_code=503,
detail=(
"Background streaming requires Redis to be running. "
"Please ensure Redis is properly configured. "
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
),
)
stream = redis_sse_stream_generator(
redis_client=redis_client,
run_id=run.id,
starting_after=request.starting_after if request else None,
poll_interval=request.poll_interval if request else None,
batch_size=request.batch_size if request else None,
)
if settings.enable_cancellation_aware_streaming:
stream = cancellation_aware_stream_wrapper(
stream_generator=stream,
run_manager=server.run_manager,
run_id=run.id,
actor=actor,
)
if request and request.include_pings and settings.enable_keepalive:
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
return StreamingResponseWithStatusCode(
stream,
media_type="text/event-stream",
)

View File

@@ -204,8 +204,6 @@ async def delete_folder(
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
folder = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
agent_states = await server.source_manager.list_attached_agents(source_id=folder_id, actor=actor)
files = await server.file_manager.list_files(folder_id, actor)
file_ids = [f.id for f in files]
if should_use_tpuf():
logger.info(f"Deleting folder {folder_id} from Turbopuffer")
@@ -218,7 +216,12 @@ async def delete_folder(
await delete_source_records_from_pinecone_index(source_id=folder_id, actor=actor)
for agent_state in agent_states:
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
# Query files_agents directly to get exactly what was attached to this agent
file_ids = await server.file_agent_manager.get_file_ids_for_agent_by_source(
agent_id=agent_state.id, source_id=folder_id, actor=actor
)
if file_ids:
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
if agent_state.enable_sleeptime:
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=folder.name, actor=actor)
@@ -470,6 +473,30 @@ async def list_files_for_folder(
)
@router.get("/{folder_id}/files/{file_id}", response_model=FileMetadata, operation_id="retrieve_file")
async def retrieve_file(
folder_id: FolderId,
file_id: FileId,
include_content: bool = Query(False, description="Whether to include full file content"),
server: "SyncServer" = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Retrieve a file from a folder by ID.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# NoResultFound will propagate and be handled as 404 by the global exception handler
file_metadata = await server.file_manager.get_file_by_id(
file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True
)
if file_metadata.source_id != folder_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File with id={file_id} not found in folder {folder_id}")
return file_metadata
# @router.get("/{folder_id}/files/{file_id}", response_model=FileMetadata, operation_id="get_file_metadata")
# async def get_file_metadata(
# folder_id: str,

View File

@@ -64,6 +64,7 @@ async def list_runs(
deprecated=True,
),
project_id: Optional[str] = Query(None, description="Filter runs by project ID."),
conversation_id: Optional[str] = Query(None, description="Filter runs by conversation ID."),
duration_percentile: Optional[int] = Query(
None, description="Filter runs by duration percentile (1-100). Returns runs slower than this percentile."
),
@@ -122,6 +123,7 @@ async def list_runs(
step_count_operator=step_count_operator,
tools_used=tools_used,
project_id=project_id,
conversation_id=conversation_id,
order_by=order_by,
duration_percentile=duration_percentile,
duration_filter=duration_filter,

View File

@@ -40,6 +40,7 @@ async def list_all_messages(
order: Literal["asc", "desc"] = Query(
"desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
),
conversation_id: Optional[str] = Query(None, description="Conversation ID to filter messages by"),
):
"""
List messages across all agents for the current user.
@@ -51,6 +52,7 @@ async def list_all_messages(
limit=limit,
reverse=(order == "desc"),
return_message_object=False,
conversation_id=conversation_id,
actor=actor,
)
@@ -73,6 +75,8 @@ async def search_all_messages(
actor=actor,
query_text=request.query,
search_mode=request.search_mode,
agent_id=request.agent_id,
conversation_id=request.conversation_id,
limit=request.limit,
start_date=request.start_date,
end_date=request.end_date,

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional
from fastapi import APIRouter, Body, Depends, Query, status
from fastapi.responses import JSONResponse
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.providers import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.validators import ProviderId
@@ -39,7 +39,14 @@ async def list_providers(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
providers = await server.provider_manager.list_providers_async(
before=before, after=after, limit=limit, actor=actor, name=name, provider_type=provider_type, ascending=(order == "asc")
before=before,
after=after,
limit=limit,
actor=actor,
name=name,
provider_type=provider_type,
provider_category=[ProviderCategory.byok],
ascending=(order == "asc"),
)
return providers

View File

@@ -55,6 +55,7 @@ async def list_runs(
statuses: Optional[List[str]] = Query(None, description="Filter runs by status. Can specify multiple statuses."),
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
stop_reason: Optional[StopReasonType] = Query(None, description="Filter runs by stop reason."),
conversation_id: Optional[str] = Query(None, description="Filter runs by conversation ID."),
before: Optional[str] = Query(
None, description="Run ID cursor for pagination. Returns runs that come before this run ID in the specified sort order"
),
@@ -109,6 +110,7 @@ async def list_runs(
ascending=sort_ascending,
stop_reason=stop_reason,
background=background,
conversation_id=conversation_id,
)
return runs

View File

@@ -184,8 +184,6 @@ async def delete_source(
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
files = await server.file_manager.list_files(source_id, actor)
file_ids = [f.id for f in files]
if should_use_tpuf():
logger.info(f"Deleting source {source_id} from Turbopuffer")
@@ -198,7 +196,12 @@ async def delete_source(
await delete_source_records_from_pinecone_index(source_id=source_id, actor=actor)
for agent_state in agent_states:
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
# Query files_agents directly to get exactly what was attached to this agent
file_ids = await server.file_agent_manager.get_file_ids_for_agent_by_source(
agent_id=agent_state.id, source_id=source_id, actor=actor
)
if file_ids:
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
if agent_state.enable_sleeptime:
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)

View File

@@ -689,7 +689,6 @@ async def connect_mcp_server(
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Create MCP client with respective transport type
try:
request.resolve_environment_variables()
@@ -704,8 +703,18 @@ async def connect_mcp_server(
tools = await client.list_tools(serialize=True)
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
return
except ConnectionError:
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
except ConnectionError as e:
# Only trigger OAuth flow on explicit unauthorized failures
unauthorized = False
if isinstance(e.__cause__, HTTPStatusError):
unauthorized = e.__cause__.response.status_code == 401
elif "401" in str(e) or "Unauthorized" in str(e):
unauthorized = True
if not unauthorized:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
return
if isinstance(client, AsyncStdioMCPClient):
logger.warning("OAuth not supported for stdio")
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")

View File

@@ -18,7 +18,13 @@ import letta.system as system
from letta.config import LettaConfig
from letta.constants import LETTA_TOOL_EXECUTION_DIR
from letta.data_sources.connectors import DataConnector, load_data
from letta.errors import HandleNotFoundError, LettaInvalidArgumentError, LettaMCPConnectionError, LettaMCPTimeoutError
from letta.errors import (
EmbeddingConfigRequiredError,
HandleNotFoundError,
LettaInvalidArgumentError,
LettaMCPConnectionError,
LettaMCPTimeoutError,
)
from letta.functions.mcp_client.types import MCPServerType, MCPTool, MCPToolHealth, SSEServerConfig, StdioServerConfig
from letta.functions.schema_validator import validate_complete_json_schema
from letta.groups.helpers import load_multi_agent
@@ -69,6 +75,7 @@ from letta.schemas.providers import (
TogetherProvider,
VLLMProvider,
XAIProvider,
ZAIProvider,
)
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate
from letta.schemas.secret import Secret
@@ -91,7 +98,8 @@ from letta.services.identity_manager import IdentityManager
from letta.services.job_manager import JobManager
from letta.services.llm_batch_manager import LLMBatchManager
from letta.services.mcp.base_client import AsyncBaseMCPClient
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
from letta.services.mcp.fastmcp_client import AsyncFastMCPSSEClient
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
from letta.services.mcp_manager import MCPManager
from letta.services.mcp_server_manager import MCPServerManager
@@ -109,7 +117,7 @@ from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task
from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task
config = LettaConfig.load()
logger = get_logger(__name__)
@@ -203,12 +211,10 @@ class SyncServer(object):
"""Initialize the MCP clients (there may be multiple)"""
self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {}
# TODO: Remove these in memory caches
self._llm_config_cache = {}
self._embedding_config_cache = {}
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
from letta.constants import LETTA_MODEL_ENDPOINT
self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)]
if model_settings.openai_api_key:
self._enabled_providers.append(
OpenAIProvider(
@@ -316,6 +322,14 @@ class SyncServer(object):
api_key_enc=Secret.from_plaintext(model_settings.xai_api_key),
)
)
if model_settings.zai_api_key:
self._enabled_providers.append(
ZAIProvider(
name="zai",
api_key_enc=Secret.from_plaintext(model_settings.zai_api_key),
base_url=model_settings.zai_base_url,
)
)
if model_settings.openrouter_api_key:
self._enabled_providers.append(
OpenRouterProvider(
@@ -332,6 +346,12 @@ class SyncServer(object):
print(f"Default user: {self.default_user} and org: {self.default_org}")
await self.tool_manager.upsert_base_tools_async(actor=self.default_user)
# Sync environment-based providers to database (idempotent, safe for multi-pod startup)
await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user)
# Sync provider models to database
await self._sync_provider_models_async()
# For OSS users, create a local sandbox config
oss_default_user = await self.user_manager.get_default_actor_async()
use_venv = False if not tool_settings.tool_exec_venv_name else True
@@ -368,13 +388,72 @@ class SyncServer(object):
force_recreate=True,
)
def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]:
"""Find and return an enabled provider by name.
Args:
provider_name: The name of the provider to find
Returns:
The matching enabled provider, or None if not found
"""
for provider in self._enabled_providers:
if provider.name == provider_name:
return provider
return None
async def _sync_provider_models_async(self):
"""Sync all provider models to database at startup."""
logger.info("Syncing provider models to database")
# Get persisted providers from database (they now have IDs)
persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user)
for persisted_provider in persisted_providers:
try:
# Find the matching enabled provider instance to call list_models on
enabled_provider = self._get_enabled_provider(persisted_provider.name)
if not enabled_provider:
# Only delete base providers that are no longer enabled
# BYOK providers are user-created and should not be automatically deleted
if persisted_provider.provider_category == ProviderCategory.base:
logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database")
try:
await self.provider_manager.delete_provider_by_id_async(
provider_id=persisted_provider.id, actor=self.default_user
)
except NoResultFound:
# Provider was already deleted (race condition in multi-pod startup)
logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping")
else:
logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync")
continue
# Fetch models from provider
llm_models = await enabled_provider.list_llm_models_async()
embedding_models = await enabled_provider.list_embedding_models_async()
# Save to database with the persisted provider (which has an ID)
await self.provider_manager.sync_provider_models_async(
provider=persisted_provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=None, # Global models
)
logger.info(
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
)
except Exception as e:
logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True)
async def init_mcp_clients(self):
# TODO: remove this
mcp_server_configs = self.get_mcp_servers()
for server_name, server_config in mcp_server_configs.items():
if server_config.type == MCPServerType.SSE:
self.mcp_clients[server_name] = AsyncSSEMCPClient(server_config)
self.mcp_clients[server_name] = AsyncFastMCPSSEClient(server_config)
elif server_config.type == MCPServerType.STDIO:
self.mcp_clients[server_name] = AsyncStdioMCPClient(server_config)
else:
@@ -395,39 +474,6 @@ class SyncServer(object):
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
@trace_method
def get_cached_llm_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
return self._llm_config_cache[key]
@trace_method
async def get_cached_llm_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs)
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
return self._llm_config_cache[key]
@trace_method
def get_cached_embedding_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
return self._embedding_config_cache[key]
# @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs))
@trace_method
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs)
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
return self._embedding_config_cache[key]
@trace_method
async def create_agent_async(
self,
@@ -461,10 +507,9 @@ class SyncServer(object):
"max_reasoning_tokens": request.max_reasoning_tokens,
"enable_reasoner": request.enable_reasoner,
}
config_params.update(additional_config_params)
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
log_event(name="start get_llm_config_from_handle", attributes=config_params)
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
log_event(name="end get_llm_config_from_handle", attributes=config_params)
if request.model and isinstance(request.model, str):
assert request.llm_config.handle == request.model, (
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
@@ -484,19 +529,17 @@ class SyncServer(object):
if request.embedding_config is None:
if request.embedding is None:
if settings.default_embedding_handle is None:
raise LettaInvalidArgumentError(
"Must specify either embedding or embedding_config in request", argument_name="embedding"
)
else:
if settings.default_embedding_handle is not None:
request.embedding = settings.default_embedding_handle
embedding_config_params = {
"handle": request.embedding,
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
}
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params)
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
# Only resolve embedding config if we have an embedding handle
if request.embedding is not None:
embedding_config_params = {
"handle": request.embedding,
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
}
log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params)
request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params)
log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params)
log_event(name="start create_agent db")
main_agent = await self.agent_manager.create_agent_async(
@@ -545,9 +588,9 @@ class SyncServer(object):
"context_window_limit": request.context_window_limit,
"max_tokens": request.max_tokens,
}
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
log_event(name="start get_llm_config_from_handle", attributes=config_params)
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
log_event(name="end get_llm_config_from_handle", attributes=config_params)
# update with model_settings
if request.model_settings is not None:
@@ -584,6 +627,8 @@ class SyncServer(object):
)
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
if main_agent.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_sleeptime_agent")
request = CreateAgent(
name=main_agent.name + "-sleeptime",
agent_type=AgentType.sleeptime_agent,
@@ -616,6 +661,8 @@ class SyncServer(object):
return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor)
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
if main_agent.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_voice_sleeptime_agent")
# TODO: Inject system
request = CreateAgent(
name=main_agent.name + "-sleeptime",
@@ -771,6 +818,7 @@ class SyncServer(object):
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
include_err: Optional[bool] = None,
conversation_id: Optional[str] = None,
) -> Union[List[Message], List[LettaMessage]]:
records = await self.message_manager.list_messages(
agent_id=agent_id,
@@ -781,6 +829,7 @@ class SyncServer(object):
ascending=not reverse,
group_id=group_id,
include_err=include_err,
conversation_id=conversation_id,
)
if not return_message_object:
@@ -816,6 +865,7 @@ class SyncServer(object):
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
include_err: Optional[bool] = None,
conversation_id: Optional[str] = None,
) -> Union[List[Message], List[LettaMessage]]:
records = await self.message_manager.list_messages(
agent_id=None,
@@ -826,6 +876,7 @@ class SyncServer(object):
ascending=not reverse,
group_id=group_id,
include_err=include_err,
conversation_id=conversation_id,
)
if not return_message_object:
@@ -989,6 +1040,8 @@ class SyncServer(object):
async def create_document_sleeptime_agent_async(
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
) -> AgentState:
if main_agent.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_document_sleeptime_agent")
try:
block = await self.agent_manager.get_block_with_label_async(agent_id=main_agent.id, block_label=source.name, actor=actor)
except:
@@ -1043,6 +1096,18 @@ class SyncServer(object):
passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor)
return passage_count, document_count
def _get_provider_sort_key(self, model: LLMConfig) -> Tuple[int, str, str]:
"""Get sort key for a model: (provider_priority, provider_name, model_name)"""
provider_priority = constants.PROVIDER_ORDER.get(model.provider_name, 999)
return (provider_priority, model.provider_name or "", model.model or "")
def _get_embedding_provider_sort_key(self, model: EmbeddingConfig) -> Tuple[int, str, str]:
"""Get sort key for an embedding model: (provider_priority, provider_name, model_name)"""
# Extract provider name from handle (format: "provider_name/model_name")
provider_name = model.handle.split("/")[0] if model.handle and "/" in model.handle else ""
provider_priority = constants.PROVIDER_ORDER.get(provider_name, 999)
return (provider_priority, provider_name, model.embedding_model or "")
@trace_method
async def list_llm_models_async(
self,
@@ -1051,73 +1116,121 @@ class SyncServer(object):
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""Asynchronously list available models with maximum concurrency"""
import asyncio
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
async def get_provider_models(provider: Provider) -> list[LLMConfig]:
try:
async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS):
return await provider.list_llm_models_async()
except asyncio.TimeoutError:
logger.warning(f"Timeout while listing LLM models for provider {provider}")
return []
except Exception as e:
logger.exception(f"Error while listing LLM models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
# Flatten the results
"""List available LLM models - base from DB, BYOK from provider endpoints"""
llm_models = []
for models in provider_results:
llm_models.extend(models)
# Get local configs - if this is potentially slow, consider making it async too
local_configs = self.get_local_llm_configs()
llm_models.extend(local_configs)
# Determine which categories to include
include_base = not provider_category or ProviderCategory.base in provider_category
include_byok = not provider_category or ProviderCategory.byok in provider_category
# dedupe by handle for uniqueness
# Seems like this is required from the tests?
seen_handles = set()
unique_models = []
for model in llm_models:
if model.handle not in seen_handles:
seen_handles.add(model.handle)
unique_models.append(model)
# Get base provider models from database
if include_base:
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="llm",
enabled=True,
)
return unique_models
# Build LLMConfig objects from database
provider_cache: Dict[str, Provider] = {}
for model in provider_models:
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
# Skip non-base providers (they're handled separately)
if provider.provider_category != ProviderCategory.base:
continue
# Apply provider_name/provider_type filters if specified
if provider_name and provider.name != provider_name:
continue
if provider_type and provider.provider_type != provider_type:
continue
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url or model.model_endpoint_type,
context_window=model.max_context_window or 16384,
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
)
llm_models.append(llm_config)
# Get BYOK provider models by hitting provider endpoints directly
if include_byok:
byok_providers = await self.provider_manager.list_providers_async(
actor=actor,
name=provider_name,
provider_type=provider_type,
provider_category=[ProviderCategory.byok],
)
for provider in byok_providers:
try:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_llm_models_async()
llm_models.extend(models)
except Exception as e:
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
# Sort by provider order (matching old _enabled_providers order), then by model name
llm_models.sort(key=self._get_provider_sort_key)
return llm_models
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
"""Asynchronously list available embedding models with maximum concurrency"""
import asyncio
# Get all eligible providers first
providers = await self.get_enabled_providers_async(actor=actor)
# Fetch embedding models from each provider concurrently
async def get_provider_embedding_models(provider):
try:
# All providers now have list_embedding_models_async
return await provider.list_embedding_models_async()
except Exception as e:
logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
# Flatten the results
"""List available embedding models - base from DB, BYOK from provider endpoints"""
embedding_models = []
for models in provider_results:
embedding_models.extend(models)
# Get base provider models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="embedding",
enabled=True,
)
# Build EmbeddingConfig objects from database (base providers only)
provider_cache: Dict[str, Provider] = {}
for model in provider_models:
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
# Skip non-base providers (they're handled separately)
if provider.provider_category != ProviderCategory.base:
continue
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or model.model_endpoint_type,
embedding_dim=model.embedding_dim or 1536,
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,
)
embedding_models.append(embedding_config)
# Get BYOK provider models by hitting provider endpoints directly
byok_providers = await self.provider_manager.list_providers_async(
actor=actor,
provider_category=[ProviderCategory.byok],
)
for provider in byok_providers:
try:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_embedding_models_async()
embedding_models.extend(models)
except Exception as e:
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")
# Sort by provider order (matching old _enabled_providers order), then by model name
embedding_models.sort(key=self._get_embedding_provider_sort_key)
return embedding_models
@@ -1128,25 +1241,17 @@ class SyncServer(object):
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[Provider]:
providers = []
if not provider_category or ProviderCategory.base in provider_category:
providers_from_env = [p for p in self._enabled_providers]
providers.extend(providers_from_env)
# Query all persisted providers from database
persisted_providers = await self.provider_manager.list_providers_async(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers = [p.cast_to_subtype() for p in persisted_providers]
if not provider_category or ProviderCategory.byok in provider_category:
providers_from_db = await self.provider_manager.list_providers_async(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers_from_db = [p.cast_to_subtype() for p in providers_from_db if p.provider_category == ProviderCategory.byok]
providers.extend(providers_from_db)
if provider_name is not None:
providers = [p for p in providers if p.name == provider_name]
if provider_type is not None:
providers = [p for p in providers if p.provider_type == provider_type]
# Filter by category if specified
if provider_category:
providers = [p for p in providers if p.provider_category in provider_category]
return providers
@@ -1160,32 +1265,19 @@ class SyncServer(object):
max_reasoning_tokens: Optional[int] = None,
enable_reasoner: Optional[bool] = None,
) -> LLMConfig:
# Use provider_manager to get LLMConfig from handle
try:
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_llm_configs = await provider.list_llm_models_async()
llm_configs = [config for config in all_llm_configs if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in all_llm_configs if config.model == model_name]
if not llm_configs:
available_handles = [config.handle for config in all_llm_configs]
raise HandleNotFoundError(handle, available_handles)
except ValueError as e:
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
if not llm_configs:
raise e
if len(llm_configs) == 1:
llm_config = llm_configs[0]
elif len(llm_configs) > 1:
raise LettaInvalidArgumentError(
f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name"
llm_config = await self.provider_manager.get_llm_config_from_handle(
handle=handle,
actor=actor,
)
else:
llm_config = llm_configs[0]
except Exception as e:
# Convert to HandleNotFoundError for backwards compatibility
from letta.orm.errors import NoResultFound
if isinstance(e, NoResultFound):
raise HandleNotFoundError(handle, [])
raise
if context_window_limit is not None:
if context_window_limit > llm_config.context_window:
@@ -1217,33 +1309,22 @@ class SyncServer(object):
async def get_embedding_config_from_handle_async(
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
) -> EmbeddingConfig:
# Use provider_manager to get EmbeddingConfig from handle
try:
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_embedding_configs = await provider.list_embedding_models_async()
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
if not embedding_configs:
raise LettaInvalidArgumentError(
f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name"
)
except LettaInvalidArgumentError as e:
# search local configs
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
if not embedding_configs:
raise e
if len(embedding_configs) == 1:
embedding_config = embedding_configs[0]
elif len(embedding_configs) > 1:
raise LettaInvalidArgumentError(
f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name"
embedding_config = await self.provider_manager.get_embedding_config_from_handle(
handle=handle,
actor=actor,
)
else:
embedding_config = embedding_configs[0]
except Exception as e:
# Convert to LettaInvalidArgumentError for backwards compatibility
from letta.orm.errors import NoResultFound
if embedding_chunk_size:
embedding_config.embedding_chunk_size = embedding_chunk_size
if isinstance(e, NoResultFound):
raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle")
raise
# Override chunk size if provided
embedding_config.embedding_chunk_size = embedding_chunk_size
return embedding_config
@@ -1252,57 +1333,17 @@ class SyncServer(object):
providers = [provider for provider in all_providers if provider.name == provider_name]
if not providers:
raise LettaInvalidArgumentError(
f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in self._enabled_providers])})",
f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in all_providers])})",
argument_name="provider_name",
)
elif len(providers) > 1:
logger.warning(f"Multiple providers with name {provider_name} supported", argument_name="provider_name")
logger.warning(f"Multiple providers with name {provider_name} supported")
provider = providers[0]
else:
provider = providers[0]
return provider
def get_local_llm_configs(self):
llm_models = []
# NOTE: deprecated
# try:
# llm_configs_dir = os.path.expanduser("~/.letta/llm_configs")
# if os.path.exists(llm_configs_dir):
# for filename in os.listdir(llm_configs_dir):
# if filename.endswith(".json"):
# filepath = os.path.join(llm_configs_dir, filename)
# try:
# with open(filepath, "r") as f:
# config_data = json.load(f)
# llm_config = LLMConfig(**config_data)
# llm_models.append(llm_config)
# except (json.JSONDecodeError, ValueError) as e:
# logger.warning(f"Error parsing LLM config file {filename}: {e}")
# except Exception as e:
# logger.warning(f"Error reading LLM configs directory: {e}")
return llm_models
def get_local_embedding_configs(self):
embedding_models = []
# NOTE: deprecated
# try:
# embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs")
# if os.path.exists(embedding_configs_dir):
# for filename in os.listdir(embedding_configs_dir):
# if filename.endswith(".json"):
# filepath = os.path.join(embedding_configs_dir, filename)
# try:
# with open(filepath, "r") as f:
# config_data = json.load(f)
# embedding_config = EmbeddingConfig(**config_data)
# embedding_models.append(embedding_config)
# except (json.JSONDecodeError, ValueError) as e:
# logger.warning(f"Error parsing embedding config file {filename}: {e}")
# except Exception as e:
# logger.warning(f"Error reading embedding configs directory: {e}")
return embedding_models
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
"""Add a new LLM model"""
@@ -1533,7 +1574,7 @@ class SyncServer(object):
# Attempt to initialize the connection to the server
if server_config.type == MCPServerType.SSE:
new_mcp_client = AsyncSSEMCPClient(server_config)
new_mcp_client = AsyncFastMCPSSEClient(server_config)
elif server_config.type == MCPServerType.STDIO:
new_mcp_client = AsyncStdioMCPClient(server_config)
else:

View File

@@ -25,7 +25,7 @@ from letta.constants import (
INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES,
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
)
from letta.errors import LettaAgentNotFoundError
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.llm_api.llm_client import LLMClient
@@ -60,6 +60,7 @@ from letta.schemas.agent import (
from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType, PrimitiveType, ProviderType, TagMatchMode, ToolType, VectorDBProvider
from letta.schemas.environment_variables import AgentEnvironmentVariable as PydanticAgentEnvVar
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.group import Group as PydanticGroup, ManagerType
from letta.schemas.letta_stop_reason import StopReasonType
@@ -111,7 +112,13 @@ from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.settings import DatabaseChoice, model_settings, settings
from letta.utils import calculate_file_defaults_based_on_context_window, enforce_types, united_diff
from letta.utils import (
bounded_gather,
calculate_file_defaults_based_on_context_window,
decrypt_agent_secrets,
enforce_types,
united_diff,
)
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -336,8 +343,8 @@ class AgentManager:
ignore_invalid_tools: bool = False,
) -> PydanticAgentState:
# validate required configs
if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")
if not agent_create.llm_config:
raise ValueError("llm_config is required")
# For v1 agents, enforce sane defaults even when reasoning is omitted
if agent_create.agent_type == AgentType.letta_v1_agent:
@@ -556,19 +563,18 @@ class AgentManager:
agent_secrets = agent_create.secrets or agent_create.tool_exec_environment_variables
if agent_secrets:
# Encrypt environment variable values
env_rows = []
for key, val in agent_secrets.items():
# Encrypt value (Secret.from_plaintext handles missing encryption key internally)
value_secret = Secret.from_plaintext(val)
row = {
# Encrypt environment variable values concurrently (async to avoid blocking event loop)
secrets_dict = await Secret.from_plaintexts_async(agent_secrets)
env_rows = [
{
"agent_id": aid,
"key": key,
"value": "", # Empty string for NOT NULL constraint (deprecated, use value_enc)
"value_enc": value_secret.get_encrypted(),
"value_enc": secret.get_encrypted(),
"organization_id": actor.organization_id,
}
env_rows.append(row)
for key, secret in secrets_dict.items()
]
result = await session.execute(insert(AgentEnvironmentVariable).values(env_rows).returning(AgentEnvironmentVariable.id))
env_rows = [{**row, "id": env_var_id} for row, env_var_id in zip(env_rows, result.scalars().all())]
@@ -588,8 +594,10 @@ class AgentManager:
result = await new_agent.to_pydantic_async(include_relationships=include_relationships)
if agent_secrets and env_rows:
result.tool_exec_environment_variables = [AgentEnvironmentVariable(**row) for row in env_rows]
result.secrets = [AgentEnvironmentVariable(**row) for row in env_rows]
# Use Pydantic schema (not ORM model) with plaintext to avoid sync decryption in model validator
env_vars = [PydanticAgentEnvVar(**{**row, "value": agent_secrets[row["key"]]}) for row in env_rows]
result.tool_exec_environment_variables = env_vars
result.secrets = env_vars
# initial message sequence (skip if _init_with_no_messages is True)
if not _init_with_no_messages:
@@ -824,7 +832,7 @@ class AgentManager:
)
session.expire(agent, ["tags"])
agent_secrets = agent_update.secrets or agent_update.tool_exec_environment_variables
agent_secrets = agent_update.secrets if agent_update.secrets is not None else agent_update.tool_exec_environment_variables
if agent_secrets is not None:
# Fetch existing environment variables to check if values changed
result = await session.execute(select(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
@@ -832,25 +840,35 @@ class AgentManager:
# TODO: do we need to delete each time or can we just upsert?
await session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
# Encrypt environment variable values
# Only re-encrypt if the value has actually changed
# Decrypt existing values to check for changes (async to avoid blocking)
existing_values: dict[str, str | None] = {}
for k, existing_env in existing_env_vars.items():
if existing_env.value_enc:
existing_secret = Secret.from_encrypted(existing_env.value_enc)
existing_values[k] = await existing_secret.get_plaintext_async()
else:
existing_values[k] = None
# Identify values that need encryption (new or changed)
to_encrypt = {
k: v
for k, v in agent_secrets.items()
if k not in existing_env_vars or existing_values.get(k) != v or not existing_env_vars[k].value_enc
}
# Batch encrypt new/changed values concurrently (async to avoid blocking event loop)
new_secrets = await Secret.from_plaintexts_async(to_encrypt) if to_encrypt else {}
# Build rows, reusing existing encrypted values where unchanged
env_rows = []
for k, v in agent_secrets.items():
# Check if value changed to avoid unnecessary re-encryption
existing_env = existing_env_vars.get(k)
existing_value = None
if existing_env and existing_env.value_enc:
existing_secret = Secret.from_encrypted(existing_env.value_enc)
existing_value = await existing_secret.get_plaintext_async()
# Encrypt value (reuse existing encrypted value if unchanged)
if existing_value == v and existing_env and existing_env.value_enc:
# Value unchanged, reuse existing encrypted value
value_enc = existing_env.value_enc
if k in new_secrets:
# New or changed value - use newly encrypted value
value_enc = new_secrets[k].get_encrypted()
else:
# Value changed or new, encrypt
value_secret = Secret.from_plaintext(v)
value_enc = value_secret.get_encrypted()
# Value unchanged - reuse existing encrypted value
value_enc = existing_env_vars[k].value_enc
row = {
"agent_id": aid,
@@ -875,7 +893,11 @@ class AgentManager:
await session.flush()
await session.refresh(agent)
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
@enforce_types
@trace_method
@@ -899,7 +921,8 @@ class AgentManager:
agent.message_ids = message_ids
await agent.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
# context manager now handles commits
# await session.commit()
@trace_method
async def list_agents_async(
@@ -969,10 +992,16 @@ class AgentManager:
query = query.limit(limit)
result = await session.execute(query)
agents = result.scalars().all()
return await asyncio.gather(
*[agent.to_pydantic_async(include_relationships=include_relationships, include=include) for agent in agents]
# Convert to pydantic without decrypting (keeps encrypted values)
# This allows us to release the DB connection before expensive PBKDF2 operations
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False) for agent in agents]
)
# DB session released - now decrypt secrets outside session to prevent connection holding
return await decrypt_agent_secrets(agents_encrypted)
@trace_method
async def count_agents_async(
self,
@@ -1067,7 +1096,12 @@ class AgentManager:
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
result = await session.execute(query)
return await asyncio.gather(*[agent.to_pydantic_async() for agent in result.scalars()])
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather([agent.to_pydantic_async(decrypt=False) for agent in result.scalars()])
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@trace_method
async def size_async(
@@ -1092,8 +1126,8 @@ class AgentManager:
) -> PydanticAgentState:
"""Fetch an agent by its ID."""
async with db_registry.async_session() as session:
try:
try:
async with db_registry.async_session() as session:
query = select(AgentModel)
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
query = query.where(AgentModel.id == agent_id)
@@ -1105,13 +1139,17 @@ class AgentManager:
if agent is None:
raise NoResultFound(f"Agent with ID {agent_id} not found")
return await agent.to_pydantic_async(include_relationships=include_relationships, include=include)
except NoResultFound:
# Re-raise NoResultFound without logging to preserve 404 handling
raise
except Exception as e:
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
raise
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
except NoResultFound:
# Re-raise NoResultFound without logging to preserve 404 handling
raise
except Exception as e:
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
raise
@enforce_types
@trace_method
@@ -1122,8 +1160,8 @@ class AgentManager:
include_relationships: Optional[List[str]] = None,
) -> list[PydanticAgentState]:
"""Fetch a list of agents by their IDs."""
async with db_registry.async_session() as session:
try:
try:
async with db_registry.async_session() as session:
query = select(AgentModel)
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
query = query.where(AgentModel.id.in_(agent_ids))
@@ -1136,10 +1174,16 @@ class AgentManager:
logger.warning(f"No agents found with IDs: {agent_ids}")
return []
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
except Exception as e:
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
raise
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=include_relationships, decrypt=False) for agent in agents]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
except Exception as e:
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
raise
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -1216,7 +1260,8 @@ class AgentManager:
await session.commit()
for agent in agents_to_delete:
await session.delete(agent)
await session.commit()
# context manager now handles commits
# await session.commit()
except Exception as e:
await session.rollback()
logger.exception(f"Failed to hard delete Agent with ID {agent_id}")
@@ -1357,7 +1402,7 @@ class AgentManager:
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
if agent_state.message_ids == []:
if not agent_state.message_ids: # Handles both None and empty list
curr_system_message = None
else:
curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
@@ -1694,7 +1739,12 @@ class AgentManager:
agent = await agent.update_async(session, actor=actor)
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
# or even better, not refresh at all.
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
@enforce_types
@trace_method
@@ -1856,7 +1906,12 @@ class AgentManager:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
# or even better, not refresh at all.
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
# ======================================================================================================================
# Block management
@@ -1888,25 +1943,25 @@ class AgentManager:
) -> PydanticBlock:
"""Gets a block attached to an agent by its label."""
async with db_registry.async_session() as session:
block = None
matched_block = None
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
for block in agent.core_memory:
if block.label == block_label:
block = block
matched_block = block
break
if not block:
if not matched_block:
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Validate limit constraints before updating
validate_block_limit_constraint(update_data, block)
validate_block_limit_constraint(update_data, matched_block)
for key, value in update_data.items():
setattr(block, key, value)
setattr(matched_block, key, value)
await block.update_async(session, actor=actor)
return block.to_pydantic()
await matched_block.update_async(session, actor=actor)
return matched_block.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -1944,7 +1999,11 @@ class AgentManager:
# TODO: I have too many things rn so lets look at this later
# await session.commit()
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
@enforce_types
@trace_method
@@ -1965,7 +2024,12 @@ class AgentManager:
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
# ======================================================================================================================
# Passage Management
@@ -2400,7 +2464,7 @@ class AgentManager:
# Use ISO format if no timezone is set
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
result_dict = {"timestamp": formatted_timestamp, "content": passage.text, "tags": passage.tags or []}
result_dict = {"id": passage.id, "timestamp": formatted_timestamp, "content": passage.text, "tags": passage.tags or []}
# Add relevance metadata if available
if metadata:
@@ -2570,7 +2634,8 @@ class AgentManager:
agent.tool_rules = tool_rules
session.add(agent)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -2643,7 +2708,8 @@ class AgentManager:
else:
logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}")
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -2767,7 +2833,8 @@ class AgentManager:
else:
logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}")
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -2804,7 +2871,8 @@ class AgentManager:
else:
logger.info(f"Detached all {detached_count} tools from agent {agent_id}")
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -2832,7 +2900,8 @@ class AgentManager:
agent.tool_rules = tool_rules
session.add(agent)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -3067,10 +3136,14 @@ class AgentManager:
# Generate visible content for each file
line_chunker = LineChunker()
visible_content_map = {}
for file_metadata in file_metadata_with_content:
for i, file_metadata in enumerate(file_metadata_with_content):
content_lines = line_chunker.chunk_text(file_metadata=file_metadata)
visible_content_map[file_metadata.file_name] = "\n".join(content_lines)
# Yield to event loop every 100 files to prevent saturation
if i > 0 and i % 100 == 0:
await asyncio.sleep(0)
# Use bulk attach to avoid race conditions and duplicate LRU eviction decisions
closed_files = await self.file_agent_manager.attach_files_bulk(
agent_id=agent_state.id,

View File

@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
from sqlalchemy import delete, or_, select
from letta.errors import EmbeddingConfigRequiredError
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
@@ -17,7 +18,7 @@ from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
from letta.settings import DatabaseChoice, settings
from letta.utils import enforce_types
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -191,7 +192,8 @@ class ArchiveManager:
is_owner=is_owner,
)
session.add(archives_agents)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -224,7 +226,8 @@ class ArchiveManager:
else:
logger.info(f"Detached agent {agent_id} from archive {archive_id}")
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -431,6 +434,8 @@ class ArchiveManager:
return archive
# Create a default archive for this agent
if agent_state.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="create_default_archive")
archive_name = f"{agent_state.name}'s Archive"
archive = await self.create_archive_async(
name=archive_name,
@@ -549,8 +554,13 @@ class ArchiveManager:
result = await session.execute(query)
agents_orm = result.scalars().all()
agents = await asyncio.gather(*[agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents_orm])
return agents
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@enforce_types
@trace_method
@@ -609,6 +619,7 @@ class ArchiveManager:
# update the archive with the namespace
await session.execute(update(ArchiveModel).where(ArchiveModel.id == archive_id).values(_vector_db_namespace=namespace_name))
await session.commit()
# context manager now handles commits
# await session.commit()
return namespace_name

View File

@@ -19,7 +19,7 @@ from letta.schemas.enums import ActorType, PrimitiveType
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import DatabaseChoice, settings
from letta.utils import enforce_types
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -100,7 +100,8 @@ class BlockManager:
block = BlockModel(**data, organization_id=actor.organization_id)
await block.create_async(session, actor=actor, no_commit=True, no_refresh=True)
pydantic_block = block.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_block
@enforce_types
@@ -119,18 +120,19 @@ class BlockManager:
async with db_registry.async_session() as session:
# Validate all blocks before creating any
validated_data = []
for block in blocks:
block_data = block.model_dump(to_orm=True, exclude_none=True)
validate_block_creation(block_data)
validated_data.append(block_data)
block_models = [
BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks
]
block_models = [BlockModel(**data, organization_id=actor.organization_id) for data in validated_data]
created_models = await BlockModel.batch_create_async(
items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True
)
result = [m.to_pydantic() for m in created_models]
await session.commit()
# context manager now handles commits
# await session.commit()
return result
@enforce_types
@@ -150,7 +152,8 @@ class BlockManager:
await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
pydantic_block = block.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_block
@enforce_types
@@ -502,10 +505,13 @@ class BlockManager:
result = await session.execute(query)
agents_orm = result.scalars().all()
agents = await asyncio.gather(
*[agent.to_pydantic_async(include_relationships=include_relationships, include=include) for agent in agents_orm]
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
)
return agents
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@enforce_types
@trace_method
@@ -591,7 +597,8 @@ class BlockManager:
new_val = new_val[: block.limit]
block.value = new_val
await session.commit()
# context manager now handles commits
# await session.commit()
if return_hydrated:
# TODO: implement for async
@@ -669,7 +676,8 @@ class BlockManager:
# 7) Flush changes, then commit once
block = await block.update_async(db_session=session, actor=actor, no_commit=True)
await session.commit()
# context manager now handles commits
# await session.commit()
return block.to_pydantic()
@@ -757,7 +765,8 @@ class BlockManager:
block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor)
# 4) Commit
await session.commit()
# context manager now handles commits
# await session.commit()
return block.to_pydantic()
@enforce_types
@@ -805,5 +814,6 @@ class BlockManager:
block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor)
await session.commit()
# context manager now handles commits
# await session.commit()
return block.to_pydantic()

View File

@@ -0,0 +1,357 @@
from typing import List, Optional
from sqlalchemy import func, select
from letta.orm.conversation import Conversation as ConversationModel
from letta.orm.conversation_messages import ConversationMessage as ConversationMessageModel
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.otel.tracing import trace_method
from letta.schemas.conversation import Conversation as PydanticConversation, CreateConversation, UpdateConversation
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
class ConversationManager:
"""Manager class to handle business logic related to Conversations."""
@enforce_types
@trace_method
async def create_conversation(
self,
agent_id: str,
conversation_create: CreateConversation,
actor: PydanticUser,
) -> PydanticConversation:
"""Create a new conversation for an agent."""
async with db_registry.async_session() as session:
conversation = ConversationModel(
agent_id=agent_id,
summary=conversation_create.summary,
organization_id=actor.organization_id,
)
await conversation.create_async(session, actor=actor)
return conversation.to_pydantic()
@enforce_types
@trace_method
async def get_conversation_by_id(
self,
conversation_id: str,
actor: PydanticUser,
) -> PydanticConversation:
"""Retrieve a conversation by its ID, including in-context message IDs."""
async with db_registry.async_session() as session:
conversation = await ConversationModel.read_async(
db_session=session,
identifier=conversation_id,
actor=actor,
check_is_deleted=True,
)
# Get the in-context message IDs for this conversation
message_ids = await self.get_message_ids_for_conversation(
conversation_id=conversation_id,
actor=actor,
)
# Build the pydantic model with in_context_message_ids
pydantic_conversation = conversation.to_pydantic()
pydantic_conversation.in_context_message_ids = message_ids
return pydantic_conversation
@enforce_types
@trace_method
async def list_conversations(
self,
agent_id: str,
actor: PydanticUser,
limit: int = 50,
after: Optional[str] = None,
) -> List[PydanticConversation]:
"""List conversations for an agent with cursor-based pagination."""
async with db_registry.async_session() as session:
conversations = await ConversationModel.list_async(
db_session=session,
actor=actor,
agent_id=agent_id,
limit=limit,
after=after,
ascending=False,
)
return [conv.to_pydantic() for conv in conversations]
@enforce_types
@trace_method
async def update_conversation(
self,
conversation_id: str,
conversation_update: UpdateConversation,
actor: PydanticUser,
) -> PydanticConversation:
"""Update a conversation."""
async with db_registry.async_session() as session:
conversation = await ConversationModel.read_async(
db_session=session,
identifier=conversation_id,
actor=actor,
)
# Set attributes on the model
update_data = conversation_update.model_dump(exclude_none=True)
for key, value in update_data.items():
setattr(conversation, key, value)
# Commit the update
updated_conversation = await conversation.update_async(
db_session=session,
actor=actor,
)
return updated_conversation.to_pydantic()
@enforce_types
@trace_method
async def delete_conversation(
self,
conversation_id: str,
actor: PydanticUser,
) -> None:
"""Soft delete a conversation."""
async with db_registry.async_session() as session:
conversation = await ConversationModel.read_async(
db_session=session,
identifier=conversation_id,
actor=actor,
)
# Soft delete by setting is_deleted flag
conversation.is_deleted = True
await conversation.update_async(db_session=session, actor=actor)
# ==================== Message Management Methods ====================
@enforce_types
@trace_method
async def get_message_ids_for_conversation(
self,
conversation_id: str,
actor: PydanticUser,
) -> List[str]:
"""
Get ordered message IDs for a conversation.
Returns message IDs ordered by position in the conversation.
Only returns messages that are currently in_context.
"""
async with db_registry.async_session() as session:
query = (
select(ConversationMessageModel.message_id)
.where(
ConversationMessageModel.conversation_id == conversation_id,
ConversationMessageModel.organization_id == actor.organization_id,
ConversationMessageModel.in_context == True,
ConversationMessageModel.is_deleted == False,
)
.order_by(ConversationMessageModel.position)
)
result = await session.execute(query)
return list(result.scalars().all())
@enforce_types
@trace_method
async def get_messages_for_conversation(
self,
conversation_id: str,
actor: PydanticUser,
) -> List[PydanticMessage]:
"""
Get ordered Message objects for a conversation.
Returns full Message objects ordered by position in the conversation.
Only returns messages that are currently in_context.
"""
async with db_registry.async_session() as session:
query = (
select(MessageModel)
.join(
ConversationMessageModel,
MessageModel.id == ConversationMessageModel.message_id,
)
.where(
ConversationMessageModel.conversation_id == conversation_id,
ConversationMessageModel.organization_id == actor.organization_id,
ConversationMessageModel.in_context == True,
ConversationMessageModel.is_deleted == False,
)
.order_by(ConversationMessageModel.position)
)
result = await session.execute(query)
return [msg.to_pydantic() for msg in result.scalars().all()]
@enforce_types
@trace_method
async def add_messages_to_conversation(
self,
conversation_id: str,
agent_id: str,
message_ids: List[str],
actor: PydanticUser,
starting_position: Optional[int] = None,
) -> None:
"""
Add messages to a conversation's tracking table.
Creates ConversationMessage entries with auto-incrementing positions.
Args:
conversation_id: The conversation to add messages to
agent_id: The agent ID
message_ids: List of message IDs to add
actor: The user performing the action
starting_position: Optional starting position (defaults to next available)
"""
if not message_ids:
return
async with db_registry.async_session() as session:
# Get starting position if not provided
if starting_position is None:
query = select(func.coalesce(func.max(ConversationMessageModel.position), -1)).where(
ConversationMessageModel.conversation_id == conversation_id,
ConversationMessageModel.organization_id == actor.organization_id,
)
result = await session.execute(query)
max_position = result.scalar()
# Use explicit None check instead of `or` to handle position=0 correctly
if max_position is None:
max_position = -1
starting_position = max_position + 1
# Create ConversationMessage entries
for i, message_id in enumerate(message_ids):
conv_msg = ConversationMessageModel(
conversation_id=conversation_id,
agent_id=agent_id,
message_id=message_id,
position=starting_position + i,
in_context=True,
organization_id=actor.organization_id,
)
session.add(conv_msg)
await session.commit()
@enforce_types
@trace_method
async def update_in_context_messages(
self,
conversation_id: str,
in_context_message_ids: List[str],
actor: PydanticUser,
) -> None:
"""
Update which messages are in context for a conversation.
Sets in_context=True for messages in the list, False for others.
Args:
conversation_id: The conversation to update
in_context_message_ids: List of message IDs that should be in context
actor: The user performing the action
"""
async with db_registry.async_session() as session:
# Get all conversation messages for this conversation
query = select(ConversationMessageModel).where(
ConversationMessageModel.conversation_id == conversation_id,
ConversationMessageModel.organization_id == actor.organization_id,
ConversationMessageModel.is_deleted == False,
)
result = await session.execute(query)
conv_messages = result.scalars().all()
# Update in_context status
in_context_set = set(in_context_message_ids)
for conv_msg in conv_messages:
conv_msg.in_context = conv_msg.message_id in in_context_set
await session.commit()
@enforce_types
@trace_method
async def list_conversation_messages(
self,
conversation_id: str,
actor: PydanticUser,
limit: Optional[int] = 100,
before: Optional[str] = None,
after: Optional[str] = None,
) -> List[LettaMessage]:
"""
List all messages in a conversation with pagination support.
Unlike get_messages_for_conversation, this returns ALL messages
(not just in_context) and supports cursor-based pagination.
Messages are always ordered by position (oldest first).
Args:
conversation_id: The conversation to list messages for
actor: The user performing the action
limit: Maximum number of messages to return
before: Return messages before this message ID
after: Return messages after this message ID
Returns:
List of LettaMessage objects
"""
async with db_registry.async_session() as session:
# Build base query joining Message with ConversationMessage
query = (
select(MessageModel)
.join(
ConversationMessageModel,
MessageModel.id == ConversationMessageModel.message_id,
)
.where(
ConversationMessageModel.conversation_id == conversation_id,
ConversationMessageModel.organization_id == actor.organization_id,
ConversationMessageModel.is_deleted == False,
)
)
# Handle cursor-based pagination
if before:
# Get the position of the cursor message
cursor_query = select(ConversationMessageModel.position).where(
ConversationMessageModel.conversation_id == conversation_id,
ConversationMessageModel.message_id == before,
)
cursor_result = await session.execute(cursor_query)
cursor_position = cursor_result.scalar_one_or_none()
if cursor_position is not None:
query = query.where(ConversationMessageModel.position < cursor_position)
if after:
# Get the position of the cursor message
cursor_query = select(ConversationMessageModel.position).where(
ConversationMessageModel.conversation_id == conversation_id,
ConversationMessageModel.message_id == after,
)
cursor_result = await session.execute(cursor_query)
cursor_position = cursor_result.scalar_one_or_none()
if cursor_position is not None:
query = query.where(ConversationMessageModel.position > cursor_position)
# Order by position (oldest first)
query = query.order_by(ConversationMessageModel.position.asc())
# Apply limit
if limit is not None:
query = query.limit(limit)
result = await session.execute(query)
messages = [msg.to_pydantic() for msg in result.scalars().all()]
# Convert to LettaMessages
return PydanticMessage.to_letta_messages_from_list(messages, reverse=False, text_is_assistant_message=True)

View File

@@ -22,7 +22,7 @@ from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, S
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import settings
from letta.utils import enforce_types
from letta.utils import bounded_gather, enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -85,7 +85,7 @@ class FileManager:
# invalidate cache for this new file
await self._invalidate_file_caches(file_orm.id, actor, file_orm.original_file_name, file_orm.source_id)
return await file_orm.to_pydantic_async()
return file_orm.to_pydantic()
except IntegrityError:
await session.rollback()
@@ -124,14 +124,20 @@ class FileManager:
)
result = await session.execute(query)
file_orm = result.scalar_one()
file_orm = result.scalar_one_or_none()
else:
# fast path (metadata only)
file_orm = await FileMetadataModel.read_async(
db_session=session,
identifier=file_id,
actor=actor,
)
try:
file_orm = await FileMetadataModel.read_async(
db_session=session,
identifier=file_id,
actor=actor,
)
except NoResultFound:
return None
if file_orm is None:
return None
return await file_orm.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
@@ -278,7 +284,7 @@ class FileManager:
identifier=file_id,
actor=actor,
)
return await file_orm.to_pydantic_async()
return file_orm.to_pydantic()
@enforce_types
@trace_method
@@ -408,7 +414,7 @@ class FileManager:
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = None,
limit: Optional[int] = 1000,
ascending: Optional[bool] = True,
include_content: bool = False,
strip_directory_prefix: bool = False,
@@ -445,9 +451,15 @@ class FileManager:
)
# convert all files to pydantic models
file_metadatas = await asyncio.gather(
*[file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix) for file in files]
)
if include_content:
file_metadatas = await bounded_gather(
[
file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
for file in files
]
)
else:
file_metadatas = [file.to_pydantic(strip_directory_prefix=strip_directory_prefix) for file in files]
# if status checking is enabled, check all files sequentially to avoid db pool exhaustion
# Each status check may update the file in the database, so concurrent checks with many
@@ -473,7 +485,7 @@ class FileManager:
await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id)
await file.hard_delete_async(db_session=session, actor=actor)
return await file.to_pydantic_async()
return file.to_pydantic()
@enforce_types
@trace_method
@@ -555,7 +567,7 @@ class FileManager:
file_orm = result.scalar_one_or_none()
if file_orm:
return await file_orm.to_pydantic_async()
return file_orm.to_pydantic()
return None
@enforce_types
@@ -664,7 +676,10 @@ class FileManager:
result = await session.execute(query)
files_orm = result.scalars().all()
return await asyncio.gather(*[file.to_pydantic_async(include_content=include_content) for file in files_orm])
if include_content:
return await bounded_gather([file.to_pydantic_async(include_content=include_content) for file in files_orm])
else:
return [file.to_pydantic() for file in files_orm]
@enforce_types
@trace_method
@@ -709,4 +724,7 @@ class FileManager:
result = await session.execute(query)
files_orm = result.scalars().all()
return await asyncio.gather(*[file.to_pydantic_async(include_content=include_content) for file in files_orm])
if include_content:
return await bounded_gather([file.to_pydantic_async(include_content=include_content) for file in files_orm])
else:
return [file.to_pydantic() for file in files_orm]

View File

@@ -50,8 +50,10 @@ class FileProcessor:
"""Chunk text and generate embeddings with fallback to default chunker if needed"""
filename = file_metadata.file_name
# Create file-type-specific chunker
text_chunker = LlamaIndexChunker(file_type=file_metadata.file_type, chunk_size=self.embedder.embedding_config.embedding_chunk_size)
# Create file-type-specific chunker in thread pool to avoid blocking event loop
text_chunker = await asyncio.to_thread(
LlamaIndexChunker, file_type=file_metadata.file_type, chunk_size=self.embedder.embedding_config.embedding_chunk_size
)
# First attempt with file-specific chunker
try:

View File

@@ -200,7 +200,8 @@ class FileAgentManager:
stmt = delete(FileAgentModel).where(and_(or_(*conditions), FileAgentModel.organization_id == actor.organization_id))
result = await session.execute(stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
return result.rowcount
@@ -291,6 +292,32 @@ class FileAgentManager:
else:
return [r.to_pydantic() for r in rows]
@enforce_types
@trace_method
async def get_file_ids_for_agent_by_source(
self,
agent_id: str,
source_id: str,
actor: PydanticUser,
) -> List[str]:
"""
Get all file IDs attached to an agent from a specific source.
This queries the files_agents junction table directly, ensuring we get
exactly the files that were attached, regardless of any changes to the
source's file list.
"""
async with db_registry.async_session() as session:
stmt = select(FileAgentModel.file_id).where(
and_(
FileAgentModel.agent_id == agent_id,
FileAgentModel.source_id == source_id,
FileAgentModel.organization_id == actor.organization_id,
)
)
result = await session.execute(stmt)
return list(result.scalars().all())
@enforce_types
@trace_method
async def list_files_for_agent_paginated(
@@ -405,7 +432,8 @@ class FileAgentManager:
.values(last_accessed_at=func.now())
)
await session.execute(stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -425,7 +453,8 @@ class FileAgentManager:
.values(last_accessed_at=func.now())
)
await session.execute(stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -458,7 +487,8 @@ class FileAgentManager:
)
closed_file_names = [row.file_name for row in (await session.execute(stmt))]
await session.commit()
# context manager now handles commits
# await session.commit()
return closed_file_names
@enforce_types
@@ -702,7 +732,8 @@ class FileAgentManager:
.values(is_open=False, visible_content=None)
)
await session.commit()
# context manager now handles commits
# await session.commit()
return closed_file_names
async def _get_association_by_file_id(self, session, agent_id: str, file_id: str, actor: PydanticUser) -> FileAgentModel:

View File

@@ -246,7 +246,8 @@ class GroupManager:
)
await session.execute(delete_stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
@@ -258,7 +259,7 @@ class GroupManager:
# Update turns counter
group.turns_counter = (group.turns_counter + 1) % group.sleeptime_agent_frequency
await group.update_async(session, actor=actor)
await group.update_async(session, actor=actor, no_refresh=True)
return group.turns_counter
@enforce_types
@@ -275,7 +276,7 @@ class GroupManager:
# Update last processed message id
prev_last_processed_message_id = group.last_processed_message_id
group.last_processed_message_id = last_processed_message_id
await group.update_async(session, actor=actor)
await group.update_async(session, actor=actor, no_refresh=True)
return prev_last_processed_message_id
@@ -434,7 +435,8 @@ class GroupManager:
# Add block to group
session.add(GroupsBlocks(group_id=group_id, block_id=block_id))
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
@@ -452,7 +454,8 @@ class GroupManager:
# Remove block from group
delete_group_block = delete(GroupsBlocks).where(and_(GroupsBlocks.group_id == group_id, GroupsBlocks.block_id == block_id))
await session.execute(delete_group_block)
await session.commit()
# context manager now handles commits
# await session.commit()
@staticmethod
def ensure_buffer_length_range_valid(

View File

@@ -1099,8 +1099,8 @@ async def build_source_passage_query(
embedded_text = np.array(embeddings[0])
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
# Base query for source passages
query = select(SourcePassage).where(SourcePassage.organization_id == actor.organization_id)
# Base query for source passages - use noload to prevent lazy loading which can block the event loop
query = select(SourcePassage).options(noload(SourcePassage.organization)).where(SourcePassage.organization_id == actor.organization_id)
# If agent_id is specified, join with SourcesAgents to get only passages linked to that agent
if agent_id is not None:
@@ -1208,23 +1208,26 @@ async def build_agent_passage_query(
embedded_text = np.array(embeddings[0])
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
# Base query for passages
# Base query for passages - use noload to prevent lazy loading which can block the event loop
if agent_id:
# Query for agent passages - join through archives_agents
# Agent_id takes precedence if both agent_id and archive_id are provided
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
.where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id)
)
elif archive_id:
# Query for archive passages directly
query = select(ArchivalPassage).where(
ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.where(ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id)
)
else:
# Org-wide search - all passages in organization
query = select(ArchivalPassage).where(ArchivalPassage.organization_id == actor.organization_id)
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.where(ArchivalPassage.organization_id == actor.organization_id)
)
# Apply filters
if start_date:

View File

@@ -24,7 +24,7 @@ from letta.schemas.identity import (
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import DatabaseChoice, settings
from letta.utils import enforce_types
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
from letta.validators import raise_on_invalid_id
@@ -257,7 +257,8 @@ class IdentityManager:
if identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
await session.delete(identity)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -335,7 +336,14 @@ class IdentityManager:
ascending=ascending,
identity_id=identity.id,
)
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents])
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@enforce_types
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)

View File

@@ -58,7 +58,8 @@ class JobManager:
job.organization_id = actor.organization_id
job = await job.create_async(session, actor=actor, no_commit=True, no_refresh=True) # Save job in the database
await session.commit()
# context manager now handles commits
# await session.commit()
# Convert to pydantic first, then add agent_id if needed
result = super(JobModel, job).to_pydantic()
@@ -122,7 +123,8 @@ class JobManager:
# Get the updated metadata for callback
final_metadata = job.metadata_
result = job.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
# Dispatch callback outside of database session if needed
if needs_callback:
@@ -143,7 +145,8 @@ class JobManager:
job.callback_error = callback_result.get("callback_error")
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
result = job.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return result
@@ -462,7 +465,8 @@ class JobManager:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.ttft_ns = ttft_ns
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
# context manager now handles commits
# await session.commit()
except Exception as e:
logger.warning(f"Failed to record TTFT for job {job_id}: {e}")
@@ -475,7 +479,8 @@ class JobManager:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.total_duration_ns = total_duration_ns
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
# context manager now handles commits
# await session.commit()
except Exception as e:
logger.warning(f"Failed to record response duration for job {job_id}: {e}")

View File

@@ -45,7 +45,8 @@ class LLMBatchManager:
)
await batch.create_async(session, actor=actor, no_commit=True, no_refresh=True)
pydantic_batch = batch.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_batch
@enforce_types
@@ -98,7 +99,8 @@ class LLMBatchManager:
)
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings))
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -285,7 +287,8 @@ class LLMBatchManager:
created_items = await LLMBatchItem.batch_create_async(orm_items, session, actor=actor, no_commit=True, no_refresh=True)
pydantic_items = [item.to_pydantic() for item in created_items]
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_items
@enforce_types
@@ -421,7 +424,8 @@ class LLMBatchManager:
if mappings:
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings))
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method

View File

@@ -82,7 +82,10 @@ class AsyncBaseMCPClient:
result = await self.session.call_tool(tool_name, tool_args)
except Exception as e:
if e.__class__.__name__ == "McpError":
logger.warning(f"MCP tool '{tool_name}' execution failed: {str(e)}")
# MCP errors are typically user-facing issues from external MCP servers
# (e.g., resource not found, invalid arguments, permission errors)
# Log at debug level to avoid triggering production alerts for expected failures
logger.debug(f"MCP tool '{tool_name}' execution failed: {str(e)}")
raise
parsed_content = []

View File

@@ -0,0 +1,307 @@
"""FastMCP-based MCP clients with server-side OAuth support.
This module provides MCP client implementations using the FastMCP library,
with support for server-side OAuth flows where authorization URLs are
forwarded to web clients instead of opening a browser.
These clients replace the existing AsyncSSEMCPClient and AsyncStreamableHTTPMCPClient
implementations that used the lower-level MCP SDK directly.
"""
from contextlib import AsyncExitStack
from typing import List, Optional, Tuple
import httpx
from fastmcp import Client
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
from mcp import Tool as MCPTool
from letta.functions.mcp_client.types import SSEServerConfig, StreamableHTTPServerConfig
from letta.log import get_logger
from letta.services.mcp.server_side_oauth import ServerSideOAuth
logger = get_logger(__name__)
class AsyncFastMCPSSEClient:
"""SSE MCP client using FastMCP with server-side OAuth support.
This client connects to MCP servers using Server-Sent Events (SSE) transport.
It supports both authenticated and unauthenticated connections, with OAuth
handled via the ServerSideOAuth class for server-side flows.
Args:
server_config: SSE server configuration including URL, headers, and auth settings
oauth: Optional ServerSideOAuth instance for OAuth authentication
agent_id: Optional agent ID to include in request headers
"""
AGENT_ID_HEADER = "X-Agent-Id"
def __init__(
self,
server_config: SSEServerConfig,
oauth: Optional[ServerSideOAuth] = None,
agent_id: Optional[str] = None,
):
self.server_config = server_config
self.oauth = oauth
self.agent_id = agent_id
self.client: Optional[Client] = None
self.initialized = False
self.exit_stack = AsyncExitStack()
async def connect_to_server(self):
"""Establish connection to the MCP server.
Raises:
ConnectionError: If connection to the server fails
"""
try:
headers = {}
if self.server_config.custom_headers:
headers.update(self.server_config.custom_headers)
if self.server_config.auth_header and self.server_config.auth_token:
headers[self.server_config.auth_header] = self.server_config.auth_token
if self.agent_id:
headers[self.AGENT_ID_HEADER] = self.agent_id
transport = SSETransport(
url=self.server_config.server_url,
headers=headers if headers else None,
auth=self.oauth, # Pass ServerSideOAuth instance (or None)
)
self.client = Client(transport)
await self.client._connect()
self.initialized = True
except httpx.HTTPStatusError as e:
# Re-raise HTTP status errors for OAuth flow handling
if e.response.status_code == 401:
raise ConnectionError("401 Unauthorized")
raise e
except Exception as e:
raise e
async def list_tools(self, serialize: bool = False) -> List[MCPTool]:
"""List available tools from the MCP server.
Args:
serialize: If True, return tools as dictionaries instead of MCPTool objects
Returns:
List of tools available on the server
Raises:
RuntimeError: If client has not been initialized
"""
self._check_initialized()
tools = await self.client.list_tools()
if serialize:
serializable_tools = []
for tool in tools:
if hasattr(tool, "model_dump"):
serializable_tools.append(tool.model_dump())
elif hasattr(tool, "dict"):
serializable_tools.append(tool.dict())
elif hasattr(tool, "__dict__"):
serializable_tools.append(tool.__dict__)
else:
serializable_tools.append(str(tool))
return serializable_tools
return tools
async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
"""Execute a tool on the MCP server.
Args:
tool_name: Name of the tool to execute
tool_args: Arguments to pass to the tool
Returns:
Tuple of (result_content, success_flag)
Raises:
RuntimeError: If client has not been initialized
"""
self._check_initialized()
try:
result = await self.client.call_tool(tool_name, tool_args)
except Exception as e:
if e.__class__.__name__ == "McpError":
logger.warning(f"MCP tool '{tool_name}' execution failed: {str(e)}")
raise
# Parse content from result
parsed_content = []
for content_piece in result.content:
if hasattr(content_piece, "text"):
parsed_content.append(content_piece.text)
logger.debug(f"MCP tool result parsed content (text): {parsed_content}")
else:
parsed_content.append(str(content_piece))
logger.debug(f"MCP tool result parsed content (other): {parsed_content}")
if parsed_content:
final_content = " ".join(parsed_content)
else:
final_content = "Empty response from tool"
return final_content, not result.is_error
def _check_initialized(self):
"""Check if the client has been initialized."""
if not self.initialized:
logger.error("MCPClient has not been initialized")
raise RuntimeError("MCPClient has not been initialized")
async def cleanup(self):
"""Clean up client resources."""
if self.client:
try:
await self.client.close()
except Exception as e:
logger.warning(f"Error during FastMCP client cleanup: {e}")
self.initialized = False
class AsyncFastMCPStreamableHTTPClient:
"""Streamable HTTP MCP client using FastMCP with server-side OAuth support.
This client connects to MCP servers using Streamable HTTP transport.
It supports both authenticated and unauthenticated connections, with OAuth
handled via the ServerSideOAuth class for server-side flows.
Args:
server_config: Streamable HTTP server configuration
oauth: Optional ServerSideOAuth instance for OAuth authentication
agent_id: Optional agent ID to include in request headers
"""
AGENT_ID_HEADER = "X-Agent-Id"
def __init__(
self,
server_config: StreamableHTTPServerConfig,
oauth: Optional[ServerSideOAuth] = None,
agent_id: Optional[str] = None,
):
self.server_config = server_config
self.oauth = oauth
self.agent_id = agent_id
self.client: Optional[Client] = None
self.initialized = False
self.exit_stack = AsyncExitStack()
async def connect_to_server(self):
"""Establish connection to the MCP server.
Raises:
ConnectionError: If connection to the server fails
"""
try:
headers = {}
if self.server_config.custom_headers:
headers.update(self.server_config.custom_headers)
if self.server_config.auth_header and self.server_config.auth_token:
headers[self.server_config.auth_header] = self.server_config.auth_token
if self.agent_id:
headers[self.AGENT_ID_HEADER] = self.agent_id
transport = StreamableHttpTransport(
url=self.server_config.server_url,
headers=headers if headers else None,
auth=self.oauth, # Pass ServerSideOAuth instance (or None)
)
self.client = Client(transport)
await self.client._connect()
self.initialized = True
except httpx.HTTPStatusError as e:
# Re-raise HTTP status errors for OAuth flow handling
if e.response.status_code == 401:
raise ConnectionError("401 Unauthorized")
raise e
except Exception as e:
raise e
async def list_tools(self, serialize: bool = False) -> List[MCPTool]:
"""List available tools from the MCP server.
Args:
serialize: If True, return tools as dictionaries instead of MCPTool objects
Returns:
List of tools available on the server
Raises:
RuntimeError: If client has not been initialized
"""
self._check_initialized()
tools = await self.client.list_tools()
if serialize:
serializable_tools = []
for tool in tools:
if hasattr(tool, "model_dump"):
serializable_tools.append(tool.model_dump())
elif hasattr(tool, "dict"):
serializable_tools.append(tool.dict())
elif hasattr(tool, "__dict__"):
serializable_tools.append(tool.__dict__)
else:
serializable_tools.append(str(tool))
return serializable_tools
return tools
async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
"""Execute a tool on the MCP server.
Args:
tool_name: Name of the tool to execute
tool_args: Arguments to pass to the tool
Returns:
Tuple of (result_content, success_flag)
Raises:
RuntimeError: If client has not been initialized
"""
self._check_initialized()
try:
result = await self.client.call_tool(tool_name, tool_args)
except Exception as e:
if e.__class__.__name__ == "McpError":
logger.warning(f"MCP tool '{tool_name}' execution failed: {str(e)}")
raise
# Parse content from result
parsed_content = []
for content_piece in result.content:
if hasattr(content_piece, "text"):
parsed_content.append(content_piece.text)
logger.debug(f"MCP tool result parsed content (text): {parsed_content}")
else:
parsed_content.append(str(content_piece))
logger.debug(f"MCP tool result parsed content (other): {parsed_content}")
if parsed_content:
final_content = " ".join(parsed_content)
else:
final_content = "Empty response from tool"
return final_content, not result.is_error
def _check_initialized(self):
"""Check if the client has been initialized."""
if not self.initialized:
logger.error("MCPClient has not been initialized")
raise RuntimeError("MCPClient has not been initialized")
async def cleanup(self):
"""Clean up client resources."""
if self.client:
try:
await self.client.close()
except Exception as e:
logger.warning(f"Error during FastMCP client cleanup: {e}")
self.initialized = False

View File

@@ -6,7 +6,7 @@ import secrets
import time
import uuid
from datetime import datetime, timedelta
from typing import Callable, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Optional, Tuple
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
@@ -18,7 +18,9 @@ from letta.schemas.mcp import MCPOAuthSessionUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.mcp.types import OauthStreamEvent
from letta.services.mcp_manager import MCPManager
if TYPE_CHECKING:
from letta.services.mcp_manager import MCPManager
logger = get_logger(__name__)
@@ -26,7 +28,7 @@ logger = get_logger(__name__)
class DatabaseTokenStorage(TokenStorage):
"""Database-backed token storage using MCPOAuth table via mcp_manager."""
def __init__(self, session_id: str, mcp_manager: MCPManager, actor: PydanticUser):
def __init__(self, session_id: str, mcp_manager: "MCPManager", actor: PydanticUser):
self.session_id = session_id
self.mcp_manager = mcp_manager
self.actor = actor
@@ -150,9 +152,10 @@ class MCPOAuthSession:
try:
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
# Encrypt the authorization_code and store only in _enc column
# Encrypt the authorization_code and store only in _enc column (async to avoid blocking event loop)
if code is not None:
oauth_record.authorization_code_enc = Secret.from_plaintext(code).get_encrypted()
code_secret = await Secret.from_plaintext_async(code)
oauth_record.authorization_code_enc = code_secret.get_encrypted()
oauth_record.status = OAuthSessionStatus.AUTHORIZED
oauth_record.state = state
@@ -186,12 +189,17 @@ async def create_oauth_provider(
session_id: str,
server_url: str,
redirect_uri: str,
mcp_manager: MCPManager,
mcp_manager: "MCPManager",
actor: PydanticUser,
logo_uri: Optional[str] = None,
url_callback: Optional[Callable[[str], None]] = None,
) -> OAuthClientProvider:
"""Create an OAuth provider for MCP server authentication."""
"""Create an OAuth provider for MCP server authentication.
DEPRECATED: Use ServerSideOAuth from letta.services.mcp.server_side_oauth instead.
This function is kept for backwards compatibility but will be removed in a future version.
"""
logger.warning("create_oauth_provider is deprecated. Use ServerSideOAuth from letta.services.mcp.server_side_oauth instead.")
client_metadata_dict = {
"client_name": "Letta",

Some files were not shown because too many files have changed in this diff Show More