chore: bump version 0.16.2 (#3140)
This commit is contained in:
2
.github/scripts/model-sweep/conftest.py
vendored
2
.github/scripts/model-sweep/conftest.py
vendored
@@ -184,7 +184,7 @@ def _start_server_once() -> str:
|
||||
thread.start()
|
||||
|
||||
# Poll until up
|
||||
timeout_seconds = 30
|
||||
timeout_seconds = 60
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12
|
||||
@@ -90,6 +90,36 @@ Workflow for Postgres-targeted migration:
|
||||
- `uv run alembic upgrade head`
|
||||
- `uv run alembic revision --autogenerate -m "..."`
|
||||
|
||||
### 5. Resetting local Postgres for clean migration generation
|
||||
|
||||
If your local Postgres database has drifted from main (e.g., applied migrations
|
||||
that no longer exist, or has stale schema), you can reset it to generate a clean
|
||||
migration.
|
||||
|
||||
From the repo root (`/Users/sarahwooders/repos/letta-cloud`):
|
||||
|
||||
```bash
|
||||
# 1. Remove postgres data directory
|
||||
rm -rf ./data/postgres
|
||||
|
||||
# 2. Stop the running postgres container
|
||||
docker stop $(docker ps -q --filter ancestor=ankane/pgvector)
|
||||
|
||||
# 3. Restart services (creates fresh postgres)
|
||||
just start-services
|
||||
|
||||
# 4. Wait a moment for postgres to be ready, then apply all migrations
|
||||
cd apps/core
|
||||
export LETTA_PG_URI=postgresql+pg8000://postgres:postgres@localhost:5432/letta-core
|
||||
uv run alembic upgrade head
|
||||
|
||||
# 5. Now generate your new migration
|
||||
uv run alembic revision --autogenerate -m "your migration message"
|
||||
```
|
||||
|
||||
This ensures the migration is generated against a clean database state matching
|
||||
main, avoiding spurious diffs from local-only schema changes.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **"Target database is not up to date" when autogenerating**
|
||||
@@ -101,7 +131,7 @@ Workflow for Postgres-targeted migration:
|
||||
changed model is imported in Alembic env context.
|
||||
- **Autogenerated migration has unexpected drops/renames**
|
||||
- Review model changes; consider explicit operations instead of relying on
|
||||
autogenerate.
|
||||
autogenerate. Reset local Postgres (see workflow 5) to get a clean baseline.
|
||||
|
||||
## References
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""add conversations tables and run conversation_id
|
||||
|
||||
Revision ID: 27de0f58e076
|
||||
Revises: ee2b43eea55e
|
||||
Create Date: 2026-01-01 20:36:09.101274
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "27de0f58e076"
|
||||
down_revision: Union[str, None] = "ee2b43eea55e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"conversations",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("summary", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_conversations_agent_id", "conversations", ["agent_id"], unique=False)
|
||||
op.create_index("ix_conversations_org_agent", "conversations", ["organization_id", "agent_id"], unique=False)
|
||||
op.create_table(
|
||||
"conversation_messages",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("conversation_id", sa.String(), nullable=True),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("message_id", sa.String(), nullable=False),
|
||||
sa.Column("position", sa.Integer(), nullable=False),
|
||||
sa.Column("in_context", sa.Boolean(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["conversation_id"], ["conversations.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["message_id"], ["messages.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("conversation_id", "message_id", name="unique_conversation_message"),
|
||||
)
|
||||
op.create_index("ix_conv_msg_agent_conversation", "conversation_messages", ["agent_id", "conversation_id"], unique=False)
|
||||
op.create_index("ix_conv_msg_agent_id", "conversation_messages", ["agent_id"], unique=False)
|
||||
op.create_index("ix_conv_msg_conversation_position", "conversation_messages", ["conversation_id", "position"], unique=False)
|
||||
op.create_index("ix_conv_msg_message_id", "conversation_messages", ["message_id"], unique=False)
|
||||
op.add_column("messages", sa.Column("conversation_id", sa.String(), nullable=True))
|
||||
op.create_index(op.f("ix_messages_conversation_id"), "messages", ["conversation_id"], unique=False)
|
||||
op.create_foreign_key(None, "messages", "conversations", ["conversation_id"], ["id"], ondelete="SET NULL")
|
||||
op.add_column("runs", sa.Column("conversation_id", sa.String(), nullable=True))
|
||||
op.create_index("ix_runs_conversation_id", "runs", ["conversation_id"], unique=False)
|
||||
op.create_foreign_key(None, "runs", "conversations", ["conversation_id"], ["id"], ondelete="SET NULL")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, "runs", type_="foreignkey")
|
||||
op.drop_index("ix_runs_conversation_id", table_name="runs")
|
||||
op.drop_column("runs", "conversation_id")
|
||||
op.drop_constraint(None, "messages", type_="foreignkey")
|
||||
op.drop_index(op.f("ix_messages_conversation_id"), table_name="messages")
|
||||
op.drop_column("messages", "conversation_id")
|
||||
op.drop_index("ix_conv_msg_message_id", table_name="conversation_messages")
|
||||
op.drop_index("ix_conv_msg_conversation_position", table_name="conversation_messages")
|
||||
op.drop_index("ix_conv_msg_agent_id", table_name="conversation_messages")
|
||||
op.drop_index("ix_conv_msg_agent_conversation", table_name="conversation_messages")
|
||||
op.drop_table("conversation_messages")
|
||||
op.drop_index("ix_conversations_org_agent", table_name="conversations")
|
||||
op.drop_index("ix_conversations_agent_id", table_name="conversations")
|
||||
op.drop_table("conversations")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add request_id to steps table
|
||||
|
||||
Revision ID: ee2b43eea55e
|
||||
Revises: 39577145c45d
|
||||
Create Date: 2025-12-17 13:48:08.642245
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ee2b43eea55e"
|
||||
down_revision: Union[str, None] = "39577145c45d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("steps", sa.Column("request_id", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("steps", "request_id")
|
||||
# ### end Alembic commands ###
|
||||
4317
fern/openapi.json
4317
fern/openapi.json
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@ try:
|
||||
__version__ = version("letta")
|
||||
except PackageNotFoundError:
|
||||
# Fallback for development installations
|
||||
__version__ = "0.16.1"
|
||||
__version__ = "0.16.2"
|
||||
|
||||
if os.environ.get("LETTA_VERSION"):
|
||||
__version__ = os.environ["LETTA_VERSION"]
|
||||
|
||||
@@ -87,9 +87,13 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
|
||||
# Process the stream and yield chunks immediately for TTFT
|
||||
async for chunk in self.interface.process(stream): # TODO: add ttft span
|
||||
# Yield each chunk immediately as it arrives
|
||||
yield chunk
|
||||
# Wrap in error handling to convert provider errors to common LLMError types
|
||||
try:
|
||||
async for chunk in self.interface.process(stream): # TODO: add ttft span
|
||||
# Yield each chunk immediately as it arrives
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
|
||||
# After streaming completes, extract the accumulated data
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
|
||||
@@ -75,7 +75,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
elif self.llm_config.model_endpoint_type in [ProviderType.openai, ProviderType.deepseek]:
|
||||
elif self.llm_config.model_endpoint_type in [ProviderType.openai, ProviderType.deepseek, ProviderType.zai]:
|
||||
# Decide interface based on payload shape
|
||||
use_responses = "input" in request_data and "messages" not in request_data
|
||||
# No support for Responses API proxy
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, AsyncGenerator
|
||||
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.log import get_logger
|
||||
@@ -10,6 +10,9 @@ from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.user import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
|
||||
|
||||
class BaseAgentV2(ABC):
|
||||
"""
|
||||
@@ -42,9 +45,14 @@ class BaseAgentV2(ABC):
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list["ClientToolSchema"] | None = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
|
||||
Args:
|
||||
client_tools: Optional list of client-side tools. When called, execution pauses
|
||||
for client to provide tool returns.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -58,11 +66,17 @@ class BaseAgentV2(ABC):
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list["ClientToolSchema"] | None = None,
|
||||
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
||||
"""
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
If stream_tokens is True, individual tokens are streamed as they arrive from the LLM,
|
||||
providing the lowest latency experience, otherwise each complete step (reasoning +
|
||||
tool call + tool return) is yielded as it completes.
|
||||
|
||||
Args:
|
||||
client_tools: Optional list of client-side tools. When called, execution pauses
|
||||
for client to provide tool returns.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
from letta.errors import PendingApprovalError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -131,16 +132,21 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
run_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
|
||||
When conversation_id is provided, messages are loaded from the conversation_messages
|
||||
table instead of agent_state.message_ids.
|
||||
|
||||
Args:
|
||||
input_messages (List[MessageCreate]): The new user input messages to process.
|
||||
agent_state (AgentState): The current state of the agent, including message buffer config.
|
||||
message_manager (MessageManager): The manager used to retrieve and create messages.
|
||||
actor (User): The user performing the action, used for access control and attribution.
|
||||
run_id (str): The run ID associated with this message processing.
|
||||
conversation_id (str): Optional conversation ID to load messages from.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Message], List[Message]]: A tuple containing:
|
||||
@@ -148,12 +154,74 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
- The new in-context messages (messages created from the new input).
|
||||
"""
|
||||
|
||||
if agent_state.message_buffer_autoclear:
|
||||
# If autoclear is enabled, only include the most recent system message (usually at index 0)
|
||||
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)]
|
||||
if conversation_id:
|
||||
# Conversation mode: load messages from conversation_messages table
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
|
||||
conversation_manager = ConversationManager()
|
||||
message_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if agent_state.message_buffer_autoclear and message_ids:
|
||||
# If autoclear is enabled, only include the system message
|
||||
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=message_ids[0], actor=actor)]
|
||||
elif message_ids:
|
||||
# Otherwise, include the full list of messages from the conversation
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
|
||||
else:
|
||||
# No messages in conversation yet - compile a new system message for this conversation
|
||||
# Each conversation gets its own system message (captures memory state at conversation start)
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.services.passage_manager import PassageManager
|
||||
|
||||
num_messages = await message_manager.size_async(actor=actor, agent_id=agent_state.id)
|
||||
passage_manager = PassageManager()
|
||||
num_archival_memories = await passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id)
|
||||
|
||||
system_message_str = await PromptGenerator.compile_system_message_async(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=get_utc_time(),
|
||||
timezone=agent_state.timezone,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
sources=agent_state.sources,
|
||||
max_files_open=agent_state.max_files_open,
|
||||
)
|
||||
system_message = Message.dict_to_message(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": system_message_str},
|
||||
)
|
||||
|
||||
# Persist the new system message
|
||||
persisted_messages = await message_manager.create_many_messages_async([system_message], actor=actor)
|
||||
system_message = persisted_messages[0]
|
||||
|
||||
# Add it to the conversation tracking
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation_id,
|
||||
agent_id=agent_state.id,
|
||||
message_ids=[system_message.id],
|
||||
actor=actor,
|
||||
starting_position=0,
|
||||
)
|
||||
|
||||
current_in_context_messages = [system_message]
|
||||
else:
|
||||
# Otherwise, include the full list of messages by ID for context
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
# Default mode: load messages from agent_state.message_ids
|
||||
if agent_state.message_buffer_autoclear:
|
||||
# If autoclear is enabled, only include the most recent system message (usually at index 0)
|
||||
current_in_context_messages = [
|
||||
await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||||
]
|
||||
else:
|
||||
# Otherwise, include the full list of messages by ID for context
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
# Check for approval-related message validation
|
||||
if input_messages[0].type == "approval":
|
||||
|
||||
@@ -34,6 +34,7 @@ from letta.schemas.agent import AgentState, UpdateAgent
|
||||
from letta.schemas.enums import AgentType, MessageStreamStatus, RunStatus, StepStatus
|
||||
from letta.schemas.letta_message import LettaMessage, MessageType
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
@@ -173,6 +174,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
@@ -184,6 +186,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
use_assistant_message: Whether to use assistant message format
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
client_tools: Optional list of client-side tools (not used in V2, for API compatibility)
|
||||
|
||||
Returns:
|
||||
LettaResponse: Complete response with all messages and metadata
|
||||
@@ -251,6 +254,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
@@ -268,6 +273,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
use_assistant_message: Whether to use assistant message format
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
client_tools: Optional list of client-side tools (not used in V2, for API compatibility)
|
||||
|
||||
Yields:
|
||||
str: JSON-formatted SSE data chunks for each completed step
|
||||
@@ -1168,16 +1174,15 @@ class LettaAgentV2(BaseAgentV2):
|
||||
"""
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
tool_name = target_tool.name
|
||||
|
||||
# Special memory case
|
||||
# Check for None before accessing attributes
|
||||
if not target_tool:
|
||||
# TODO: fix this error message
|
||||
return ToolExecutionResult(
|
||||
func_return=f"Tool {tool_name} not found",
|
||||
func_return="Tool not found",
|
||||
status="error",
|
||||
)
|
||||
|
||||
tool_name = target_tool.name
|
||||
|
||||
# TODO: This temp. Move this logic and code to executors
|
||||
|
||||
if agent_step_span:
|
||||
|
||||
@@ -31,6 +31,7 @@ from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import ApprovalReturn, LettaErrorMessage, LettaMessage, MessageType
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -46,6 +47,7 @@ from letta.server.rest_api.utils import (
|
||||
create_parallel_tool_messages_from_llm_response,
|
||||
create_tool_returns_for_denials,
|
||||
)
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.summarizer.summarizer_all import summarize_all
|
||||
from letta.services.summarizer.summarizer_config import CompactionSettings
|
||||
@@ -79,6 +81,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# affecting step-level telemetry.
|
||||
self.context_token_estimate: int | None = None
|
||||
self.in_context_messages: list[Message] = [] # in-memory tracker
|
||||
# Conversation mode: when set, messages are tracked per-conversation
|
||||
self.conversation_id: str | None = None
|
||||
# Client-side tools passed in the request (executed by client, not server)
|
||||
self.client_tools: list[ClientToolSchema] = []
|
||||
|
||||
def _compute_tool_return_truncation_chars(self) -> int:
|
||||
"""Compute a dynamic cap for tool returns in requests.
|
||||
@@ -101,6 +107,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message: bool = True, # NOTE: not used
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
@@ -112,16 +120,27 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message: Whether to use assistant message format
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
conversation_id: Optional conversation ID for conversation-scoped messaging
|
||||
client_tools: Optional list of client-side tools. When called, execution pauses
|
||||
for client to provide tool returns.
|
||||
|
||||
Returns:
|
||||
LettaResponse: Complete response with all messages and metadata
|
||||
"""
|
||||
self._initialize_state()
|
||||
self.conversation_id = conversation_id
|
||||
self.client_tools = client_tools or []
|
||||
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
|
||||
response_letta_messages = []
|
||||
|
||||
# Prepare in-context messages (conversation mode if conversation_id provided)
|
||||
curr_in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
input_messages,
|
||||
self.agent_state,
|
||||
self.message_manager,
|
||||
self.actor,
|
||||
run_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
follow_up_messages = []
|
||||
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
|
||||
@@ -234,6 +253,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message: bool = True, # NOTE: not used
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
@@ -251,11 +272,16 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message: Whether to use assistant message format
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
conversation_id: Optional conversation ID for conversation-scoped messaging
|
||||
client_tools: Optional list of client-side tools. When called, execution pauses
|
||||
for client to provide tool returns.
|
||||
|
||||
Yields:
|
||||
str: JSON-formatted SSE data chunks for each completed step
|
||||
"""
|
||||
self._initialize_state()
|
||||
self.conversation_id = conversation_id
|
||||
self.client_tools = client_tools or []
|
||||
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
|
||||
response_letta_messages = []
|
||||
first_chunk = True
|
||||
@@ -273,8 +299,14 @@ class LettaAgentV3(LettaAgentV2):
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare in-context messages (conversation mode if conversation_id provided)
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
input_messages,
|
||||
self.agent_state,
|
||||
self.message_manager,
|
||||
self.actor,
|
||||
run_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
follow_up_messages = []
|
||||
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
|
||||
@@ -424,7 +456,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
This handles:
|
||||
- Persisting the new messages into the `messages` table
|
||||
- Updating the in-memory trackers for in-context messages (`self.in_context_messages`) and agent state (`self.agent_state.message_ids`)
|
||||
- Updating the DB with the current in-context messages (`self.agent_state.message_ids`)
|
||||
- Updating the DB with the current in-context messages (`self.agent_state.message_ids`) OR conversation_messages table
|
||||
|
||||
Args:
|
||||
run_id: The run ID to associate with the messages
|
||||
@@ -446,14 +478,33 @@ class LettaAgentV3(LettaAgentV2):
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
|
||||
# persist the in-context messages
|
||||
# TODO: somehow make sure all the message ids are already persisted
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=[m.id for m in in_context_messages],
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state
|
||||
if self.conversation_id:
|
||||
# Conversation mode: update conversation_messages table
|
||||
# Add new messages to conversation tracking
|
||||
new_message_ids = [m.id for m in new_messages]
|
||||
if new_message_ids:
|
||||
await ConversationManager().add_messages_to_conversation(
|
||||
conversation_id=self.conversation_id,
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=new_message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
# Update which messages are in context
|
||||
await ConversationManager().update_in_context_messages(
|
||||
conversation_id=self.conversation_id,
|
||||
in_context_message_ids=[m.id for m in in_context_messages],
|
||||
actor=self.actor,
|
||||
)
|
||||
else:
|
||||
# Default mode: update agent.message_ids
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=[m.id for m in in_context_messages],
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state
|
||||
|
||||
self.in_context_messages = in_context_messages # update in-memory state
|
||||
|
||||
@trace_method
|
||||
@@ -658,7 +709,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message=False, # NOTE: set to false
|
||||
requires_approval_tools=self.tool_rules_solver.get_requires_approval_tools(
|
||||
set([t["name"] for t in valid_tools])
|
||||
),
|
||||
)
|
||||
+ [ct.name for ct in self.client_tools],
|
||||
step_id=step_id,
|
||||
actor=self.actor,
|
||||
)
|
||||
@@ -684,7 +736,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# checkpoint summarized messages
|
||||
# TODO: might want to delay this checkpoint in case of corrupated state
|
||||
try:
|
||||
summary_message, messages = await self.compact(
|
||||
summary_message, messages, _ = await self.compact(
|
||||
messages, trigger_threshold=self.agent_state.llm_config.context_window
|
||||
)
|
||||
self.logger.info("Summarization succeeded, continuing to retry LLM request")
|
||||
@@ -726,6 +778,15 @@ class LettaAgentV3(LettaAgentV2):
|
||||
else:
|
||||
tool_calls = []
|
||||
|
||||
# Enforce parallel_tool_calls=false by truncating to first tool call
|
||||
# Some providers (e.g. Gemini) don't respect this setting via API, so we enforce it client-side
|
||||
if len(tool_calls) > 1 and not self.agent_state.llm_config.parallel_tool_calls:
|
||||
self.logger.warning(
|
||||
f"LLM returned {len(tool_calls)} tool calls but parallel_tool_calls=false. "
|
||||
f"Truncating to first tool call: {tool_calls[0].function.name}"
|
||||
)
|
||||
tool_calls = [tool_calls[0]]
|
||||
|
||||
# get the new generated `Message` objects from handling the LLM response
|
||||
new_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
|
||||
tool_calls=tool_calls,
|
||||
@@ -795,7 +856,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self.logger.info(
|
||||
f"Context window exceeded (current: {self.context_token_estimate}, threshold: {self.agent_state.llm_config.context_window}), trying to compact messages"
|
||||
)
|
||||
summary_message, messages = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window)
|
||||
summary_message, messages, _ = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window)
|
||||
# TODO: persist + return the summary message
|
||||
# TODO: convert this to a SummaryMessage
|
||||
self.response_messages.append(summary_message)
|
||||
@@ -964,10 +1025,22 @@ class LettaAgentV3(LettaAgentV2):
|
||||
messages_to_persist = (initial_messages or []) + assistant_message
|
||||
return messages_to_persist, continue_stepping, stop_reason
|
||||
|
||||
# 2. Check whether tool call requires approval
|
||||
# 2. Check whether tool call requires approval (includes client-side tools)
|
||||
if not is_approval_response:
|
||||
requested_tool_calls = [t for t in tool_calls if tool_rules_solver.is_requires_approval_tool(t.function.name)]
|
||||
allowed_tool_calls = [t for t in tool_calls if not tool_rules_solver.is_requires_approval_tool(t.function.name)]
|
||||
# Get names of client-side tools (these are executed by client, not server)
|
||||
client_tool_names = {ct.name for ct in self.client_tools} if self.client_tools else set()
|
||||
|
||||
# Tools requiring approval: requires_approval tools OR client-side tools
|
||||
requested_tool_calls = [
|
||||
t
|
||||
for t in tool_calls
|
||||
if tool_rules_solver.is_requires_approval_tool(t.function.name) or t.function.name in client_tool_names
|
||||
]
|
||||
allowed_tool_calls = [
|
||||
t
|
||||
for t in tool_calls
|
||||
if not tool_rules_solver.is_requires_approval_tool(t.function.name) and t.function.name not in client_tool_names
|
||||
]
|
||||
if requested_tool_calls:
|
||||
approval_messages = create_approval_request_message_from_llm_response(
|
||||
agent_id=self.agent_state.id,
|
||||
@@ -1037,15 +1110,11 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
# 5. Unified tool execution path (works for both single and multiple tools)
|
||||
|
||||
# 5a. Validate parallel tool calling constraints
|
||||
if len(tool_calls) > 1:
|
||||
# No parallel tool calls with tool rules
|
||||
if self.agent_state.tool_rules and len([r for r in self.agent_state.tool_rules if r.type != "requires_approval"]) > 0:
|
||||
raise ValueError(
|
||||
"Parallel tool calling is not allowed when tool rules are present. Disable tool rules to use parallel tool calls."
|
||||
)
|
||||
# 5. Unified tool execution path (works for both single and multiple tools)
|
||||
# Note: Parallel tool calling with tool rules is validated at agent create/update time.
|
||||
# At runtime, we trust that if tool_rules exist, parallel_tool_calls=false is enforced earlier.
|
||||
|
||||
# 5b. Prepare execution specs for all tools
|
||||
# 5a. Prepare execution specs for all tools
|
||||
exec_specs = []
|
||||
for tc in tool_calls:
|
||||
call_id = tc.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
@@ -1321,7 +1390,25 @@ class LettaAgentV3(LettaAgentV2):
|
||||
last_function_response=self.last_function_response,
|
||||
error_on_empty=False, # Return empty list instead of raising error
|
||||
) or list(set(t.name for t in tools))
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
|
||||
# Get client tool names to filter out server tools with same name (client tools override)
|
||||
client_tool_names = {ct.name for ct in self.client_tools} if self.client_tools else set()
|
||||
|
||||
# Build allowed tools from server tools, excluding those overridden by client tools
|
||||
allowed_tools = [
|
||||
enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names) and t.name not in client_tool_names
|
||||
]
|
||||
|
||||
# Merge client-side tools (use flat format matching enable_strict_mode output)
|
||||
if self.client_tools:
|
||||
for ct in self.client_tools:
|
||||
client_tool_schema = {
|
||||
"name": ct.name,
|
||||
"description": ct.description,
|
||||
"parameters": ct.parameters or {"type": "object", "properties": {}},
|
||||
}
|
||||
allowed_tools.append(client_tool_schema)
|
||||
|
||||
terminal_tool_names = {rule.tool_name for rule in self.tool_rules_solver.terminal_tool_rules}
|
||||
allowed_tools = runtime_override_tool_json_schema(
|
||||
tool_list=allowed_tools,
|
||||
@@ -1332,7 +1419,9 @@ class LettaAgentV3(LettaAgentV2):
|
||||
return allowed_tools
|
||||
|
||||
@trace_method
|
||||
async def compact(self, messages, trigger_threshold: Optional[int] = None) -> Message:
|
||||
async def compact(
|
||||
self, messages, trigger_threshold: Optional[int] = None, compaction_settings: Optional["CompactionSettings"] = None
|
||||
) -> tuple[Message, list[Message], str]:
|
||||
"""Compact the current in-context messages for this agent.
|
||||
|
||||
Compaction uses a summarizer LLM configuration derived from
|
||||
@@ -1341,9 +1430,11 @@ class LettaAgentV3(LettaAgentV2):
|
||||
localized to summarization.
|
||||
"""
|
||||
|
||||
# Use agent's compaction_settings if set, otherwise fall back to
|
||||
# global defaults based on the agent's model handle.
|
||||
if self.agent_state.compaction_settings is not None:
|
||||
# Use the passed-in compaction_settings first, then agent's compaction_settings if set,
|
||||
# otherwise fall back to global defaults based on the agent's model handle.
|
||||
if compaction_settings is not None:
|
||||
summarizer_config = compaction_settings
|
||||
elif self.agent_state.compaction_settings is not None:
|
||||
summarizer_config = self.agent_state.compaction_settings
|
||||
else:
|
||||
# Prefer the new handle field if set, otherwise derive from llm_config
|
||||
@@ -1466,7 +1557,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if len(compacted_messages) > 1:
|
||||
final_messages += compacted_messages[1:]
|
||||
|
||||
return summary_message_obj, final_messages
|
||||
return summary_message_obj, final_messages, summary
|
||||
|
||||
@staticmethod
|
||||
def _build_summarizer_llm_config(
|
||||
@@ -1489,17 +1580,16 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# Parse provider/model from the handle, falling back to the agent's
|
||||
# provider type when only a model name is given.
|
||||
if "/" in summarizer_config.model:
|
||||
provider, model_name = summarizer_config.model.split("/", 1)
|
||||
if provider == "openai-proxy":
|
||||
# fix for pydantic LLMConfig validation
|
||||
provider = "openai"
|
||||
provider_name, model_name = summarizer_config.model.split("/", 1)
|
||||
else:
|
||||
provider = agent_llm_config.model_endpoint_type
|
||||
provider_name = agent_llm_config.provider_name
|
||||
model_name = summarizer_config.model
|
||||
|
||||
# Start from the agent's config and override model + provider + handle
|
||||
# Start from the agent's config and override model + provider_name + handle
|
||||
# Note: model_endpoint_type is NOT overridden - the parsed provider_name
|
||||
# is a custom label (e.g. "claude-pro-max"), not the endpoint type (e.g. "anthropic")
|
||||
base = agent_llm_config.model_copy()
|
||||
base.model_endpoint_type = provider
|
||||
base.provider_name = provider_name
|
||||
base.model = model_name
|
||||
base.handle = summarizer_config.model
|
||||
|
||||
|
||||
@@ -8,6 +8,26 @@ LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
|
||||
LETTA_MODEL_ENDPOINT = "https://inference.letta.com/v1/"
|
||||
DEFAULT_TIMEZONE = "UTC"
|
||||
|
||||
# Provider ordering for model listing (matches original _enabled_providers list order)
|
||||
PROVIDER_ORDER = {
|
||||
"letta": 0,
|
||||
"openai": 1,
|
||||
"anthropic": 2,
|
||||
"ollama": 3,
|
||||
"google_ai": 4,
|
||||
"google_vertex": 5,
|
||||
"azure": 6,
|
||||
"groq": 7,
|
||||
"together": 8,
|
||||
"vllm": 9,
|
||||
"bedrock": 10,
|
||||
"deepseek": 11,
|
||||
"xai": 12,
|
||||
"lmstudio": 13,
|
||||
"zai": 14,
|
||||
"openrouter": 15, # Note: OpenRouter uses OpenRouterProvider, not a ProviderType enum
|
||||
}
|
||||
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
API_PREFIX = "/v1"
|
||||
OLLAMA_API_PREFIX = "/v1"
|
||||
@@ -432,6 +452,10 @@ REDIS_SET_DEFAULT_VAL = "None"
|
||||
REDIS_DEFAULT_CACHE_PREFIX = "letta_cache"
|
||||
REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
|
||||
|
||||
# Conversation lock constants
|
||||
CONVERSATION_LOCK_PREFIX = "conversation:lock:"
|
||||
CONVERSATION_LOCK_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
# TODO: This is temporary, eventually use token-based eviction
|
||||
# File based controls
|
||||
DEFAULT_MAX_FILES_OPEN = 5
|
||||
|
||||
@@ -2,17 +2,20 @@ import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
|
||||
from letta.constants import CONVERSATION_LOCK_PREFIX, CONVERSATION_LOCK_TTL_SECONDS, REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
|
||||
from letta.errors import ConversationBusyError
|
||||
from letta.log import get_logger
|
||||
from letta.settings import settings
|
||||
|
||||
try:
|
||||
from redis import RedisError
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
from redis.asyncio.lock import Lock
|
||||
except ImportError:
|
||||
RedisError = None
|
||||
Redis = None
|
||||
ConnectionPool = None
|
||||
Lock = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -171,6 +174,62 @@ class AsyncRedisClient:
|
||||
client = await self.get_client()
|
||||
return await client.delete(*keys)
|
||||
|
||||
async def acquire_conversation_lock(
|
||||
self,
|
||||
conversation_id: str,
|
||||
token: str,
|
||||
) -> Optional["Lock"]:
|
||||
"""
|
||||
Acquire a distributed lock for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID for the conversation
|
||||
token: Unique identifier for the lock holder (for debugging/tracing)
|
||||
|
||||
Returns:
|
||||
Lock object if acquired, raises ConversationBusyError if in use
|
||||
"""
|
||||
if Lock is None:
|
||||
return None
|
||||
client = await self.get_client()
|
||||
lock_key = f"{CONVERSATION_LOCK_PREFIX}{conversation_id}"
|
||||
lock = Lock(
|
||||
client,
|
||||
lock_key,
|
||||
timeout=CONVERSATION_LOCK_TTL_SECONDS,
|
||||
blocking=False,
|
||||
thread_local=False, # We manage token explicitly
|
||||
raise_on_release_error=False, # We handle release errors ourselves
|
||||
)
|
||||
|
||||
if await lock.acquire(token=token):
|
||||
return lock
|
||||
|
||||
lock_holder_token = await client.get(lock_key)
|
||||
raise ConversationBusyError(
|
||||
conversation_id=conversation_id,
|
||||
lock_holder_token=lock_holder_token,
|
||||
)
|
||||
|
||||
async def release_conversation_lock(self, conversation_id: str) -> bool:
|
||||
"""
|
||||
Release a conversation lock by conversation_id.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to release the lock for
|
||||
|
||||
Returns:
|
||||
True if lock was released, False if release failed
|
||||
"""
|
||||
try:
|
||||
client = await self.get_client()
|
||||
lock_key = f"{CONVERSATION_LOCK_PREFIX}{conversation_id}"
|
||||
await client.delete(lock_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release conversation lock for conversation {conversation_id}: {e}")
|
||||
return False
|
||||
|
||||
@with_retry()
|
||||
async def exists(self, *keys: str) -> int:
|
||||
"""Check if keys exist."""
|
||||
@@ -395,6 +454,16 @@ class NoopAsyncRedisClient(AsyncRedisClient):
|
||||
async def delete(self, *keys: str) -> int:
|
||||
return 0
|
||||
|
||||
async def acquire_conversation_lock(
|
||||
self,
|
||||
conversation_id: str,
|
||||
token: str,
|
||||
) -> Optional["Lock"]:
|
||||
return None
|
||||
|
||||
async def release_conversation_lock(self, conversation_id: str) -> bool:
|
||||
return False
|
||||
|
||||
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -53,6 +53,42 @@ class PendingApprovalError(LettaError):
|
||||
super().__init__(message=message, code=code, details=details)
|
||||
|
||||
|
||||
class NoActiveRunsToCancelError(LettaError):
|
||||
"""Error raised when attempting to cancel but there are no active runs to cancel."""
|
||||
|
||||
def __init__(self, agent_id: Optional[str] = None):
|
||||
message = "No active runs to cancel"
|
||||
if agent_id:
|
||||
message = f"No active runs to cancel for agent {agent_id}"
|
||||
details = {"error_code": "NO_ACTIVE_RUNS_TO_CANCEL", "agent_id": agent_id}
|
||||
super().__init__(message=message, code=ErrorCode.CONFLICT, details=details)
|
||||
|
||||
|
||||
class ConcurrentUpdateError(LettaError):
|
||||
"""Error raised when a resource was updated by another transaction (optimistic locking conflict)."""
|
||||
|
||||
def __init__(self, resource_type: str, resource_id: str):
|
||||
message = f"{resource_type} with id '{resource_id}' was updated by another transaction. Please retry your request."
|
||||
details = {"error_code": "CONCURRENT_UPDATE", "resource_type": resource_type, "resource_id": resource_id}
|
||||
super().__init__(message=message, code=ErrorCode.CONFLICT, details=details)
|
||||
|
||||
|
||||
class ConversationBusyError(LettaError):
|
||||
"""Error raised when attempting to send a message while another request is already processing for the same conversation."""
|
||||
|
||||
def __init__(self, conversation_id: str, lock_holder_token: Optional[str] = None):
|
||||
self.conversation_id = conversation_id
|
||||
self.lock_holder_token = lock_holder_token
|
||||
message = "Cannot send a new message: Another request is currently being processed for this conversation. Please wait for the current request to complete."
|
||||
code = ErrorCode.CONFLICT
|
||||
details = {
|
||||
"error_code": "CONVERSATION_BUSY",
|
||||
"conversation_id": conversation_id,
|
||||
"lock_holder_token": lock_holder_token,
|
||||
}
|
||||
super().__init__(message=message, code=code, details=details)
|
||||
|
||||
|
||||
class LettaToolCreateError(LettaError):
|
||||
"""Error raised when a tool cannot be created."""
|
||||
|
||||
@@ -90,6 +126,19 @@ class LettaConfigurationError(LettaError):
|
||||
super().__init__(message=message, details={"missing_fields": self.missing_fields})
|
||||
|
||||
|
||||
class EmbeddingConfigRequiredError(LettaError):
|
||||
"""Error raised when an operation requires embedding_config but the agent doesn't have one configured."""
|
||||
|
||||
def __init__(self, agent_id: Optional[str] = None, operation: Optional[str] = None):
|
||||
self.agent_id = agent_id
|
||||
self.operation = operation
|
||||
message = "This operation requires an embedding configuration, but the agent does not have one configured."
|
||||
if operation:
|
||||
message = f"Operation '{operation}' requires an embedding configuration, but the agent does not have one configured."
|
||||
details = {"agent_id": agent_id, "operation": operation}
|
||||
super().__init__(message=message, code=ErrorCode.INVALID_ARGUMENT, details=details)
|
||||
|
||||
|
||||
class LettaAgentNotFoundError(LettaError):
|
||||
"""Error raised when an agent is not found."""
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ async def archival_memory_search(
|
||||
top_k: Maximum number of results to return (default: 10)
|
||||
|
||||
Returns:
|
||||
A list of relevant memories with timestamps and content, ranked by similarity.
|
||||
A list of relevant memories with IDs, timestamps, and content, ranked by similarity.
|
||||
|
||||
Examples:
|
||||
# Search for project discussions
|
||||
|
||||
@@ -12,7 +12,9 @@ from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.run import Run, RunUpdate
|
||||
from letta.schemas.user import User
|
||||
@@ -44,6 +46,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
use_assistant_message: bool = False,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> LettaResponse:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -57,6 +60,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
client_tools=client_tools,
|
||||
)
|
||||
|
||||
await self.run_sleeptime_agents()
|
||||
@@ -74,6 +78,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -90,6 +95,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
client_tools=client_tools,
|
||||
):
|
||||
yield chunk
|
||||
finally:
|
||||
@@ -214,6 +220,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
run_update = RunUpdate(
|
||||
status=RunStatus.completed,
|
||||
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
stop_reason=result.stop_reason.stop_reason if result.stop_reason else StopReasonType.end_turn,
|
||||
metadata={
|
||||
"result": result.model_dump(mode="json"),
|
||||
"agent_id": sleeptime_agent_state.id,
|
||||
@@ -225,6 +232,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
run_update = RunUpdate(
|
||||
status=RunStatus.failed,
|
||||
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
stop_reason=StopReasonType.error,
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
await self.run_manager.update_run_by_id_async(run_id=run_id, update=run_update, actor=self.actor)
|
||||
|
||||
@@ -12,7 +12,9 @@ from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.run import Run, RunUpdate
|
||||
from letta.schemas.user import User
|
||||
@@ -44,6 +46,8 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> LettaResponse:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -57,6 +61,8 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
conversation_id=conversation_id,
|
||||
client_tools=client_tools,
|
||||
)
|
||||
|
||||
run_ids = await self.run_sleeptime_agents()
|
||||
@@ -73,6 +79,8 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -89,6 +97,8 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
conversation_id=conversation_id,
|
||||
client_tools=client_tools,
|
||||
):
|
||||
yield chunk
|
||||
finally:
|
||||
@@ -231,6 +241,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
run_update = RunUpdate(
|
||||
status=RunStatus.completed,
|
||||
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
stop_reason=result.stop_reason.stop_reason if result.stop_reason else StopReasonType.end_turn,
|
||||
metadata={
|
||||
"result": result.model_dump(mode="json"),
|
||||
"agent_id": sleeptime_agent_state.id,
|
||||
@@ -242,6 +253,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
run_update = RunUpdate(
|
||||
status=RunStatus.failed,
|
||||
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
stop_reason=StopReasonType.error,
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
await self.run_manager.update_run_by_id_async(run_id=run_id, update=run_update, actor=self.actor)
|
||||
|
||||
@@ -6,6 +6,10 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
|
||||
from sqlalchemy import Dialect
|
||||
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
from letta.helpers.json_helpers import sanitize_null_bytes
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderType, ToolRuleType
|
||||
from letta.schemas.letta_message import ApprovalReturn, MessageReturnType
|
||||
@@ -184,16 +188,22 @@ def deserialize_tool_rule(
|
||||
|
||||
|
||||
def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]) -> List[Dict]:
|
||||
"""Convert a list of OpenAI ToolCall objects into JSON-serializable format."""
|
||||
"""Convert a list of OpenAI ToolCall objects into JSON-serializable format.
|
||||
|
||||
Note: Tool call arguments may contain null bytes from various sources.
|
||||
These are sanitized to prevent PostgreSQL errors.
|
||||
"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
serialized_calls = []
|
||||
for call in tool_calls:
|
||||
if isinstance(call, OpenAIToolCall):
|
||||
serialized_calls.append(call.model_dump(mode="json"))
|
||||
# Sanitize null bytes from tool call data to prevent PostgreSQL errors
|
||||
serialized_calls.append(sanitize_null_bytes(call.model_dump(mode="json")))
|
||||
elif isinstance(call, dict):
|
||||
serialized_calls.append(call) # Already a dictionary, leave it as-is
|
||||
# Sanitize null bytes from dictionary data
|
||||
serialized_calls.append(sanitize_null_bytes(call))
|
||||
else:
|
||||
raise TypeError(f"Unexpected tool call type: {type(call)}")
|
||||
|
||||
@@ -221,16 +231,22 @@ def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]:
|
||||
|
||||
|
||||
def serialize_tool_returns(tool_returns: Optional[List[Union[ToolReturn, dict]]]) -> List[Dict]:
|
||||
"""Convert a list of ToolReturn objects into JSON-serializable format."""
|
||||
"""Convert a list of ToolReturn objects into JSON-serializable format.
|
||||
|
||||
Note: Tool returns may contain null bytes from sandbox execution or binary data.
|
||||
These are sanitized to prevent PostgreSQL errors.
|
||||
"""
|
||||
if not tool_returns:
|
||||
return []
|
||||
|
||||
serialized_tool_returns = []
|
||||
for tool_return in tool_returns:
|
||||
if isinstance(tool_return, ToolReturn):
|
||||
serialized_tool_returns.append(tool_return.model_dump(mode="json"))
|
||||
# Sanitize null bytes from tool return data to prevent PostgreSQL errors
|
||||
serialized_tool_returns.append(sanitize_null_bytes(tool_return.model_dump(mode="json")))
|
||||
elif isinstance(tool_return, dict):
|
||||
serialized_tool_returns.append(tool_return) # Already a dictionary, leave it as-is
|
||||
# Sanitize null bytes from dictionary data
|
||||
serialized_tool_returns.append(sanitize_null_bytes(tool_return))
|
||||
else:
|
||||
raise TypeError(f"Unexpected tool return type: {type(tool_return)}")
|
||||
|
||||
@@ -256,18 +272,24 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
|
||||
|
||||
|
||||
def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolReturn, dict]]]) -> List[Dict]:
|
||||
"""Convert a list of ToolReturn objects into JSON-serializable format."""
|
||||
"""Convert a list of ToolReturn objects into JSON-serializable format.
|
||||
|
||||
Note: Approval data may contain null bytes from various sources.
|
||||
These are sanitized to prevent PostgreSQL errors.
|
||||
"""
|
||||
if not approvals:
|
||||
return []
|
||||
|
||||
serialized_approvals = []
|
||||
for approval in approvals:
|
||||
if isinstance(approval, ApprovalReturn):
|
||||
serialized_approvals.append(approval.model_dump(mode="json"))
|
||||
# Sanitize null bytes from approval data to prevent PostgreSQL errors
|
||||
serialized_approvals.append(sanitize_null_bytes(approval.model_dump(mode="json")))
|
||||
elif isinstance(approval, ToolReturn):
|
||||
serialized_approvals.append(approval.model_dump(mode="json"))
|
||||
serialized_approvals.append(sanitize_null_bytes(approval.model_dump(mode="json")))
|
||||
elif isinstance(approval, dict):
|
||||
serialized_approvals.append(approval) # Already a dictionary, leave it as-is
|
||||
# Sanitize null bytes from dictionary data
|
||||
serialized_approvals.append(sanitize_null_bytes(approval))
|
||||
else:
|
||||
raise TypeError(f"Unexpected approval type: {type(approval)}")
|
||||
|
||||
@@ -318,7 +340,11 @@ def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalRetu
|
||||
|
||||
|
||||
def serialize_message_content(message_content: Optional[List[Union[MessageContent, dict]]]) -> List[Dict]:
|
||||
"""Convert a list of MessageContent objects into JSON-serializable format."""
|
||||
"""Convert a list of MessageContent objects into JSON-serializable format.
|
||||
|
||||
Note: Message content may contain null bytes from various sources.
|
||||
These are sanitized to prevent PostgreSQL errors.
|
||||
"""
|
||||
if not message_content:
|
||||
return []
|
||||
|
||||
@@ -327,9 +353,11 @@ def serialize_message_content(message_content: Optional[List[Union[MessageConten
|
||||
if isinstance(content, MessageContent):
|
||||
if content.type == MessageContentType.image:
|
||||
assert content.source.type == ImageSourceType.letta, f"Invalid image source type: {content.source.type}"
|
||||
serialized_message_content.append(content.model_dump(mode="json"))
|
||||
# Sanitize null bytes from message content to prevent PostgreSQL errors
|
||||
serialized_message_content.append(sanitize_null_bytes(content.model_dump(mode="json")))
|
||||
elif isinstance(content, dict):
|
||||
serialized_message_content.append(content) # Already a dictionary, leave it as-is
|
||||
# Sanitize null bytes from dictionary data
|
||||
serialized_message_content.append(sanitize_null_bytes(content))
|
||||
else:
|
||||
raise TypeError(f"Unexpected message content type: {type(content)}")
|
||||
return serialized_message_content
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
# Eagerly load the cryptography backend at module import time.
|
||||
_CRYPTO_BACKEND = default_backend()
|
||||
|
||||
# Dedicated thread pool for CPU-intensive crypto operations
|
||||
# Prevents crypto from blocking health checks and other operations
|
||||
_crypto_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="CryptoWorker")
|
||||
|
||||
# Common API key prefixes that should not be considered encrypted
|
||||
# These are plaintext credentials that happen to be long strings
|
||||
PLAINTEXT_PREFIXES = (
|
||||
@@ -46,6 +53,11 @@ class CryptoUtils:
|
||||
# Salt size for key derivation
|
||||
SALT_SIZE = 16
|
||||
|
||||
# WARNING: DO NOT CHANGE THIS VALUE UNLESS YOU ARE SURE WHAT YOU ARE DOING
|
||||
# EXISTING ENCRYPTED SECRETS MUST BE DECRYPTED WITH THE SAME ITERATIONS
|
||||
# Number of PBKDF2 iterations
|
||||
PBKDF2_ITERATIONS = 100000
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=256)
|
||||
def _derive_key_cached(cls, master_key: str, salt: bytes) -> bytes:
|
||||
@@ -55,11 +67,19 @@ class CryptoUtils:
|
||||
This is a CPU-intensive operation (100k iterations of PBKDF2-HMAC-SHA256)
|
||||
that can take 100-500ms. Results are cached since key derivation is deterministic.
|
||||
|
||||
Uses Python's standard hashlib.pbkdf2_hmac which produces identical output
|
||||
to the cryptography library's PBKDF2HMAC for the same parameters.
|
||||
|
||||
WARNING: This is a synchronous blocking operation. Use _derive_key_async()
|
||||
in async contexts to avoid blocking the event loop.
|
||||
"""
|
||||
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=cls.KEY_SIZE, salt=salt, iterations=100000, backend=default_backend())
|
||||
return kdf.derive(master_key.encode())
|
||||
return hashlib.pbkdf2_hmac(
|
||||
hash_name="sha256",
|
||||
password=master_key.encode(),
|
||||
salt=salt,
|
||||
iterations=cls.PBKDF2_ITERATIONS,
|
||||
dklen=cls.KEY_SIZE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _derive_key(cls, master_key: str, salt: bytes) -> bytes:
|
||||
@@ -69,13 +89,16 @@ class CryptoUtils:
|
||||
@classmethod
|
||||
async def _derive_key_async(cls, master_key: str, salt: bytes) -> bytes:
|
||||
"""
|
||||
Async version of _derive_key that runs PBKDF2 in a thread pool.
|
||||
Async version of _derive_key that runs PBKDF2 in a dedicated thread pool.
|
||||
|
||||
This prevents PBKDF2 (a CPU-intensive operation) from blocking the event loop.
|
||||
PBKDF2 with 100k iterations typically takes 100-500ms, which would freeze
|
||||
the event loop and prevent all other requests from being processed.
|
||||
Uses a dedicated crypto thread pool (8 workers) to prevent PBKDF2 operations
|
||||
from exhausting the default ThreadPoolExecutor (16 threads) and blocking
|
||||
health checks and other operations during high load.
|
||||
|
||||
PBKDF2 with 100k iterations typically takes 100-500ms per operation.
|
||||
"""
|
||||
return await asyncio.to_thread(cls._derive_key, master_key, salt)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(_crypto_executor, cls._derive_key, master_key, salt)
|
||||
|
||||
@classmethod
|
||||
def encrypt(cls, plaintext: str, master_key: Optional[str] = None) -> str:
|
||||
@@ -111,7 +134,7 @@ class CryptoUtils:
|
||||
key = cls._derive_key(master_key, salt)
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=_CRYPTO_BACKEND)
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
# Encrypt the plaintext
|
||||
@@ -160,7 +183,7 @@ class CryptoUtils:
|
||||
key = await cls._derive_key_async(master_key, salt)
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=_CRYPTO_BACKEND)
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
# Encrypt the plaintext
|
||||
@@ -215,7 +238,7 @@ class CryptoUtils:
|
||||
key = cls._derive_key(master_key, salt)
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=_CRYPTO_BACKEND)
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
# Decrypt the ciphertext
|
||||
@@ -266,7 +289,7 @@ class CryptoUtils:
|
||||
key = await cls._derive_key_async(master_key, salt)
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=_CRYPTO_BACKEND)
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
# Decrypt the ciphertext
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel
|
||||
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import tracer
|
||||
from letta.plugins.plugins import get_experimental_checker
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -109,35 +110,59 @@ def async_redis_cache(
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
redis_client = await get_redis_client()
|
||||
with tracer.start_as_current_span("redis_cache", attributes={"cache.function": func.__name__}) as span:
|
||||
# 1. Get Redis client
|
||||
with tracer.start_as_current_span("redis_cache.get_client"):
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
# Don't bother going through other operations for no reason.
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
return await func(*args, **kwargs)
|
||||
cache_key = get_cache_key(*args, **kwargs)
|
||||
cached_value = await redis_client.get(cache_key)
|
||||
# Don't bother going through other operations for no reason.
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
span.set_attribute("cache.noop", True)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
if cached_value is not None:
|
||||
stats.hits += 1
|
||||
if model_class:
|
||||
return model_class.model_validate_json(cached_value)
|
||||
return json.loads(cached_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to retrieve value from cache: {e}")
|
||||
cache_key = get_cache_key(*args, **kwargs)
|
||||
span.set_attribute("cache.key", cache_key)
|
||||
|
||||
stats.misses += 1
|
||||
result = await func(*args, **kwargs)
|
||||
try:
|
||||
if model_class:
|
||||
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
|
||||
elif isinstance(result, (dict, list, str, int, float, bool)):
|
||||
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
|
||||
else:
|
||||
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return result
|
||||
# 2. Try cache read
|
||||
with tracer.start_as_current_span("redis_cache.get") as get_span:
|
||||
cached_value = await redis_client.get(cache_key)
|
||||
get_span.set_attribute("cache.hit", cached_value is not None)
|
||||
|
||||
try:
|
||||
if cached_value is not None:
|
||||
stats.hits += 1
|
||||
span.set_attribute("cache.result", "hit")
|
||||
# 3. Deserialize cache hit
|
||||
with tracer.start_as_current_span("redis_cache.deserialize"):
|
||||
if model_class:
|
||||
return model_class.model_validate_json(cached_value)
|
||||
return json.loads(cached_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to retrieve value from cache: {e}")
|
||||
span.record_exception(e)
|
||||
|
||||
stats.misses += 1
|
||||
span.set_attribute("cache.result", "miss")
|
||||
|
||||
# 4. Call original function
|
||||
with tracer.start_as_current_span("redis_cache.call_original"):
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# 5. Write to cache
|
||||
try:
|
||||
with tracer.start_as_current_span("redis_cache.set") as set_span:
|
||||
if model_class:
|
||||
await redis_client.set(cache_key, result.model_dump_json(), ex=ttl_s)
|
||||
elif isinstance(result, (dict, list, str, int, float, bool)):
|
||||
await redis_client.set(cache_key, json.dumps(result), ex=ttl_s)
|
||||
else:
|
||||
set_span.set_attribute("cache.set_skipped", True)
|
||||
logger.warning(f"Cannot cache result of type {type(result).__name__} for {func.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
span.record_exception(e)
|
||||
|
||||
return result
|
||||
|
||||
async def invalidate(*args, **kwargs) -> bool:
|
||||
stats.invalidations += 1
|
||||
|
||||
@@ -1,6 +1,42 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
def sanitize_null_bytes(value: Any) -> Any:
|
||||
"""Recursively remove null bytes (0x00) from strings.
|
||||
|
||||
PostgreSQL TEXT columns don't accept null bytes in UTF-8 encoding, which causes
|
||||
asyncpg.exceptions.CharacterNotInRepertoireError when data with null bytes is inserted.
|
||||
|
||||
This function sanitizes:
|
||||
- Strings: removes all null bytes
|
||||
- Dicts: recursively sanitizes all string values
|
||||
- Lists: recursively sanitizes all elements
|
||||
- Other types: returned as-is
|
||||
|
||||
Args:
|
||||
value: The value to sanitize
|
||||
|
||||
Returns:
|
||||
The sanitized value with null bytes removed from all strings
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
# Remove null bytes from strings
|
||||
return value.replace("\x00", "")
|
||||
elif isinstance(value, dict):
|
||||
# Recursively sanitize dictionary keys and values
|
||||
return {sanitize_null_bytes(k): sanitize_null_bytes(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
# Recursively sanitize list elements
|
||||
return [sanitize_null_bytes(item) for item in value]
|
||||
elif isinstance(value, tuple):
|
||||
# Recursively sanitize tuple elements (return as tuple)
|
||||
return tuple(sanitize_null_bytes(item) for item in value)
|
||||
else:
|
||||
# Return other types as-is (int, float, bool, None, etc.)
|
||||
return value
|
||||
|
||||
|
||||
def json_loads(data):
|
||||
@@ -8,15 +44,33 @@ def json_loads(data):
|
||||
|
||||
|
||||
def json_dumps(data, indent=2) -> str:
|
||||
"""Serialize data to JSON string, sanitizing null bytes to prevent PostgreSQL errors.
|
||||
|
||||
PostgreSQL TEXT columns reject null bytes (0x00) in UTF-8 encoding. This function
|
||||
sanitizes all strings in the data structure before JSON serialization to prevent
|
||||
asyncpg.exceptions.CharacterNotInRepertoireError.
|
||||
|
||||
Args:
|
||||
data: The data to serialize
|
||||
indent: JSON indentation level (default: 2)
|
||||
|
||||
Returns:
|
||||
JSON string with null bytes removed from all string values
|
||||
"""
|
||||
# Sanitize null bytes before serialization to prevent PostgreSQL errors
|
||||
sanitized_data = sanitize_null_bytes(data)
|
||||
|
||||
def safe_serializer(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, bytes):
|
||||
try:
|
||||
return obj.decode("utf-8")
|
||||
decoded = obj.decode("utf-8")
|
||||
# Also sanitize decoded bytes
|
||||
return decoded.replace("\x00", "")
|
||||
except Exception:
|
||||
# TODO: this is to handle Gemini thought signatures, b64 decode this back to bytes when sending back to Gemini
|
||||
return base64.b64encode(obj).decode("utf-8")
|
||||
raise TypeError(f"Type {type(obj)} not serializable")
|
||||
|
||||
return json.dumps(data, indent=indent, default=safe_serializer, ensure_ascii=False)
|
||||
return json.dumps(sanitized_data, indent=indent, default=safe_serializer, ensure_ascii=False)
|
||||
|
||||
@@ -12,23 +12,52 @@ from letta.schemas.letta_message_content import Base64Image, ImageContent, Image
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
|
||||
|
||||
async def _fetch_image_from_url(url: str) -> tuple[bytes, str | None]:
|
||||
async def _fetch_image_from_url(url: str, max_retries: int = 1, timeout_seconds: float = 5.0) -> tuple[bytes, str | None]:
|
||||
"""
|
||||
Async helper to fetch image from URL without blocking the event loop.
|
||||
Retries once on timeout to handle transient network issues.
|
||||
|
||||
Args:
|
||||
url: URL of the image to fetch
|
||||
max_retries: Number of retry attempts (default: 1)
|
||||
timeout_seconds: Total timeout in seconds (default: 5.0)
|
||||
|
||||
Returns:
|
||||
Tuple of (image_bytes, media_type)
|
||||
|
||||
Raises:
|
||||
LettaImageFetchError: If image fetch fails after all retries
|
||||
"""
|
||||
timeout = httpx.Timeout(15.0, connect=5.0)
|
||||
# Connect timeout is half of total timeout, capped at 3 seconds
|
||||
connect_timeout = min(timeout_seconds / 2, 3.0)
|
||||
timeout = httpx.Timeout(timeout_seconds, connect=connect_timeout)
|
||||
headers = {"User-Agent": f"Letta/{__version__}"}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout, headers=headers) as client:
|
||||
image_response = await client.get(url, follow_redirects=True)
|
||||
image_response.raise_for_status()
|
||||
image_bytes = image_response.content
|
||||
image_media_type = image_response.headers.get("content-type")
|
||||
return image_bytes, image_media_type
|
||||
except (httpx.RemoteProtocolError, httpx.TimeoutException, httpx.HTTPStatusError) as e:
|
||||
raise LettaImageFetchError(url=url, reason=str(e))
|
||||
except Exception as e:
|
||||
raise LettaImageFetchError(url=url, reason=f"Unexpected error: {e}")
|
||||
|
||||
last_exception = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout, headers=headers) as client:
|
||||
image_response = await client.get(url, follow_redirects=True)
|
||||
image_response.raise_for_status()
|
||||
image_bytes = image_response.content
|
||||
image_media_type = image_response.headers.get("content-type")
|
||||
return image_bytes, image_media_type
|
||||
except httpx.TimeoutException as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries:
|
||||
# Brief delay before retry
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
# Final attempt failed
|
||||
raise LettaImageFetchError(url=url, reason=f"Timeout after {max_retries + 1} attempts: {e}")
|
||||
except (httpx.RemoteProtocolError, httpx.HTTPStatusError) as e:
|
||||
# Don't retry on protocol errors or HTTP errors (4xx, 5xx)
|
||||
raise LettaImageFetchError(url=url, reason=str(e))
|
||||
except Exception as e:
|
||||
raise LettaImageFetchError(url=url, reason=f"Unexpected error: {e}")
|
||||
|
||||
# Should never reach here, but just in case
|
||||
raise LettaImageFetchError(url=url, reason=f"Failed after {max_retries + 1} attempts: {last_exception}")
|
||||
|
||||
|
||||
async def convert_message_creates_to_messages(
|
||||
|
||||
@@ -400,6 +400,7 @@ class TurbopufferClient:
|
||||
created_ats: List[datetime],
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
conversation_ids: Optional[List[Optional[str]]] = None,
|
||||
) -> bool:
|
||||
"""Insert messages into Turbopuffer.
|
||||
|
||||
@@ -413,6 +414,7 @@ class TurbopufferClient:
|
||||
created_ats: List of creation timestamps for each message
|
||||
project_id: Optional project ID for all messages
|
||||
template_id: Optional template ID for all messages
|
||||
conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts)
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
@@ -441,22 +443,26 @@ class TurbopufferClient:
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
||||
if len(message_ids) != len(created_ats):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
|
||||
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
|
||||
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
|
||||
|
||||
# prepare column-based data for turbopuffer - optimized for batch insert
|
||||
ids = []
|
||||
vectors = []
|
||||
texts = []
|
||||
organization_ids = []
|
||||
agent_ids = []
|
||||
organization_ids_list = []
|
||||
agent_ids_list = []
|
||||
message_roles = []
|
||||
created_at_timestamps = []
|
||||
project_ids = []
|
||||
template_ids = []
|
||||
project_ids_list = []
|
||||
template_ids_list = []
|
||||
conversation_ids_list = []
|
||||
|
||||
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
|
||||
message_id = message_ids[original_idx]
|
||||
role = roles[original_idx]
|
||||
created_at = created_ats[original_idx]
|
||||
conversation_id = conversation_ids[original_idx] if conversation_ids else None
|
||||
|
||||
# ensure the provided timestamp is timezone-aware and in UTC
|
||||
if created_at.tzinfo is None:
|
||||
@@ -470,31 +476,36 @@ class TurbopufferClient:
|
||||
ids.append(message_id)
|
||||
vectors.append(embedding)
|
||||
texts.append(text)
|
||||
organization_ids.append(organization_id)
|
||||
agent_ids.append(agent_id)
|
||||
organization_ids_list.append(organization_id)
|
||||
agent_ids_list.append(agent_id)
|
||||
message_roles.append(role.value)
|
||||
created_at_timestamps.append(timestamp)
|
||||
project_ids.append(project_id)
|
||||
template_ids.append(template_id)
|
||||
project_ids_list.append(project_id)
|
||||
template_ids_list.append(template_id)
|
||||
conversation_ids_list.append(conversation_id)
|
||||
|
||||
# build column-based upsert data
|
||||
upsert_columns = {
|
||||
"id": ids,
|
||||
"vector": vectors,
|
||||
"text": texts,
|
||||
"organization_id": organization_ids,
|
||||
"agent_id": agent_ids,
|
||||
"organization_id": organization_ids_list,
|
||||
"agent_id": agent_ids_list,
|
||||
"role": message_roles,
|
||||
"created_at": created_at_timestamps,
|
||||
}
|
||||
|
||||
# only include conversation_id if it's provided
|
||||
if conversation_ids is not None:
|
||||
upsert_columns["conversation_id"] = conversation_ids_list
|
||||
|
||||
# only include project_id if it's provided
|
||||
if project_id is not None:
|
||||
upsert_columns["project_id"] = project_ids
|
||||
upsert_columns["project_id"] = project_ids_list
|
||||
|
||||
# only include template_id if it's provided
|
||||
if template_id is not None:
|
||||
upsert_columns["template_id"] = template_ids
|
||||
upsert_columns["template_id"] = template_ids_list
|
||||
|
||||
try:
|
||||
# use global semaphore to limit concurrent Turbopuffer writes
|
||||
@@ -506,7 +517,10 @@ class TurbopufferClient:
|
||||
await namespace.write(
|
||||
upsert_columns=upsert_columns,
|
||||
distance_metric="cosine_distance",
|
||||
schema={"text": {"type": "string", "full_text_search": True}},
|
||||
schema={
|
||||
"text": {"type": "string", "full_text_search": True},
|
||||
"conversation_id": {"type": "string"},
|
||||
},
|
||||
)
|
||||
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
|
||||
return True
|
||||
@@ -561,67 +575,80 @@ class TurbopufferClient:
|
||||
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
|
||||
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
|
||||
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
|
||||
if search_mode == "timestamp":
|
||||
# retrieve most recent items by timestamp
|
||||
query_params = {
|
||||
"rank_by": ("created_at", "desc"),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
if search_mode == "timestamp":
|
||||
# retrieve most recent items by timestamp
|
||||
query_params = {
|
||||
"rank_by": ("created_at", "desc"),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
|
||||
elif search_mode == "vector":
|
||||
# vector search query
|
||||
query_params = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
elif search_mode == "vector":
|
||||
# vector search query
|
||||
query_params = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
|
||||
elif search_mode == "fts":
|
||||
# full-text search query
|
||||
query_params = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
elif search_mode == "fts":
|
||||
# full-text search query
|
||||
query_params = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
|
||||
else: # hybrid mode
|
||||
queries = []
|
||||
else: # hybrid mode
|
||||
queries = []
|
||||
|
||||
# vector search query
|
||||
vector_query = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
vector_query["filters"] = filters
|
||||
queries.append(vector_query)
|
||||
# vector search query
|
||||
vector_query = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
vector_query["filters"] = filters
|
||||
queries.append(vector_query)
|
||||
|
||||
# full-text search query
|
||||
fts_query = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
fts_query["filters"] = filters
|
||||
queries.append(fts_query)
|
||||
# full-text search query
|
||||
fts_query = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
fts_query["filters"] = filters
|
||||
queries.append(fts_query)
|
||||
|
||||
# execute multi-query
|
||||
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
||||
# execute multi-query
|
||||
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
||||
except Exception as e:
|
||||
# Wrap turbopuffer errors with user-friendly messages
|
||||
from turbopuffer import NotFoundError
|
||||
|
||||
if isinstance(e, NotFoundError):
|
||||
# Extract just the error message without implementation details
|
||||
error_msg = str(e)
|
||||
if "namespace" in error_msg.lower() and "not found" in error_msg.lower():
|
||||
raise ValueError("No conversation history found. Please send a message first to enable search.") from e
|
||||
raise ValueError(f"Search data not found: {error_msg}") from e
|
||||
# Re-raise other errors as-is
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def query_passages(
|
||||
@@ -779,6 +806,7 @@ class TurbopufferClient:
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -796,6 +824,7 @@ class TurbopufferClient:
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
conversation_id: Optional conversation ID to filter messages by (use "default" for NULL)
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
@@ -862,6 +891,19 @@ class TurbopufferClient:
|
||||
if template_id:
|
||||
template_filter = ("template_id", "Eq", template_id)
|
||||
|
||||
# build conversation_id filter if provided
|
||||
# three cases:
|
||||
# 1. conversation_id=None (omitted) -> return all messages (no filter)
|
||||
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
|
||||
# 3. conversation_id="xyz" -> return only messages in that conversation
|
||||
conversation_filter = None
|
||||
if conversation_id == "default":
|
||||
# "default" is reserved for default messages only (conversation_id is none)
|
||||
conversation_filter = ("conversation_id", "Eq", None)
|
||||
elif conversation_id is not None:
|
||||
# Specific conversation
|
||||
conversation_filter = ("conversation_id", "Eq", conversation_id)
|
||||
|
||||
# combine all filters
|
||||
all_filters = [agent_filter] # always include agent_id filter
|
||||
if role_filter:
|
||||
@@ -870,6 +912,8 @@ class TurbopufferClient:
|
||||
all_filters.append(project_filter)
|
||||
if template_filter:
|
||||
all_filters.append(template_filter)
|
||||
if conversation_filter:
|
||||
all_filters.append(conversation_filter)
|
||||
if date_filters:
|
||||
all_filters.extend(date_filters)
|
||||
|
||||
@@ -888,7 +932,7 @@ class TurbopufferClient:
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
|
||||
include_attributes=True,
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
@@ -939,6 +983,7 @@ class TurbopufferClient:
|
||||
agent_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -956,6 +1001,10 @@ class TurbopufferClient:
|
||||
agent_id: Optional agent ID to filter messages by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
conversation_id: Optional conversation ID to filter messages by. Special values:
|
||||
- None (omitted): Return all messages
|
||||
- "default": Return only default messages (conversation_id IS NULL)
|
||||
- Any other value: Return messages in that specific conversation
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
@@ -1004,6 +1053,18 @@ class TurbopufferClient:
|
||||
if template_id:
|
||||
all_filters.append(("template_id", "Eq", template_id))
|
||||
|
||||
# conversation filter
|
||||
# three cases:
|
||||
# 1. conversation_id=None (omitted) -> return all messages (no filter)
|
||||
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
|
||||
# 3. conversation_id="xyz" -> return only messages in that conversation
|
||||
if conversation_id == "default":
|
||||
# "default" is reserved for default messages only (conversation_id is none)
|
||||
all_filters.append(("conversation_id", "Eq", None))
|
||||
elif conversation_id is not None:
|
||||
# Specific conversation
|
||||
all_filters.append(("conversation_id", "Eq", conversation_id))
|
||||
|
||||
# date filters
|
||||
if start_date:
|
||||
# Convert to UTC to match stored timestamps
|
||||
@@ -1036,7 +1097,7 @@ class TurbopufferClient:
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
|
||||
include_attributes=True,
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
@@ -1121,6 +1182,7 @@ class TurbopufferClient:
|
||||
"agent_id": getattr(row, "agent_id", None),
|
||||
"role": getattr(row, "role", None),
|
||||
"created_at": getattr(row, "created_at", None),
|
||||
"conversation_id": getattr(row, "conversation_id", None),
|
||||
}
|
||||
messages.append(message_dict)
|
||||
|
||||
|
||||
@@ -88,6 +88,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
self.tool_call_name = None
|
||||
self.accumulated_tool_call_args = ""
|
||||
self.previous_parse = {}
|
||||
self.thinking_signature = None
|
||||
|
||||
# usage trackers
|
||||
self.input_tokens = 0
|
||||
@@ -426,20 +427,23 @@ class SimpleAnthropicStreamingInterface:
|
||||
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning=delta.thinking,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
# Only emit reasoning message if we have actual content
|
||||
if delta.thinking and delta.thinking.strip():
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning=delta.thinking,
|
||||
signature=self.thinking_signature,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
elif isinstance(delta, BetaSignatureDelta):
|
||||
# Safety check
|
||||
@@ -448,21 +452,15 @@ class SimpleAnthropicStreamingInterface:
|
||||
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning="",
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
signature=delta.signature,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
# Store signature but don't emit empty reasoning message
|
||||
# Signature will be attached when actual thinking content arrives
|
||||
self.thinking_signature = delta.signature
|
||||
|
||||
# Update the last reasoning message with the signature so it gets persisted
|
||||
if self.reasoning_messages:
|
||||
last_msg = self.reasoning_messages[-1]
|
||||
if isinstance(last_msg, ReasoningMessage):
|
||||
last_msg.signature = delta.signature
|
||||
|
||||
elif isinstance(event, BetaRawMessageStartEvent):
|
||||
self.message_id = event.message.id
|
||||
|
||||
@@ -224,40 +224,32 @@ class SimpleGeminiStreamingInterface:
|
||||
# NOTE: the thought_signature comes on the Part with the function_call
|
||||
thought_signature = part.thought_signature
|
||||
self.thinking_signature = base64.b64encode(thought_signature).decode("utf-8")
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
yield ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
source="reasoner_model",
|
||||
reasoning="",
|
||||
signature=self.thinking_signature,
|
||||
)
|
||||
prev_message_type = "reasoning_message"
|
||||
# Don't emit empty reasoning message - signature will be attached to actual reasoning content
|
||||
|
||||
# Thinking summary content part (bool means text is thought part)
|
||||
if part.thought:
|
||||
reasoning_summary = part.text
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
yield ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
source="reasoner_model",
|
||||
reasoning=reasoning_summary,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "reasoning_message"
|
||||
self.content_parts.append(
|
||||
ReasoningContent(
|
||||
is_native=True,
|
||||
# Only emit reasoning message if we have actual content
|
||||
if reasoning_summary and reasoning_summary.strip():
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
yield ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
source="reasoner_model",
|
||||
reasoning=reasoning_summary,
|
||||
signature=self.thinking_signature,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "reasoning_message"
|
||||
self.content_parts.append(
|
||||
ReasoningContent(
|
||||
is_native=True,
|
||||
reasoning=reasoning_summary,
|
||||
signature=self.thinking_signature,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Plain text content part
|
||||
elif part.text:
|
||||
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import BetaMessage as AnthropicMessage, BetaRawMessageStreamEvent
|
||||
from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming
|
||||
@@ -28,6 +29,7 @@ from letta.errors import (
|
||||
)
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.decorators import deprecated
|
||||
from letta.llm_api.anthropic_constants import ANTHROPIC_MAX_STRICT_TOOLS, ANTHROPIC_STRICT_MODE_ALLOWLIST
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
@@ -81,9 +83,17 @@ class AnthropicClient(LLMClientBase):
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
|
||||
betas.append("context-management-2025-06-27")
|
||||
|
||||
# Structured outputs beta
|
||||
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
|
||||
betas.append("structured-outputs-2025-11-13")
|
||||
# Structured outputs beta - only for supported models
|
||||
# Supported: Claude Sonnet 4.5, Opus 4.1, Opus 4.5, Haiku 4.5
|
||||
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
|
||||
# See PR #7495 for original implementation
|
||||
# supports_structured_outputs = _supports_structured_outputs(llm_config.model)
|
||||
#
|
||||
# if supports_structured_outputs:
|
||||
# # Always enable structured outputs beta on supported models.
|
||||
# # NOTE: We do NOT send `strict` on tool schemas because the current Anthropic SDK
|
||||
# # typed tool params reject unknown fields (e.g., `tools.0.custom.strict`).
|
||||
# betas.append("structured-outputs-2025-11-13")
|
||||
|
||||
if betas:
|
||||
response = client.beta.messages.create(**request_data, betas=betas)
|
||||
@@ -94,7 +104,6 @@ class AnthropicClient(LLMClientBase):
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
|
||||
betas: list[str] = []
|
||||
# interleaved thinking for reasoner
|
||||
if llm_config.enable_reasoner:
|
||||
@@ -119,9 +128,13 @@ class AnthropicClient(LLMClientBase):
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
|
||||
betas.append("context-management-2025-06-27")
|
||||
|
||||
# Structured outputs beta
|
||||
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
|
||||
betas.append("structured-outputs-2025-11-13")
|
||||
# Structured outputs beta - only for supported models
|
||||
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
|
||||
# See PR #7495 for original implementation
|
||||
# supports_structured_outputs = _supports_structured_outputs(llm_config.model)
|
||||
#
|
||||
# if supports_structured_outputs:
|
||||
# betas.append("structured-outputs-2025-11-13")
|
||||
|
||||
if betas:
|
||||
response = await client.beta.messages.create(**request_data, betas=betas)
|
||||
@@ -164,9 +177,13 @@ class AnthropicClient(LLMClientBase):
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
|
||||
betas.append("context-management-2025-06-27")
|
||||
|
||||
# Structured outputs beta
|
||||
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
|
||||
betas.append("structured-outputs-2025-11-13")
|
||||
# Structured outputs beta - only for supported models
|
||||
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
|
||||
# See PR #7495 for original implementation
|
||||
# supports_structured_outputs = _supports_structured_outputs(llm_config.model)
|
||||
#
|
||||
# if supports_structured_outputs:
|
||||
# betas.append("structured-outputs-2025-11-13")
|
||||
|
||||
# log failed requests
|
||||
try:
|
||||
@@ -234,17 +251,35 @@ class AnthropicClient(LLMClientBase):
|
||||
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
api_key, _, _ = self.get_byok_overrides(llm_config)
|
||||
|
||||
# For claude-pro-max provider, use OAuth Bearer token instead of api_key
|
||||
is_oauth_provider = llm_config.provider_name == "claude-pro-max"
|
||||
|
||||
if async_client:
|
||||
return (
|
||||
anthropic.AsyncAnthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
|
||||
if api_key
|
||||
else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
return (
|
||||
anthropic.Anthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
|
||||
if api_key
|
||||
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
if api_key:
|
||||
if is_oauth_provider:
|
||||
return anthropic.AsyncAnthropic(
|
||||
max_retries=model_settings.anthropic_max_retries,
|
||||
default_headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": "oauth-2025-04-20",
|
||||
},
|
||||
)
|
||||
return anthropic.AsyncAnthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
|
||||
return anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
|
||||
if api_key:
|
||||
if is_oauth_provider:
|
||||
return anthropic.Anthropic(
|
||||
max_retries=model_settings.anthropic_max_retries,
|
||||
default_headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": "oauth-2025-04-20",
|
||||
},
|
||||
)
|
||||
return anthropic.Anthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
|
||||
return anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
|
||||
@trace_method
|
||||
async def _get_anthropic_client_async(
|
||||
@@ -252,17 +287,35 @@ class AnthropicClient(LLMClientBase):
|
||||
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
|
||||
|
||||
# For claude-pro-max provider, use OAuth Bearer token instead of api_key
|
||||
is_oauth_provider = llm_config.provider_name == "claude-pro-max"
|
||||
|
||||
if async_client:
|
||||
return (
|
||||
anthropic.AsyncAnthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
|
||||
if api_key
|
||||
else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
return (
|
||||
anthropic.Anthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
|
||||
if api_key
|
||||
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
if api_key:
|
||||
if is_oauth_provider:
|
||||
return anthropic.AsyncAnthropic(
|
||||
max_retries=model_settings.anthropic_max_retries,
|
||||
default_headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": "oauth-2025-04-20",
|
||||
},
|
||||
)
|
||||
return anthropic.AsyncAnthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
|
||||
return anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
|
||||
if api_key:
|
||||
if is_oauth_provider:
|
||||
return anthropic.Anthropic(
|
||||
max_retries=model_settings.anthropic_max_retries,
|
||||
default_headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": "oauth-2025-04-20",
|
||||
},
|
||||
)
|
||||
return anthropic.Anthropic(api_key=api_key, max_retries=model_settings.anthropic_max_retries)
|
||||
return anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
@@ -331,11 +384,13 @@ class AnthropicClient(LLMClientBase):
|
||||
}
|
||||
|
||||
# Structured outputs via response_format
|
||||
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
|
||||
data["output_format"] = {
|
||||
"type": "json_schema",
|
||||
"schema": llm_config.response_format.json_schema["schema"],
|
||||
}
|
||||
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
|
||||
# See PR #7495 for original implementation
|
||||
# if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
|
||||
# data["output_format"] = {
|
||||
# "type": "json_schema",
|
||||
# "schema": llm_config.response_format.json_schema["schema"],
|
||||
# }
|
||||
|
||||
# Tools
|
||||
# For an overview on tool choice:
|
||||
@@ -385,7 +440,12 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
if tools_for_request and len(tools_for_request) > 0:
|
||||
# TODO eventually enable parallel tool use
|
||||
data["tools"] = convert_tools_to_anthropic_format(tools_for_request)
|
||||
# DISABLED: use_strict=False to disable structured outputs (TTFT latency impact)
|
||||
# See PR #7495 for original implementation
|
||||
data["tools"] = convert_tools_to_anthropic_format(
|
||||
tools_for_request,
|
||||
use_strict=False, # Was: _supports_structured_outputs(llm_config.model)
|
||||
)
|
||||
# Add cache control to the last tool for caching tool definitions
|
||||
if len(data["tools"]) > 0:
|
||||
data["tools"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
@@ -522,13 +582,18 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
# Use the default client; token counting is lightweight and does not require BYOK overrides
|
||||
client = anthropic.AsyncAnthropic()
|
||||
if messages and len(messages) == 0:
|
||||
messages = None
|
||||
if tools and len(tools) > 0:
|
||||
anthropic_tools = convert_tools_to_anthropic_format(tools)
|
||||
# Token counting endpoint requires additionalProperties: false (use_strict=True)
|
||||
# but does NOT support the `strict` field on tools (add_strict_field=False)
|
||||
anthropic_tools = convert_tools_to_anthropic_format(
|
||||
tools,
|
||||
use_strict=True,
|
||||
add_strict_field=False,
|
||||
)
|
||||
else:
|
||||
anthropic_tools = None
|
||||
|
||||
@@ -637,6 +702,12 @@ class AnthropicClient(LLMClientBase):
|
||||
if thinking_enabled:
|
||||
betas.append("context-management-2025-06-27")
|
||||
|
||||
# Structured outputs beta - only for supported models
|
||||
# DISABLED: Commenting out structured outputs to investigate TTFT latency impact
|
||||
# See PR #7495 for original implementation
|
||||
# if model and _supports_structured_outputs(model):
|
||||
# betas.append("structured-outputs-2025-11-13")
|
||||
|
||||
if betas:
|
||||
result = await client.beta.messages.count_tokens(**count_params, betas=betas)
|
||||
else:
|
||||
@@ -669,6 +740,8 @@ class AnthropicClient(LLMClientBase):
|
||||
or "exceeds context" in error_str
|
||||
or "too many total text bytes" in error_str
|
||||
or "total text bytes" in error_str
|
||||
or "request_too_large" in error_str
|
||||
or "request exceeds the maximum size" in error_str
|
||||
):
|
||||
logger.warning(f"[Anthropic] Context window exceeded: {str(e)}")
|
||||
return ContextWindowExceededError(
|
||||
@@ -691,6 +764,27 @@ class AnthropicClient(LLMClientBase):
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
# Handle httpx.RemoteProtocolError which can occur during streaming
|
||||
# when the remote server closes the connection unexpectedly
|
||||
# (e.g., "peer closed connection without sending complete message body")
|
||||
if isinstance(e, httpx.RemoteProtocolError):
|
||||
logger.warning(f"[Anthropic] Remote protocol error during streaming: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Connection error during Anthropic streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
# Handle httpx network errors which can occur during streaming
|
||||
# when the connection is unexpectedly closed while reading/writing
|
||||
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
|
||||
logger.warning(f"[Anthropic] Network error during streaming: {type(e).__name__}: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Network error during Anthropic streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.RateLimitError):
|
||||
logger.warning("[Anthropic] Rate limited (429). Consider backoff.")
|
||||
return LLMRateLimitError(
|
||||
@@ -750,6 +844,12 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
if isinstance(e, anthropic.APIStatusError):
|
||||
logger.warning(f"[Anthropic] API status error: {str(e)}")
|
||||
# Handle 413 Request Entity Too Large - request payload exceeds size limits
|
||||
if hasattr(e, "status_code") and e.status_code == 413:
|
||||
logger.warning(f"[Anthropic] Request too large (413): {str(e)}")
|
||||
return ContextWindowExceededError(
|
||||
message=f"Request too large for Anthropic (413): {str(e)}",
|
||||
)
|
||||
if "overloaded" in str(e).lower():
|
||||
return LLMProviderOverloaded(
|
||||
message=f"Anthropic API is overloaded: {str(e)}",
|
||||
@@ -827,7 +927,13 @@ class AnthropicClient(LLMClientBase):
|
||||
if content_part.type == "tool_use":
|
||||
# hack for incorrect tool format
|
||||
tool_input = json.loads(json.dumps(content_part.input))
|
||||
if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input:
|
||||
# Check if id is a string before calling startswith (sometimes it's an int)
|
||||
if (
|
||||
"id" in tool_input
|
||||
and isinstance(tool_input["id"], str)
|
||||
and tool_input["id"].startswith("toolu_")
|
||||
and "function" in tool_input
|
||||
):
|
||||
if isinstance(tool_input["function"], str):
|
||||
tool_input["function"] = json.loads(tool_input["function"])
|
||||
arguments = json.dumps(tool_input["function"]["arguments"], indent=2)
|
||||
@@ -964,7 +1070,34 @@ class AnthropicClient(LLMClientBase):
|
||||
return messages
|
||||
|
||||
|
||||
def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
|
||||
def _supports_structured_outputs(model: str) -> bool:
|
||||
"""Check if the model supports structured outputs (strict mode).
|
||||
|
||||
Only these 4 models are supported:
|
||||
- Claude Sonnet 4.5
|
||||
- Claude Opus 4.1
|
||||
- Claude Opus 4.5
|
||||
- Claude Haiku 4.5
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if "sonnet-4-5" in model_lower:
|
||||
return True
|
||||
elif "opus-4-1" in model_lower:
|
||||
return True
|
||||
elif "opus-4-5" in model_lower:
|
||||
return True
|
||||
elif "haiku-4-5" in model_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def convert_tools_to_anthropic_format(
|
||||
tools: List[OpenAITool],
|
||||
use_strict: bool = False,
|
||||
add_strict_field: bool = True,
|
||||
) -> List[dict]:
|
||||
"""See: https://docs.anthropic.com/claude/docs/tool-use
|
||||
|
||||
OpenAI style:
|
||||
@@ -975,18 +1108,11 @@ def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
|
||||
"description": "find ....",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
PARAM: {
|
||||
"type": PARAM_TYPE, # eg "string"
|
||||
"description": PARAM_DESCRIPTION,
|
||||
},
|
||||
...
|
||||
},
|
||||
"properties": {...},
|
||||
"required": List[str],
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}]
|
||||
|
||||
Anthropic style:
|
||||
"tools": [{
|
||||
@@ -994,89 +1120,114 @@ def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
|
||||
"description": "find ....",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
PARAM: {
|
||||
"type": PARAM_TYPE, # eg "string"
|
||||
"description": PARAM_DESCRIPTION,
|
||||
},
|
||||
...
|
||||
},
|
||||
"properties": {...},
|
||||
"required": List[str],
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
}]
|
||||
|
||||
Two small differences:
|
||||
- 1 level less of nesting
|
||||
- "parameters" -> "input_schema"
|
||||
Args:
|
||||
tools: List of OpenAI-style tools to convert
|
||||
use_strict: If True, add additionalProperties: false to all object schemas
|
||||
add_strict_field: If True (and use_strict=True), add strict: true to allowlisted tools.
|
||||
Set to False for token counting endpoint which doesn't support this field.
|
||||
"""
|
||||
formatted_tools = []
|
||||
strict_count = 0
|
||||
|
||||
for tool in tools:
|
||||
# Get the input schema
|
||||
input_schema = tool.function.parameters or {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
# Clean up the properties in the schema
|
||||
# The presence of union types / default fields seems Anthropic to produce invalid JSON for tool calls
|
||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||
cleaned_properties = {}
|
||||
for prop_name, prop_schema in input_schema.get("properties", {}).items():
|
||||
if isinstance(prop_schema, dict):
|
||||
cleaned_properties[prop_name] = _clean_property_schema(prop_schema)
|
||||
else:
|
||||
cleaned_properties[prop_name] = prop_schema
|
||||
|
||||
# Create cleaned input schema
|
||||
cleaned_input_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
"properties": cleaned_properties,
|
||||
}
|
||||
|
||||
# Only add required field if it exists and is non-empty
|
||||
if "required" in input_schema and input_schema["required"]:
|
||||
cleaned_input_schema["required"] = input_schema["required"]
|
||||
else:
|
||||
cleaned_input_schema = input_schema
|
||||
|
||||
formatted_tool = {
|
||||
# Use the older lightweight cleanup: remove defaults and simplify union-with-null.
|
||||
# When using structured outputs (use_strict=True), also add additionalProperties: false to all object types.
|
||||
cleaned_schema = (
|
||||
_clean_property_schema(input_schema, add_additional_properties_false=use_strict)
|
||||
if isinstance(input_schema, dict)
|
||||
else input_schema
|
||||
)
|
||||
# Normalize to a safe "object" schema shape to avoid downstream assumptions failing.
|
||||
if isinstance(cleaned_schema, dict):
|
||||
if cleaned_schema.get("type") != "object":
|
||||
cleaned_schema["type"] = "object"
|
||||
if not isinstance(cleaned_schema.get("properties"), dict):
|
||||
cleaned_schema["properties"] = {}
|
||||
# Ensure additionalProperties: false for structured outputs on the top-level schema
|
||||
# Must override any existing additionalProperties: true as well
|
||||
if use_strict:
|
||||
cleaned_schema["additionalProperties"] = False
|
||||
formatted_tool: dict = {
|
||||
"name": tool.function.name,
|
||||
"description": tool.function.description if tool.function.description else "",
|
||||
"input_schema": cleaned_input_schema,
|
||||
"input_schema": cleaned_schema,
|
||||
}
|
||||
|
||||
# Structured outputs "strict" mode: always attach `strict` for allowlisted tools
|
||||
# when we are using structured outputs models. Limit the number of strict tools
|
||||
# to avoid exceeding Anthropic constraints.
|
||||
# NOTE: The token counting endpoint does NOT support `strict` - only the messages endpoint does.
|
||||
if (
|
||||
use_strict
|
||||
and add_strict_field
|
||||
and tool.function.name in ANTHROPIC_STRICT_MODE_ALLOWLIST
|
||||
and strict_count < ANTHROPIC_MAX_STRICT_TOOLS
|
||||
):
|
||||
formatted_tool["strict"] = True
|
||||
strict_count += 1
|
||||
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
|
||||
def _clean_property_schema(prop_schema: dict) -> dict:
|
||||
"""Clean up a property schema by removing defaults and simplifying union types."""
|
||||
cleaned = {}
|
||||
def _clean_property_schema(schema: dict, add_additional_properties_false: bool = False) -> dict:
|
||||
"""Older schema cleanup used for Anthropic tools.
|
||||
|
||||
# Handle type field - simplify union types like ["null", "string"] to just "string"
|
||||
if "type" in prop_schema:
|
||||
prop_type = prop_schema["type"]
|
||||
if isinstance(prop_type, list):
|
||||
# Remove "null" from union types to simplify
|
||||
# e.g., ["null", "string"] becomes "string"
|
||||
non_null_types = [t for t in prop_type if t != "null"]
|
||||
if len(non_null_types) == 1:
|
||||
cleaned["type"] = non_null_types[0]
|
||||
elif len(non_null_types) > 1:
|
||||
# Keep as array if multiple non-null types
|
||||
cleaned["type"] = non_null_types
|
||||
Removes / simplifies fields that commonly cause Anthropic tool schema issues:
|
||||
- Remove `default` values
|
||||
- Simplify nullable unions like {"type": ["null", "string"]} -> {"type": "string"}
|
||||
- Recurse through nested schemas (properties/items/anyOf/oneOf/allOf/etc.)
|
||||
- Optionally add additionalProperties: false to object types (required for structured outputs)
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
cleaned: dict = {}
|
||||
|
||||
# Simplify union types like ["null", "string"] to "string"
|
||||
if "type" in schema:
|
||||
t = schema.get("type")
|
||||
if isinstance(t, list):
|
||||
non_null = [x for x in t if x != "null"]
|
||||
if len(non_null) == 1:
|
||||
cleaned["type"] = non_null[0]
|
||||
elif len(non_null) > 1:
|
||||
cleaned["type"] = non_null
|
||||
else:
|
||||
# If only "null" was in the list, default to string
|
||||
cleaned["type"] = "string"
|
||||
else:
|
||||
cleaned["type"] = prop_type
|
||||
cleaned["type"] = t
|
||||
|
||||
# Copy over other fields except 'default'
|
||||
for key, value in prop_schema.items():
|
||||
if key not in ["type", "default"]: # Skip 'default' field
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
# Recursively clean nested properties
|
||||
cleaned["properties"] = {k: _clean_property_schema(v) if isinstance(v, dict) else v for k, v in value.items()}
|
||||
else:
|
||||
cleaned[key] = value
|
||||
for key, value in schema.items():
|
||||
if key == "type":
|
||||
continue
|
||||
if key == "default":
|
||||
continue
|
||||
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
cleaned["properties"] = {k: _clean_property_schema(v, add_additional_properties_false) for k, v in value.items()}
|
||||
elif key == "items" and isinstance(value, dict):
|
||||
cleaned["items"] = _clean_property_schema(value, add_additional_properties_false)
|
||||
elif key in ("anyOf", "oneOf", "allOf") and isinstance(value, list):
|
||||
cleaned[key] = [_clean_property_schema(v, add_additional_properties_false) if isinstance(v, dict) else v for v in value]
|
||||
elif key in ("additionalProperties",) and isinstance(value, dict):
|
||||
cleaned[key] = _clean_property_schema(value, add_additional_properties_false)
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
# For structured outputs, Anthropic requires additionalProperties: false on all object types
|
||||
# We must override any existing additionalProperties: true as well
|
||||
if add_additional_properties_false and cleaned.get("type") == "object":
|
||||
cleaned["additionalProperties"] = False
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
27
letta/llm_api/anthropic_constants.py
Normal file
27
letta/llm_api/anthropic_constants.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Anthropic-specific constants for the Letta LLM API
|
||||
|
||||
# Allowlist of simple tools that work with Anthropic's structured outputs (strict mode).
|
||||
# These tools have few parameters and no complex nesting, making them safe for strict mode.
|
||||
# Tools with many optional params or deeply nested structures should use non-strict mode.
|
||||
#
|
||||
# Anthropic limitations for strict mode:
|
||||
# - Max 15 tools can use strict mode per request
|
||||
# - Max 24 optional parameters per tool (counted recursively in undocumented ways)
|
||||
# - Schema complexity limits
|
||||
#
|
||||
# Rather than trying to count parameters correctly, we allowlist simple tools that we know work.
|
||||
ANTHROPIC_STRICT_MODE_ALLOWLIST = {
|
||||
"Write", # 2 required params, no optional
|
||||
"Read", # 1 required, 2 simple optional
|
||||
"Edit", # 3 required, 1 simple optional
|
||||
"Glob", # 1 required, 1 simple optional
|
||||
"KillBash", # 1 required, no optional
|
||||
"fetch_webpage", # 1 required, no optional
|
||||
"EnterPlanMode", # no params
|
||||
"ExitPlanMode", # no params
|
||||
"Skill", # 1 required, 1 optional array
|
||||
"conversation_search", # 1 required, 4 simple optional
|
||||
}
|
||||
|
||||
# Maximum number of tools that can use strict mode in a single request
|
||||
ANTHROPIC_MAX_STRICT_TOOLS = 15
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
import httpx
|
||||
from google.genai import Client, errors
|
||||
from google.genai.types import (
|
||||
FunctionCallingConfig,
|
||||
@@ -860,6 +861,27 @@ class GoogleVertexClient(LLMClientBase):
|
||||
},
|
||||
)
|
||||
|
||||
# Handle httpx.RemoteProtocolError which can occur during streaming
|
||||
# when the remote server closes the connection unexpectedly
|
||||
# (e.g., "peer closed connection without sending complete message body")
|
||||
if isinstance(e, httpx.RemoteProtocolError):
|
||||
logger.warning(f"{self._provider_prefix()} Remote protocol error during streaming: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Connection error during {self._provider_name()} streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
# Handle httpx network errors which can occur during streaming
|
||||
# when the connection is unexpectedly closed while reading/writing
|
||||
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
|
||||
logger.warning(f"{self._provider_prefix()} Network error during streaming: {type(e).__name__}: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Network error during {self._provider_name()} streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
||||
)
|
||||
|
||||
# Handle connection-related errors
|
||||
if "connection" in str(e).lower() or "timeout" in str(e).lower():
|
||||
logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")
|
||||
|
||||
@@ -79,6 +79,13 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.zai:
|
||||
from letta.llm_api.zai_client import ZAIClient
|
||||
|
||||
return ZAIClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.groq:
|
||||
from letta.llm_api.groq_client import GroqClient
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@ import json
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from anthropic.types.beta.messages import BetaMessageBatch
|
||||
from openai import AsyncStream, Stream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.errors import LLMError
|
||||
from letta.errors import ErrorCode, LLMConnectionError, LLMError
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentType, ProviderCategory
|
||||
@@ -215,29 +216,59 @@ class LLMClientBase:
|
||||
Returns:
|
||||
An LLMError subclass that represents the error in a provider-agnostic way
|
||||
"""
|
||||
# Handle httpx.RemoteProtocolError which can occur during streaming
|
||||
# when the remote server closes the connection unexpectedly
|
||||
# (e.g., "peer closed connection without sending complete message body")
|
||||
if isinstance(e, httpx.RemoteProtocolError):
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(f"[LLM] Remote protocol error during streaming: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Connection error during streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
return LLMError(f"Unhandled LLM error: {str(e)}")
|
||||
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the override key for the given llm config.
|
||||
Only fetches API key from database for BYOK providers.
|
||||
Base providers use environment variables directly.
|
||||
"""
|
||||
api_key = None
|
||||
# Only fetch API key from database for BYOK providers
|
||||
# Base providers should always use environment variables
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
# If we got an empty string from the database, treat it as None
|
||||
# so the client can fall back to environment variables or default behavior
|
||||
if api_key == "":
|
||||
api_key = None
|
||||
|
||||
return api_key, None, None
|
||||
|
||||
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the override key for the given llm config.
|
||||
Only fetches API key from database for BYOK providers.
|
||||
Base providers use environment variables directly.
|
||||
"""
|
||||
api_key = None
|
||||
# Only fetch API key from database for BYOK providers
|
||||
# Base providers should always use environment variables
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
|
||||
# If we got an empty string from the database, treat it as None
|
||||
# so the client can fall back to environment variables or default behavior
|
||||
if api_key == "":
|
||||
api_key = None
|
||||
|
||||
return api_key, None, None
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types import Reasoning
|
||||
@@ -960,6 +961,27 @@ class OpenAIClient(LLMClientBase):
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
# Handle httpx.RemoteProtocolError which can occur during streaming
|
||||
# when the remote server closes the connection unexpectedly
|
||||
# (e.g., "peer closed connection without sending complete message body")
|
||||
if isinstance(e, httpx.RemoteProtocolError):
|
||||
logger.warning(f"[OpenAI] Remote protocol error during streaming: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Connection error during OpenAI streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
# Handle httpx network errors which can occur during streaming
|
||||
# when the connection is unexpectedly closed while reading/writing
|
||||
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
|
||||
logger.warning(f"[OpenAI] Network error during streaming: {type(e).__name__}: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Network error during OpenAI streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}")
|
||||
return LLMRateLimitError(
|
||||
|
||||
81
letta/llm_api/zai_client.py
Normal file
81
letta/llm_api/zai_client.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class ZAIClient(OpenAIClient):
|
||||
"""Z.ai (ZhipuAI) client - uses OpenAI-compatible API."""
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
agent_type: AgentType,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
return data
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to Z.ai API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.zai_api_key
|
||||
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to Z.ai API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.zai_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to Z.ai and returns the async stream iterator.
|
||||
"""
|
||||
api_key = model_settings.zai_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data, stream=True, stream_options={"include_usage": True}
|
||||
)
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config"""
|
||||
api_key = model_settings.zai_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
return [r.embedding for r in response.data]
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from letta.log import get_logger
|
||||
@@ -31,9 +32,12 @@ class EventLoopWatchdog:
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._last_heartbeat = time.time()
|
||||
self._heartbeat_scheduled_at = time.time()
|
||||
self._heartbeat_lock = threading.Lock()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._monitoring = False
|
||||
self._last_dump_time = 0.0 # Cooldown between task dumps
|
||||
self._saturation_start: Optional[float] = None # Track when saturation began
|
||||
|
||||
def start(self, loop: asyncio.AbstractEventLoop):
|
||||
"""Start the watchdog thread."""
|
||||
@@ -43,7 +47,9 @@ class EventLoopWatchdog:
|
||||
self._loop = loop
|
||||
self._monitoring = True
|
||||
self._stop_event.clear()
|
||||
self._last_heartbeat = time.time()
|
||||
now = time.time()
|
||||
self._last_heartbeat = now
|
||||
self._heartbeat_scheduled_at = now
|
||||
|
||||
self._thread = threading.Thread(target=self._watch_loop, daemon=True, name="EventLoopWatchdog")
|
||||
self._thread.start()
|
||||
@@ -51,7 +57,10 @@ class EventLoopWatchdog:
|
||||
# Schedule periodic heartbeats on the event loop
|
||||
loop.call_soon(self._schedule_heartbeats)
|
||||
|
||||
logger.info(f"Watchdog started (timeout: {self.timeout_threshold}s)")
|
||||
logger.info(
|
||||
f"Event loop watchdog started - monitoring thread running, heartbeat every 1s, "
|
||||
f"checks every {self.check_interval}s, hang threshold: {self.timeout_threshold}s"
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the watchdog thread."""
|
||||
@@ -66,8 +75,16 @@ class EventLoopWatchdog:
|
||||
if not self._monitoring:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
with self._heartbeat_lock:
|
||||
self._last_heartbeat = time.time()
|
||||
# Calculate event loop lag: time between when we scheduled this callback and when it ran
|
||||
lag = now - self._heartbeat_scheduled_at
|
||||
self._last_heartbeat = now
|
||||
self._heartbeat_scheduled_at = now + 1.0
|
||||
|
||||
# Log if lag is significant (> 2 seconds means event loop is saturated)
|
||||
if lag > 2.0:
|
||||
logger.warning(f"Event loop lag in heartbeat: {lag:.2f}s (expected ~1.0s)")
|
||||
|
||||
if self._loop and self._monitoring:
|
||||
self._loop.call_later(1.0, self._schedule_heartbeats)
|
||||
@@ -75,6 +92,7 @@ class EventLoopWatchdog:
|
||||
def _watch_loop(self):
|
||||
"""Main watchdog loop running in separate thread."""
|
||||
consecutive_hangs = 0
|
||||
max_lag_seen = 0.0
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
@@ -82,8 +100,13 @@ class EventLoopWatchdog:
|
||||
|
||||
with self._heartbeat_lock:
|
||||
last_beat = self._last_heartbeat
|
||||
scheduled_at = self._heartbeat_scheduled_at
|
||||
|
||||
time_since_heartbeat = time.time() - last_beat
|
||||
now = time.time()
|
||||
time_since_heartbeat = now - last_beat
|
||||
# Calculate current lag: how far behind schedule is the heartbeat?
|
||||
current_lag = now - scheduled_at
|
||||
max_lag_seen = max(max_lag_seen, current_lag)
|
||||
|
||||
# Try to estimate event loop load (safe from separate thread)
|
||||
task_count = -1
|
||||
@@ -98,9 +121,34 @@ class EventLoopWatchdog:
|
||||
|
||||
# ALWAYS log every check to prove watchdog is alive
|
||||
logger.debug(
|
||||
f"WATCHDOG_CHECK: heartbeat_age={time_since_heartbeat:.1f}s, consecutive_hangs={consecutive_hangs}, tasks={task_count}"
|
||||
f"WATCHDOG_CHECK: heartbeat_age={time_since_heartbeat:.1f}s, current_lag={current_lag:.2f}s, "
|
||||
f"max_lag={max_lag_seen:.2f}s, consecutive_hangs={consecutive_hangs}, tasks={task_count}"
|
||||
)
|
||||
|
||||
# Log at INFO if we see significant lag (> 2 seconds indicates saturation)
|
||||
if current_lag > 2.0:
|
||||
# Track saturation duration
|
||||
if self._saturation_start is None:
|
||||
self._saturation_start = now
|
||||
saturation_duration = now - self._saturation_start
|
||||
|
||||
logger.info(
|
||||
f"Event loop saturation detected: lag={current_lag:.2f}s, duration={saturation_duration:.1f}s, "
|
||||
f"tasks={task_count}, max_lag_seen={max_lag_seen:.2f}s"
|
||||
)
|
||||
|
||||
# Only dump stack traces with 60s cooldown to avoid spam
|
||||
if (now - self._last_dump_time) > 60.0:
|
||||
self._dump_asyncio_tasks() # Dump async tasks
|
||||
self._dump_state() # Dump thread stacks
|
||||
self._last_dump_time = now
|
||||
else:
|
||||
# Reset saturation tracking when recovered
|
||||
if self._saturation_start is not None:
|
||||
duration = now - self._saturation_start
|
||||
logger.info(f"Event loop saturation ended after {duration:.1f}s")
|
||||
self._saturation_start = None
|
||||
|
||||
if time_since_heartbeat > self.timeout_threshold:
|
||||
consecutive_hangs += 1
|
||||
logger.error(
|
||||
@@ -108,7 +156,8 @@ class EventLoopWatchdog:
|
||||
f"tasks={task_count}"
|
||||
)
|
||||
|
||||
# Dump basic state
|
||||
# Dump both thread state and asyncio tasks
|
||||
self._dump_asyncio_tasks()
|
||||
self._dump_state()
|
||||
|
||||
if consecutive_hangs >= 2:
|
||||
@@ -153,6 +202,93 @@ class EventLoopWatchdog:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dump state: {e}")
|
||||
|
||||
def _dump_asyncio_tasks(self):
|
||||
"""Dump asyncio task stack traces to diagnose event loop saturation."""
|
||||
try:
|
||||
if not self._loop or self._loop.is_closed():
|
||||
return
|
||||
|
||||
active_tasks = asyncio.all_tasks(self._loop)
|
||||
if not active_tasks:
|
||||
return
|
||||
|
||||
logger.warning(f"Severe lag detected - dumping active tasks ({len(active_tasks)} total):")
|
||||
|
||||
# Collect task data in single pass
|
||||
tasks_by_location = defaultdict(list)
|
||||
|
||||
for task in active_tasks:
|
||||
try:
|
||||
if task.done():
|
||||
continue
|
||||
stack = task.get_stack()
|
||||
if not stack:
|
||||
continue
|
||||
|
||||
# Find top letta frame for grouping
|
||||
for frame in reversed(stack):
|
||||
if "letta" in frame.f_code.co_filename:
|
||||
idx = frame.f_code.co_filename.find("letta/")
|
||||
path = frame.f_code.co_filename[idx + 6 :] if idx != -1 else frame.f_code.co_filename
|
||||
location = f"{path}:{frame.f_lineno}:{frame.f_code.co_name}"
|
||||
|
||||
# For bounded tasks, use wrapped coroutine location instead
|
||||
if frame.f_code.co_name == "bounded_coro":
|
||||
task_name = task.get_name()
|
||||
if task_name and task_name.startswith("bounded["):
|
||||
location = task_name[8:-1] # Extract "file:line:func" from "bounded[...]"
|
||||
|
||||
tasks_by_location[location].append((task, stack))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not tasks_by_location:
|
||||
return
|
||||
|
||||
total_tasks = sum(len(tasks) for tasks in tasks_by_location.values())
|
||||
logger.warning(f" Letta tasks: {total_tasks} total")
|
||||
|
||||
# Sort by task count (most blocked first) and show detailed stacks for top 3
|
||||
sorted_patterns = sorted(tasks_by_location.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
num_patterns = len(sorted_patterns)
|
||||
|
||||
logger.warning(f" Task patterns ({num_patterns} unique locations):")
|
||||
|
||||
# Show detailed stacks for top 3, summary for rest
|
||||
for i, (location, tasks) in enumerate(sorted_patterns, 1):
|
||||
count = len(tasks)
|
||||
pct = (count / total_tasks) * 100 if total_tasks > 0 else 0
|
||||
|
||||
if i <= 3:
|
||||
# Top 3: show detailed vertical stack trace
|
||||
logger.warning(f" [{i}] {count} tasks ({pct:.0f}%) at: {location}")
|
||||
_, sample_stack = tasks[0]
|
||||
# Show up to 8 frames vertically for better context
|
||||
for frame in sample_stack[-8:]:
|
||||
filename = frame.f_code.co_filename
|
||||
letta_idx = filename.find("letta/")
|
||||
if letta_idx != -1:
|
||||
short_path = filename[letta_idx + 6 :]
|
||||
logger.warning(f" {short_path}:{frame.f_lineno} in {frame.f_code.co_name}")
|
||||
else:
|
||||
pkg_idx = filename.find("site-packages/")
|
||||
if pkg_idx != -1:
|
||||
lib_path = filename[pkg_idx + 14 :]
|
||||
logger.warning(f" [{lib_path}:{frame.f_lineno}] {frame.f_code.co_name}")
|
||||
elif i <= 10:
|
||||
# Positions 4-10: show location only
|
||||
logger.warning(f" [{i}] {count} tasks ({pct:.0f}%) at: {location}")
|
||||
else:
|
||||
# Beyond 10: just show count in summary
|
||||
if i == 11:
|
||||
remaining = sum(len(t) for _, t in sorted_patterns[10:])
|
||||
remaining_patterns = num_patterns - 10
|
||||
logger.warning(f" ... and {remaining} more tasks across {remaining_patterns} other locations")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dump asyncio tasks: {e}")
|
||||
|
||||
|
||||
_global_watchdog: Optional[EventLoopWatchdog] = None
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.conversation import Conversation
|
||||
from letta.orm.conversation_messages import ConversationMessage
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.files_agents import FileAgent
|
||||
from letta.orm.group import Group
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set
|
||||
|
||||
from sqlalchemy import JSON, Boolean, DateTime, Index, Integer, String
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy import JSON, Boolean, DateTime, Index, Integer, String, select
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs, async_object_session
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.custom_columns import CompactionSettingsColumn, EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
|
||||
ENCRYPTED_PLACEHOLDER = "<encrypted>"
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentType
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable as PydanticAgentEnvVar
|
||||
@@ -22,11 +25,12 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.utils import calculate_file_defaults_based_on_context_window
|
||||
from letta.utils import bounded_gather, calculate_file_defaults_based_on_context_window
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.archives_agents import ArchivesAgents
|
||||
from letta.orm.conversation import Conversation
|
||||
from letta.orm.files_agents import FileAgent
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.organization import Organization
|
||||
@@ -74,7 +78,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
LLMConfigColumn, nullable=True, doc="the LLM backend configuration object for this agent."
|
||||
)
|
||||
embedding_config: Mapped[Optional[EmbeddingConfig]] = mapped_column(
|
||||
EmbeddingConfigColumn, doc="the embedding configuration object for this agent."
|
||||
EmbeddingConfigColumn, nullable=True, doc="the embedding configuration object for this agent."
|
||||
)
|
||||
compaction_settings: Mapped[Optional[dict]] = mapped_column(
|
||||
CompactionSettingsColumn, nullable=True, doc="the compaction settings configuration object for compaction."
|
||||
@@ -147,7 +151,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
"Run",
|
||||
back_populates="agent",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
lazy="raise",
|
||||
doc="Runs associated with the agent.",
|
||||
)
|
||||
identities: Mapped[List["Identity"]] = relationship(
|
||||
@@ -186,6 +190,13 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
lazy="noload",
|
||||
doc="Archives accessible by this agent.",
|
||||
)
|
||||
conversations: Mapped[List["Conversation"]] = relationship(
|
||||
"Conversation",
|
||||
back_populates="agent",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="raise",
|
||||
doc="Conversations for concurrent messaging on this agent.",
|
||||
)
|
||||
|
||||
def _get_per_file_view_window_char_limit(self) -> int:
|
||||
"""Get the per_file_view_window_char_limit, calculating defaults if None."""
|
||||
@@ -298,10 +309,33 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
|
||||
return self.__pydantic_model__(**state)
|
||||
|
||||
async def _get_pending_approval_async(self) -> Optional[Any]:
|
||||
if self.message_ids and len(self.message_ids) > 0:
|
||||
# Try to get the async session this object is attached to
|
||||
session = async_object_session(self)
|
||||
if not session:
|
||||
# Object is detached, can't safely query
|
||||
return None
|
||||
|
||||
latest_message_id = self.message_ids[-1]
|
||||
result = await session.execute(select(MessageModel).where(MessageModel.id == latest_message_id))
|
||||
latest_message = result.scalar_one_or_none()
|
||||
|
||||
if (
|
||||
latest_message
|
||||
and latest_message.role == "approval"
|
||||
and latest_message.tool_calls is not None
|
||||
and len(latest_message.tool_calls) > 0
|
||||
):
|
||||
pydantic_message = latest_message.to_pydantic()
|
||||
return pydantic_message._convert_approval_request_message()
|
||||
return None
|
||||
|
||||
async def to_pydantic_async(
|
||||
self,
|
||||
include_relationships: Optional[Set[str]] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
decrypt: bool = True,
|
||||
) -> PydanticAgentState:
|
||||
"""
|
||||
Converts the SQLAlchemy Agent model into its Pydantic counterpart.
|
||||
@@ -368,6 +402,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
"managed_group": None,
|
||||
"tool_exec_environment_variables": [],
|
||||
"secrets": [],
|
||||
"pending_approval": None,
|
||||
}
|
||||
|
||||
# Initialize include_relationships to an empty set if it's None
|
||||
@@ -411,9 +446,20 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
file_agents = (
|
||||
self.awaitable_attrs.file_agents if "memory" in include_relationships or "agent.blocks" in include_set else empty_list_async()
|
||||
)
|
||||
pending_approval = self._get_pending_approval_async() if "agent.pending_approval" in include_set else none_async()
|
||||
|
||||
(tags, tools, sources, memory, identities, multi_agent_group, tool_exec_environment_variables, file_agents) = await asyncio.gather(
|
||||
tags, tools, sources, memory, identities, multi_agent_group, tool_exec_environment_variables, file_agents
|
||||
(
|
||||
tags,
|
||||
tools,
|
||||
sources,
|
||||
memory,
|
||||
identities,
|
||||
multi_agent_group,
|
||||
tool_exec_environment_variables,
|
||||
file_agents,
|
||||
pending_approval,
|
||||
) = await asyncio.gather(
|
||||
tags, tools, sources, memory, identities, multi_agent_group, tool_exec_environment_variables, file_agents, pending_approval
|
||||
)
|
||||
|
||||
state["tags"] = [t.tag for t in tags]
|
||||
@@ -433,12 +479,29 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
state["identities"] = [i.to_pydantic() for i in identities]
|
||||
state["multi_agent_group"] = multi_agent_group
|
||||
state["managed_group"] = multi_agent_group
|
||||
# Convert ORM env vars to Pydantic with async decryption
|
||||
env_vars_pydantic = []
|
||||
for e in tool_exec_environment_variables:
|
||||
env_vars_pydantic.append(await PydanticAgentEnvVar.from_orm_async(e))
|
||||
# Convert ORM env vars to Pydantic, optionally skipping decryption
|
||||
if decrypt:
|
||||
env_vars_pydantic = await bounded_gather([PydanticAgentEnvVar.from_orm_async(e) for e in tool_exec_environment_variables])
|
||||
else:
|
||||
# Skip decryption - return with encrypted values (faster, no PBKDF2)
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
||||
from letta.schemas.secret import Secret
|
||||
|
||||
env_vars_pydantic = []
|
||||
for e in tool_exec_environment_variables:
|
||||
data = {
|
||||
"id": e.id,
|
||||
"key": e.key,
|
||||
"description": e.description,
|
||||
"organization_id": e.organization_id,
|
||||
"agent_id": e.agent_id,
|
||||
"value": ENCRYPTED_PLACEHOLDER,
|
||||
"value_enc": Secret.from_encrypted(e.value_enc) if e.value_enc else None,
|
||||
}
|
||||
env_vars_pydantic.append(AgentEnvironmentVariable.model_validate(data))
|
||||
state["tool_exec_environment_variables"] = env_vars_pydantic
|
||||
state["secrets"] = env_vars_pydantic
|
||||
state["pending_approval"] = pending_approval
|
||||
state["model"] = self.llm_config.handle if self.llm_config else None
|
||||
state["model_settings"] = self.llm_config._to_model_settings() if self.llm_config else None
|
||||
state["embedding"] = self.embedding_config.handle if self.embedding_config else None
|
||||
|
||||
49
letta/orm/conversation.py
Normal file
49
letta/orm/conversation.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.conversation import Conversation as PydanticConversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.conversation_messages import ConversationMessage
|
||||
|
||||
|
||||
class Conversation(SqlalchemyBase, OrganizationMixin):
|
||||
"""Conversations that can be created on an agent for concurrent messaging."""
|
||||
|
||||
__tablename__ = "conversations"
|
||||
__pydantic_model__ = PydanticConversation
|
||||
__table_args__ = (
|
||||
Index("ix_conversations_agent_id", "agent_id"),
|
||||
Index("ix_conversations_org_agent", "organization_id", "agent_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv-{uuid.uuid4()}")
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False)
|
||||
summary: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Summary of the conversation")
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="conversations", lazy="raise")
|
||||
message_associations: Mapped[List["ConversationMessage"]] = relationship(
|
||||
"ConversationMessage",
|
||||
back_populates="conversation",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> PydanticConversation:
|
||||
"""Converts the SQLAlchemy model to its Pydantic counterpart."""
|
||||
return self.__pydantic_model__(
|
||||
id=self.id,
|
||||
agent_id=self.agent_id,
|
||||
summary=self.summary,
|
||||
created_at=self.created_at,
|
||||
updated_at=self.updated_at,
|
||||
created_by_id=self.created_by_id,
|
||||
last_updated_by_id=self.last_updated_by_id,
|
||||
)
|
||||
73
letta/orm/conversation_messages.py
Normal file
73
letta/orm/conversation_messages.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, Index, Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.conversation import Conversation
|
||||
from letta.orm.message import Message
|
||||
|
||||
|
||||
class ConversationMessage(SqlalchemyBase, OrganizationMixin):
|
||||
"""
|
||||
Track in-context messages for a conversation.
|
||||
|
||||
This replaces the message_ids JSON list on agents with proper relational modeling.
|
||||
- conversation_id=NULL represents the "default" conversation (backward compatible)
|
||||
- conversation_id=<id> represents a named conversation for concurrent messaging
|
||||
"""
|
||||
|
||||
__tablename__ = "conversation_messages"
|
||||
__table_args__ = (
|
||||
Index("ix_conv_msg_conversation_position", "conversation_id", "position"),
|
||||
Index("ix_conv_msg_message_id", "message_id"),
|
||||
Index("ix_conv_msg_agent_id", "agent_id"),
|
||||
Index("ix_conv_msg_agent_conversation", "agent_id", "conversation_id"),
|
||||
UniqueConstraint("conversation_id", "message_id", name="unique_conversation_message"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv_msg-{uuid.uuid4()}")
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(
|
||||
String,
|
||||
ForeignKey("conversations.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
doc="NULL for default conversation, otherwise FK to conversation",
|
||||
)
|
||||
agent_id: Mapped[str] = mapped_column(
|
||||
String,
|
||||
ForeignKey("agents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
doc="The agent this message association belongs to",
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(
|
||||
String,
|
||||
ForeignKey("messages.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
doc="The message being tracked",
|
||||
)
|
||||
position: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
doc="Position in conversation (for ordering)",
|
||||
)
|
||||
in_context: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
nullable=False,
|
||||
doc="Whether message is currently in the agent's context window",
|
||||
)
|
||||
|
||||
# Relationships
|
||||
conversation: Mapped[Optional["Conversation"]] = relationship(
|
||||
"Conversation",
|
||||
back_populates="message_associations",
|
||||
lazy="raise",
|
||||
)
|
||||
message: Mapped["Message"] = relationship(
|
||||
"Message",
|
||||
lazy="selectin",
|
||||
)
|
||||
@@ -69,6 +69,34 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
def to_pydantic(self, strip_directory_prefix: bool = False) -> PydanticFileMetadata:
|
||||
"""
|
||||
Convert to Pydantic model without any relationship loading.
|
||||
"""
|
||||
file_name = self.file_name
|
||||
if strip_directory_prefix and "/" in file_name:
|
||||
file_name = "/".join(file_name.split("/")[1:])
|
||||
|
||||
return PydanticFileMetadata(
|
||||
id=self.id,
|
||||
organization_id=self.organization_id,
|
||||
source_id=self.source_id,
|
||||
file_name=file_name,
|
||||
original_file_name=self.original_file_name,
|
||||
file_path=self.file_path,
|
||||
file_type=self.file_type,
|
||||
file_size=self.file_size,
|
||||
file_creation_date=self.file_creation_date,
|
||||
file_last_modified_date=self.file_last_modified_date,
|
||||
processing_status=self.processing_status,
|
||||
error_message=self.error_message,
|
||||
total_chunks=self.total_chunks,
|
||||
chunks_embedded=self.chunks_embedded,
|
||||
created_at=self.created_at,
|
||||
updated_at=self.updated_at,
|
||||
content=None,
|
||||
)
|
||||
|
||||
async def to_pydantic_async(self, include_content: bool = False, strip_directory_prefix: bool = False) -> PydanticFileMetadata:
|
||||
"""
|
||||
Async version of `to_pydantic` that supports optional relationship loading
|
||||
|
||||
@@ -55,6 +55,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
nullable=True,
|
||||
doc="The id of the LLMBatchItem that this message is associated with",
|
||||
)
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("conversations.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
doc="The conversation this message belongs to (NULL = default conversation)",
|
||||
)
|
||||
is_err: Mapped[Optional[bool]] = mapped_column(
|
||||
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
|
||||
)
|
||||
|
||||
@@ -40,8 +40,8 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
"""Relationship to organization"""
|
||||
return relationship("Organization", back_populates="passages", lazy="selectin")
|
||||
"""Relationship to organization - use lazy='raise' to prevent accidental blocking in async contexts"""
|
||||
return relationship("Organization", back_populates="passages", lazy="raise")
|
||||
|
||||
|
||||
class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
@@ -53,7 +53,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
return relationship("Organization", back_populates="source_passages", lazy="selectin")
|
||||
return relationship("Organization", back_populates="source_passages", lazy="raise")
|
||||
|
||||
@declared_attr
|
||||
def __table_args__(cls):
|
||||
@@ -84,7 +84,7 @@ class ArchivalPassage(BasePassage, ArchiveMixin):
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
return relationship("Organization", back_populates="archival_passages", lazy="selectin")
|
||||
return relationship("Organization", back_populates="archival_passages", lazy="raise")
|
||||
|
||||
@declared_attr
|
||||
def __table_args__(cls):
|
||||
|
||||
@@ -30,6 +30,7 @@ class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
||||
Index("ix_runs_created_at", "created_at", "id"),
|
||||
Index("ix_runs_agent_id", "agent_id"),
|
||||
Index("ix_runs_organization_id", "organization_id"),
|
||||
Index("ix_runs_conversation_id", "conversation_id"),
|
||||
)
|
||||
|
||||
# Generate run ID with run- prefix
|
||||
@@ -50,6 +51,11 @@ class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
||||
# Agent relationship - A run belongs to one agent
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), nullable=False, doc="The agent that owns this run.")
|
||||
|
||||
# Conversation relationship - Optional, a run may be associated with a conversation
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(
|
||||
String, ForeignKey("conversations.id", ondelete="SET NULL"), nullable=True, doc="The conversation this run belongs to."
|
||||
)
|
||||
|
||||
# Callback related columns
|
||||
callback_url: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="When set, POST to this URL after run completion.")
|
||||
callback_sent_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="Timestamp when the callback was last attempted.")
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
from sqlalchemy.orm.interfaces import ORMOption
|
||||
|
||||
from letta.errors import ConcurrentUpdateError
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
@@ -259,28 +260,40 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
if before_obj and after_obj:
|
||||
# Window-based query - get records between before and after
|
||||
conditions.append(
|
||||
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id))
|
||||
)
|
||||
conditions.append(
|
||||
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id))
|
||||
)
|
||||
# Skip pagination if either object has null created_at
|
||||
if before_obj.created_at is not None and after_obj.created_at is not None:
|
||||
conditions.append(
|
||||
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id))
|
||||
)
|
||||
conditions.append(
|
||||
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id))
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping pagination: before_obj.created_at={before_obj.created_at}, after_obj.created_at={after_obj.created_at}"
|
||||
)
|
||||
else:
|
||||
# Pure pagination query
|
||||
if before_obj:
|
||||
conditions.append(
|
||||
or_(
|
||||
cls.created_at < before_obj.created_at if ascending else cls.created_at > before_obj.created_at,
|
||||
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
|
||||
if before_obj.created_at is not None:
|
||||
conditions.append(
|
||||
or_(
|
||||
cls.created_at < before_obj.created_at if ascending else cls.created_at > before_obj.created_at,
|
||||
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Skipping 'before' pagination: before_obj.created_at is None (id={before_obj.id})")
|
||||
if after_obj:
|
||||
conditions.append(
|
||||
or_(
|
||||
cls.created_at > after_obj.created_at if ascending else cls.created_at < after_obj.created_at,
|
||||
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
|
||||
if after_obj.created_at is not None:
|
||||
conditions.append(
|
||||
or_(
|
||||
cls.created_at > after_obj.created_at if ascending else cls.created_at < after_obj.created_at,
|
||||
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Skipping 'after' pagination: after_obj.created_at is None (id={after_obj.id})")
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
@@ -619,6 +632,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
self.set_updated_at()
|
||||
|
||||
# Capture id before try block to avoid accessing expired attributes after rollback
|
||||
object_id = self.id
|
||||
class_name = self.__class__.__name__
|
||||
|
||||
try:
|
||||
db_session.add(self)
|
||||
if no_commit:
|
||||
@@ -633,8 +651,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
# This can occur when using optimistic locking (version_id_col) and:
|
||||
# 1. The row doesn't exist (0 rows matched)
|
||||
# 2. The version has changed (concurrent update)
|
||||
# We convert this to NoResultFound to return a proper 404 error
|
||||
raise NoResultFound(f"{self.__class__.__name__} with id '{self.id}' not found or was updated by another transaction") from e
|
||||
# In practice, case 1 is rare (blocks aren't frequently deleted), so we always
|
||||
# return 409 ConcurrentUpdateError. If it was actually deleted, the retry will get 404.
|
||||
# Not worth performing another db query to check if the row exists.
|
||||
raise ConcurrentUpdateError(resource_type=class_name, resource_id=object_id) from e
|
||||
except (DBAPIError, IntegrityError) as e:
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
|
||||
@@ -60,6 +60,9 @@ class Step(SqlalchemyBase, ProjectMixin):
|
||||
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
|
||||
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
|
||||
trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")
|
||||
request_id: Mapped[Optional[str]] = mapped_column(
|
||||
None, nullable=True, doc="The API request log ID from cloud-api for correlating steps with API requests."
|
||||
)
|
||||
feedback: Mapped[Optional[str]] = mapped_column(
|
||||
None, nullable=True, doc="The feedback for this step. Must be either 'positive' or 'negative'."
|
||||
)
|
||||
@@ -72,9 +75,9 @@ class Step(SqlalchemyBase, ProjectMixin):
|
||||
status: Mapped[Optional[StepStatus]] = mapped_column(None, nullable=True, doc="Step status: pending, success, or failed")
|
||||
|
||||
# Relationships (foreign keys)
|
||||
organization: Mapped[Optional["Organization"]] = relationship("Organization")
|
||||
provider: Mapped[Optional["Provider"]] = relationship("Provider")
|
||||
run: Mapped[Optional["Run"]] = relationship("Run", back_populates="steps")
|
||||
organization: Mapped[Optional["Organization"]] = relationship("Organization", lazy="raise")
|
||||
provider: Mapped[Optional["Provider"]] = relationship("Provider", lazy="raise")
|
||||
run: Mapped[Optional["Run"]] = relationship("Run", back_populates="steps", lazy="raise")
|
||||
|
||||
# Relationships (backrefs)
|
||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload")
|
||||
|
||||
@@ -82,8 +82,8 @@ class StepMetrics(SqlalchemyBase, ProjectMixin, AgentMixin):
|
||||
|
||||
# Relationships (foreign keys)
|
||||
step: Mapped["Step"] = relationship("Step", back_populates="metrics", uselist=False)
|
||||
run: Mapped[Optional["Run"]] = relationship("Run")
|
||||
agent: Mapped[Optional["Agent"]] = relationship("Agent")
|
||||
run: Mapped[Optional["Run"]] = relationship("Run", lazy="raise")
|
||||
agent: Mapped[Optional["Agent"]] = relationship("Agent", lazy="raise")
|
||||
|
||||
def create(
|
||||
self,
|
||||
|
||||
@@ -37,6 +37,9 @@ _excluded_v1_endpoints_regex: List[str] = [
|
||||
|
||||
|
||||
async def _trace_request_middleware(request: Request, call_next):
|
||||
# Capture earliest possible timestamp when request enters application
|
||||
entry_time = time.time()
|
||||
|
||||
if not _is_tracing_initialized:
|
||||
return await call_next(request)
|
||||
initial_span_name = f"{request.method} {request.url.path}"
|
||||
@@ -47,8 +50,17 @@ async def _trace_request_middleware(request: Request, call_next):
|
||||
initial_span_name,
|
||||
kind=trace.SpanKind.SERVER,
|
||||
) as span:
|
||||
# Record when we entered the application (useful for detecting worker queuing)
|
||||
span.set_attribute("entry.timestamp_ms", int(entry_time * 1000))
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
|
||||
# Update span name with route pattern after FastAPI has matched the route
|
||||
route = request.scope.get("route")
|
||||
if route and hasattr(route, "path"):
|
||||
span.update_name(f"{request.method} {route.path}")
|
||||
|
||||
span.set_attribute("http.status_code", response.status_code)
|
||||
span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR))
|
||||
return response
|
||||
@@ -67,44 +79,50 @@ async def _update_trace_attributes(request: Request):
|
||||
if not span:
|
||||
return
|
||||
|
||||
# Update span name with route pattern
|
||||
route = request.scope.get("route")
|
||||
if route and hasattr(route, "path"):
|
||||
span.update_name(f"{request.method} {route.path}")
|
||||
# Wrap attribute-setting work in a span to measure time before body parsing
|
||||
with tracer.start_as_current_span("trace.set_attributes"):
|
||||
# Update span name with route pattern
|
||||
route = request.scope.get("route")
|
||||
if route and hasattr(route, "path"):
|
||||
span.update_name(f"{request.method} {route.path}")
|
||||
|
||||
# Add request info
|
||||
span.set_attribute("http.method", request.method)
|
||||
span.set_attribute("http.url", str(request.url))
|
||||
# Add request info
|
||||
span.set_attribute("http.method", request.method)
|
||||
span.set_attribute("http.url", str(request.url))
|
||||
|
||||
# Add path params
|
||||
for key, value in request.path_params.items():
|
||||
span.set_attribute(f"http.{key}", value)
|
||||
# Add path params
|
||||
for key, value in request.path_params.items():
|
||||
span.set_attribute(f"http.{key}", value)
|
||||
|
||||
# Add the following headers to span if available
|
||||
header_attributes = {
|
||||
"user_id": "user.id",
|
||||
"x-organization-id": "organization.id",
|
||||
"x-project-id": "project.id",
|
||||
"x-agent-id": "agent.id",
|
||||
"x-template-id": "template.id",
|
||||
"x-base-template-id": "base_template.id",
|
||||
"user-agent": "client",
|
||||
"x-stainless-package-version": "sdk.version",
|
||||
"x-stainless-lang": "sdk.language",
|
||||
"x-letta-source": "source",
|
||||
}
|
||||
for header_key, span_key in header_attributes.items():
|
||||
header_value = request.headers.get(header_key)
|
||||
if header_value:
|
||||
span.set_attribute(span_key, header_value)
|
||||
# Add the following headers to span if available
|
||||
header_attributes = {
|
||||
"user_id": "user.id",
|
||||
"x-organization-id": "organization.id",
|
||||
"x-project-id": "project.id",
|
||||
"x-agent-id": "agent.id",
|
||||
"x-template-id": "template.id",
|
||||
"x-base-template-id": "base_template.id",
|
||||
"user-agent": "client",
|
||||
"x-stainless-package-version": "sdk.version",
|
||||
"x-stainless-lang": "sdk.language",
|
||||
"x-letta-source": "source",
|
||||
}
|
||||
for header_key, span_key in header_attributes.items():
|
||||
header_value = request.headers.get(header_key)
|
||||
if header_value:
|
||||
span.set_attribute(span_key, header_value)
|
||||
|
||||
# Add request body if available
|
||||
try:
|
||||
body = await request.json()
|
||||
for key, value in body.items():
|
||||
span.set_attribute(f"http.request.body.{key}", str(value))
|
||||
except Exception:
|
||||
pass
|
||||
# Add request body if available (only for JSON requests)
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "application/json" in content_type and request.method in ("POST", "PUT", "PATCH"):
|
||||
try:
|
||||
with tracer.start_as_current_span("trace.request_body"):
|
||||
body = await request.json()
|
||||
for key, value in body.items():
|
||||
span.set_attribute(f"http.request.body.{key}", str(value))
|
||||
except Exception:
|
||||
# Ignore JSON parsing errors (empty body, invalid JSON, etc.)
|
||||
pass
|
||||
|
||||
|
||||
async def _trace_error_handler(_request: Request, exc: Exception) -> JSONResponse:
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.schemas.file import FileStatus
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.identity import Identity
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import ApprovalRequestMessage
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
@@ -51,6 +52,7 @@ AgentRelationships = Literal[
|
||||
"agent.blocks",
|
||||
"agent.identities",
|
||||
"agent.managed_group",
|
||||
"agent.pending_approval",
|
||||
"agent.secrets",
|
||||
"agent.sources",
|
||||
"agent.tags",
|
||||
@@ -81,8 +83,8 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
llm_config: LLMConfig = Field(
|
||||
..., description="Deprecated: Use `model` field instead. The LLM configuration used by the agent.", deprecated=True
|
||||
)
|
||||
embedding_config: EmbeddingConfig = Field(
|
||||
..., description="Deprecated: Use `embedding` field instead. The embedding configuration used by the agent.", deprecated=True
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(
|
||||
None, description="Deprecated: Use `embedding` field instead. The embedding configuration used by the agent.", deprecated=True
|
||||
)
|
||||
model: Optional[str] = Field(None, description="The model handle used by the agent (format: provider/model-name).")
|
||||
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
|
||||
@@ -125,6 +127,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
[], description="Deprecated: Use `identities` field instead. The ids of the identities associated with this agent.", deprecated=True
|
||||
)
|
||||
identities: List[Identity] = Field([], description="The identities associated with this agent.")
|
||||
pending_approval: Optional[ApprovalRequestMessage] = Field(
|
||||
None, description="The latest approval request message pending for this agent, if any."
|
||||
)
|
||||
|
||||
# An advanced configuration that makes it so this agent does not remember any previous messages
|
||||
message_buffer_autoclear: bool = Field(
|
||||
|
||||
28
letta/schemas/conversation.py
Normal file
28
letta/schemas/conversation.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
class Conversation(OrmMetadataBase):
|
||||
"""Represents a conversation on an agent for concurrent messaging."""
|
||||
|
||||
__id_prefix__ = "conv"
|
||||
|
||||
id: str = Field(..., description="The unique identifier of the conversation.")
|
||||
agent_id: str = Field(..., description="The ID of the agent this conversation belongs to.")
|
||||
summary: Optional[str] = Field(None, description="A summary of the conversation.")
|
||||
in_context_message_ids: List[str] = Field(default_factory=list, description="The IDs of in-context messages for the conversation.")
|
||||
|
||||
|
||||
class CreateConversation(BaseModel):
|
||||
"""Request model for creating a new conversation."""
|
||||
|
||||
summary: Optional[str] = Field(None, description="A summary of the conversation.")
|
||||
|
||||
|
||||
class UpdateConversation(BaseModel):
|
||||
"""Request model for updating a conversation."""
|
||||
|
||||
summary: Optional[str] = Field(None, description="A summary of the conversation.")
|
||||
@@ -26,6 +26,7 @@ class PrimitiveType(str, Enum):
|
||||
SANDBOX_CONFIG = "sandbox" # Note: sandbox_config IDs use "sandbox" prefix
|
||||
STEP = "step"
|
||||
IDENTITY = "identity"
|
||||
CONVERSATION = "conv"
|
||||
|
||||
# Infrastructure types
|
||||
MCP_SERVER = "mcp_server"
|
||||
@@ -67,6 +68,7 @@ class ProviderType(str, Enum):
|
||||
together = "together"
|
||||
vllm = "vllm"
|
||||
xai = "xai"
|
||||
zai = "zai"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
|
||||
from letta.schemas.secret import Secret
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Base Environment Variable
|
||||
class EnvironmentVariableBase(OrmMetadataBase):
|
||||
@@ -24,18 +28,30 @@ class EnvironmentVariableBase(OrmMetadataBase):
|
||||
# This validator syncs `value` and `value_enc` for backward compatibility:
|
||||
# - If `value_enc` is set but `value` is empty -> populate `value` from decrypted `value_enc`
|
||||
# - If `value` is set but `value_enc` is empty -> populate `value_enc` from encrypted `value`
|
||||
@model_validator(mode="after")
|
||||
def sync_value_and_value_enc(self):
|
||||
"""Sync deprecated `value` field with `value_enc` for backward compatibility."""
|
||||
if self.value_enc and not self.value:
|
||||
# Decrypt value_enc -> value (for API responses)
|
||||
plaintext = self.value_enc.get_plaintext()
|
||||
if plaintext:
|
||||
self.value = plaintext
|
||||
elif self.value and not self.value_enc:
|
||||
# Encrypt value -> value_enc (for backward compat when value is provided directly)
|
||||
self.value_enc = Secret.from_plaintext(self.value)
|
||||
return self
|
||||
# @model_validator(mode="after")
|
||||
# def sync_value_and_value_enc(self):
|
||||
# """Sync deprecated `value` field with `value_enc` for backward compatibility."""
|
||||
# if self.value_enc and not self.value:
|
||||
# # ERROR: This should not happen - all code paths should populate value via async decryption
|
||||
# # Log error with stack trace to identify the caller that bypassed async decryption
|
||||
# logger.warning(
|
||||
# f"Sync decryption fallback triggered for env var key={self.key}. "
|
||||
# f"This indicates a code path that bypassed async decryption. Stack trace:\n{''.join(traceback.format_stack())}"
|
||||
# )
|
||||
# # Decrypt value_enc -> value (for API responses)
|
||||
# plaintext = self.value_enc.get_plaintext()
|
||||
# if plaintext:
|
||||
# self.value = plaintext
|
||||
# elif self.value and not self.value_enc:
|
||||
# # WARNING: This triggers sync encryption - should use async encryption where possible
|
||||
# # Log warning with stack trace to identify the caller
|
||||
# logger.warning(
|
||||
# f"Sync encryption fallback triggered for env var key={self.key}. "
|
||||
# f"This indicates a code path that bypassed async encryption. Stack trace:\n{''.join(traceback.format_stack())}"
|
||||
# )
|
||||
# # Encrypt value -> value_enc (for backward compat when value is provided directly)
|
||||
# self.value_enc = Secret.from_plaintext(self.value)
|
||||
# return self
|
||||
|
||||
|
||||
class EnvironmentVariableCreateBase(LettaBase):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
|
||||
|
||||
@@ -9,6 +9,19 @@ from letta.schemas.letta_message_content import LettaMessageContentUnion
|
||||
from letta.schemas.message import MessageCreate, MessageCreateUnion, MessageRole
|
||||
|
||||
|
||||
class ClientToolSchema(BaseModel):
|
||||
"""Schema for a client-side tool passed in the request.
|
||||
|
||||
Client-side tools are executed by the client, not the server. When the agent
|
||||
calls a client-side tool, execution pauses and returns control to the client
|
||||
to execute the tool and provide the result.
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the tool function")
|
||||
description: Optional[str] = Field(None, description="Description of what the tool does")
|
||||
parameters: Optional[Dict[str, Any]] = Field(None, description="JSON Schema for the function parameters")
|
||||
|
||||
|
||||
class LettaRequest(BaseModel):
|
||||
messages: Optional[List[MessageCreateUnion]] = Field(None, description="The messages to be sent to the agent.")
|
||||
input: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(
|
||||
@@ -45,6 +58,13 @@ class LettaRequest(BaseModel):
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
# Client-side tools
|
||||
client_tools: Optional[List[ClientToolSchema]] = Field(
|
||||
None,
|
||||
description="Client-side tools that the agent can call. When the agent calls a client-side tool, "
|
||||
"execution pauses and returns control to the client to execute the tool and provide the result via a ToolReturn.",
|
||||
)
|
||||
|
||||
@field_validator("messages", mode="before")
|
||||
@classmethod
|
||||
def add_default_type_to_messages(cls, v):
|
||||
|
||||
@@ -48,6 +48,7 @@ class LLMConfig(BaseModel):
|
||||
"bedrock",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"zai",
|
||||
] = Field(..., description="The endpoint type for the model.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
|
||||
@@ -317,6 +318,7 @@ class LLMConfig(BaseModel):
|
||||
OpenAIReasoning,
|
||||
TogetherModelSettings,
|
||||
XAIModelSettings,
|
||||
ZAIModelSettings,
|
||||
)
|
||||
|
||||
if self.model_endpoint_type == "openai":
|
||||
@@ -359,6 +361,11 @@ class LLMConfig(BaseModel):
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "zai":
|
||||
return ZAIModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "groq":
|
||||
return GroqModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
|
||||
@@ -3,7 +3,9 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +53,21 @@ class MCPServer(BaseMCPServer):
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
def get_token_secret(self) -> Optional[Secret]:
|
||||
"""Get the token as a Secret object."""
|
||||
return self.token_enc
|
||||
@@ -199,6 +216,21 @@ class UpdateSSEMCPServer(LettaBase):
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateStdioMCPServer(LettaBase):
|
||||
"""Update a Stdio MCP server"""
|
||||
@@ -218,6 +250,21 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer]
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from letta.functions.mcp_client.types import (
|
||||
MCP_AUTH_HEADER_AUTHORIZATION,
|
||||
@@ -41,6 +42,19 @@ class CreateSSEMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: str) -> str:
|
||||
"""Validate that server_url is a valid HTTP(S) URL."""
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
class CreateStreamableHTTPMCPServer(LettaBase):
|
||||
"""Create a new Streamable HTTP MCP server"""
|
||||
@@ -51,6 +65,19 @@ class CreateStreamableHTTPMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: str) -> str:
|
||||
"""Validate that server_url is a valid HTTP(S) URL."""
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
CreateMCPServerUnion = Union[CreateStdioMCPServer, CreateSSEMCPServer, CreateStreamableHTTPMCPServer]
|
||||
|
||||
@@ -99,6 +126,21 @@ class UpdateSSEMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
"""Update schema for Streamable HTTP MCP server - all fields optional"""
|
||||
@@ -109,6 +151,21 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
UpdateMCPServerUnion = Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStreamableHTTPMCPServer]
|
||||
|
||||
|
||||
@@ -456,6 +456,7 @@ class CreateArchivalMemory(BaseModel):
|
||||
|
||||
|
||||
class ArchivalMemorySearchResult(BaseModel):
|
||||
id: str = Field(..., description="Unique identifier of the archival memory passage")
|
||||
timestamp: str = Field(..., description="Timestamp of when the memory was created, formatted in agent's timezone")
|
||||
content: str = Field(..., description="Text content of the archival memory passage")
|
||||
tags: List[str] = Field(default_factory=list, description="List of tags associated with this memory")
|
||||
|
||||
@@ -211,6 +211,7 @@ class Message(BaseMessage):
|
||||
tool_returns (List[ToolReturn]): The list of tool returns requested.
|
||||
group_id (str): The multi-agent group that the message was sent in.
|
||||
sender_id (str): The id of the sender of the message, can be an identity id or agent id.
|
||||
conversation_id (str): The conversation this message belongs to.
|
||||
t
|
||||
"""
|
||||
|
||||
@@ -237,6 +238,7 @@ class Message(BaseMessage):
|
||||
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
|
||||
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
|
||||
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
|
||||
conversation_id: Optional[str] = Field(default=None, description="The conversation this message belongs to")
|
||||
is_err: Optional[bool] = Field(
|
||||
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
|
||||
)
|
||||
@@ -1639,13 +1641,13 @@ class Message(BaseMessage):
|
||||
# TextContent, ImageContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent
|
||||
if isinstance(content_part, ReasoningContent):
|
||||
if current_model == self.model:
|
||||
content.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": content_part.reasoning,
|
||||
"signature": content_part.signature,
|
||||
}
|
||||
)
|
||||
block = {
|
||||
"type": "thinking",
|
||||
"thinking": content_part.reasoning,
|
||||
}
|
||||
if content_part.signature:
|
||||
block["signature"] = content_part.signature
|
||||
content.append(block)
|
||||
elif isinstance(content_part, RedactedReasoningContent):
|
||||
if current_model == self.model:
|
||||
content.append(
|
||||
@@ -1671,13 +1673,13 @@ class Message(BaseMessage):
|
||||
for content_part in self.content:
|
||||
if isinstance(content_part, ReasoningContent):
|
||||
if current_model == self.model:
|
||||
content.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": content_part.reasoning,
|
||||
"signature": content_part.signature,
|
||||
}
|
||||
)
|
||||
block = {
|
||||
"type": "thinking",
|
||||
"thinking": content_part.reasoning,
|
||||
}
|
||||
if content_part.signature:
|
||||
block["signature"] = content_part.signature
|
||||
content.append(block)
|
||||
if isinstance(content_part, RedactedReasoningContent):
|
||||
if current_model == self.model:
|
||||
content.append(
|
||||
@@ -1729,29 +1731,42 @@ class Message(BaseMessage):
|
||||
elif self.role == "tool":
|
||||
# NOTE: Anthropic uses role "user" for "tool" responses
|
||||
content = []
|
||||
for tool_return in self.tool_returns:
|
||||
if not tool_return.tool_call_id:
|
||||
from letta.log import get_logger
|
||||
# Handle the case where tool_returns is None or empty
|
||||
if self.tool_returns:
|
||||
# For single tool returns, we can use the message's tool_call_id as fallback
|
||||
# since self.tool_call_id == tool_returns[0].tool_call_id for legacy compatibility.
|
||||
# For multiple tool returns (parallel tool calls), each must have its own ID
|
||||
# to correctly map results to their corresponding tool invocations.
|
||||
use_message_fallback = len(self.tool_returns) == 1
|
||||
for idx, tool_return in enumerate(self.tool_returns):
|
||||
# Get tool_call_id from tool_return; only use message fallback for single returns
|
||||
resolved_tool_call_id = tool_return.tool_call_id
|
||||
if not resolved_tool_call_id and use_message_fallback:
|
||||
resolved_tool_call_id = self.tool_call_id
|
||||
if not resolved_tool_call_id:
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.error(
|
||||
f"Missing tool_call_id in tool return. "
|
||||
f"Message ID: {self.id}, "
|
||||
f"Tool name: {getattr(tool_return, 'name', 'unknown')}, "
|
||||
f"Tool return: {tool_return}"
|
||||
logger = get_logger(__name__)
|
||||
logger.error(
|
||||
f"Missing tool_call_id in tool return and no fallback available. "
|
||||
f"Message ID: {self.id}, "
|
||||
f"Tool name: {self.name or 'unknown'}, "
|
||||
f"Tool return index: {idx}/{len(self.tool_returns)}, "
|
||||
f"Tool return status: {tool_return.status}"
|
||||
)
|
||||
raise TypeError(
|
||||
f"Anthropic API requires tool_use_id to be set. "
|
||||
f"Message ID: {self.id}, Tool: {self.name or 'unknown'}, "
|
||||
f"Tool return index: {idx}/{len(self.tool_returns)}"
|
||||
)
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": resolved_tool_call_id,
|
||||
"content": func_response,
|
||||
}
|
||||
)
|
||||
raise TypeError(
|
||||
f"Anthropic API requires tool_use_id to be set. "
|
||||
f"Message ID: {self.id}, Tool: {getattr(tool_return, 'name', 'unknown')}"
|
||||
)
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_return.tool_call_id,
|
||||
"content": func_response,
|
||||
}
|
||||
)
|
||||
if content:
|
||||
anthropic_message = {
|
||||
"role": "user",
|
||||
@@ -2302,6 +2317,7 @@ class MessageSearchRequest(BaseModel):
|
||||
agent_id: Optional[str] = Field(None, description="Filter messages by agent ID")
|
||||
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
|
||||
template_id: Optional[str] = Field(None, description="Filter messages by template ID")
|
||||
conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID")
|
||||
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
|
||||
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
|
||||
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
|
||||
@@ -2310,6 +2326,8 @@ class MessageSearchRequest(BaseModel):
|
||||
class SearchAllMessagesRequest(BaseModel):
|
||||
query: str = Field(..., description="Text query for full-text search")
|
||||
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
|
||||
agent_id: Optional[str] = Field(None, description="Filter messages by agent ID")
|
||||
conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID")
|
||||
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
|
||||
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
|
||||
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
|
||||
|
||||
@@ -47,6 +47,7 @@ class Model(LLMConfig, ModelBase):
|
||||
"bedrock",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"zai",
|
||||
] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True)
|
||||
context_window: int = Field(
|
||||
..., description="Deprecated: Use 'max_context_window' field instead. The context window size for the model.", deprecated=True
|
||||
@@ -131,6 +132,7 @@ class Model(LLMConfig, ModelBase):
|
||||
ProviderType.google_vertex: GoogleVertexModelSettings,
|
||||
ProviderType.azure: AzureModelSettings,
|
||||
ProviderType.xai: XAIModelSettings,
|
||||
ProviderType.zai: ZAIModelSettings,
|
||||
ProviderType.groq: GroqModelSettings,
|
||||
ProviderType.deepseek: DeepseekModelSettings,
|
||||
ProviderType.together: TogetherModelSettings,
|
||||
@@ -352,6 +354,22 @@ class XAIModelSettings(ModelSettings):
|
||||
}
|
||||
|
||||
|
||||
class ZAIModelSettings(ModelSettings):
|
||||
"""Z.ai (ZhipuAI) model configuration (OpenAI-compatible)."""
|
||||
|
||||
provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
class GroqModelSettings(ModelSettings):
|
||||
"""Groq model configuration (OpenAI-compatible)."""
|
||||
|
||||
@@ -424,6 +442,7 @@ ModelSettingsUnion = Annotated[
|
||||
GoogleVertexModelSettings,
|
||||
AzureModelSettings,
|
||||
XAIModelSettings,
|
||||
ZAIModelSettings,
|
||||
GroqModelSettings,
|
||||
DeepseekModelSettings,
|
||||
TogetherModelSettings,
|
||||
|
||||
@@ -18,6 +18,7 @@ from .openrouter import OpenRouterProvider
|
||||
from .together import TogetherProvider
|
||||
from .vllm import VLLMProvider
|
||||
from .xai import XAIProvider
|
||||
from .zai import ZAIProvider
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
@@ -43,5 +44,6 @@ __all__ = [
|
||||
"TogetherProvider",
|
||||
"VLLMProvider", # Replaces ChatCompletions and Completions
|
||||
"XAIProvider",
|
||||
"ZAIProvider",
|
||||
"OpenRouterProvider",
|
||||
]
|
||||
|
||||
@@ -109,18 +109,19 @@ class AnthropicProvider(Provider):
|
||||
|
||||
async def check_api_key(self):
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if api_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
try:
|
||||
# just use a cheap model to count some tokens - as of 5/7/2025 this is faster than fetching the list of models
|
||||
anthropic_client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}])
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with Anthropic: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
except Exception as e:
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
else:
|
||||
if not api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
try:
|
||||
# Use async Anthropic client
|
||||
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
# just use a cheap model to count some tokens - as of 5/7/2025 this is faster than fetching the list of models
|
||||
await anthropic_client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}])
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with Anthropic: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
except Exception as e:
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for Anthropic models."""
|
||||
if "opus" in model_name:
|
||||
@@ -138,8 +139,21 @@ class AnthropicProvider(Provider):
|
||||
NOTE: currently there is no GET /models, so we need to hardcode
|
||||
"""
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
# For claude-pro-max provider, use OAuth Bearer token instead of api_key
|
||||
is_oauth_provider = self.name == "claude-pro-max"
|
||||
|
||||
if api_key:
|
||||
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
if is_oauth_provider:
|
||||
anthropic_client = anthropic.AsyncAnthropic(
|
||||
default_headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": "oauth-2025-04-20",
|
||||
},
|
||||
)
|
||||
else:
|
||||
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.AsyncAnthropic()
|
||||
else:
|
||||
|
||||
@@ -196,6 +196,7 @@ class Provider(ProviderBase):
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
XAIProvider,
|
||||
ZAIProvider,
|
||||
)
|
||||
|
||||
if self.base_url == "":
|
||||
@@ -230,6 +231,8 @@ class Provider(ProviderBase):
|
||||
return CerebrasProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.xai:
|
||||
return XAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.zai:
|
||||
return ZAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.lmstudio_openai:
|
||||
return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.bedrock:
|
||||
|
||||
@@ -8,10 +8,13 @@ from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers.base import Provider
|
||||
|
||||
LETTA_EMBEDDING_ENDPOINT = "https://embeddings.letta.com/"
|
||||
|
||||
|
||||
class LettaProvider(Provider):
|
||||
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = Field(LETTA_EMBEDDING_ENDPOINT, description="Base URL for the Letta API (used for embeddings).")
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
return [
|
||||
@@ -32,7 +35,7 @@ class LettaProvider(Provider):
|
||||
EmbeddingConfig(
|
||||
embedding_model="letta-free", # NOTE: renamed
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://embeddings.letta.com/",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle("letta-free", is_embedding=True),
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Literal
|
||||
|
||||
from openai import AsyncOpenAI, AuthenticationError
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_CONTEXT_WINDOW
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
@@ -23,11 +25,21 @@ class OpenAIProvider(Provider):
|
||||
base_url: str = Field("https://api.openai.com/v1", description="Base URL for the OpenAI API.")
|
||||
|
||||
async def check_api_key(self):
|
||||
from letta.llm_api.openai import openai_check_valid_api_key # TODO: DO NOT USE THIS - old code path
|
||||
|
||||
# Decrypt API key before using
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
openai_check_valid_api_key(self.base_url, api_key)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
try:
|
||||
# Use async OpenAI client to check API key validity
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
|
||||
# Just list models to verify API key works
|
||||
await client.models.list()
|
||||
except AuthenticationError as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with OpenAI: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
except Exception as e:
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for OpenAI models."""
|
||||
|
||||
71
letta/schemas/providers/zai.py
Normal file
71
letta/schemas/providers/zai.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Literal
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers.openai import OpenAIProvider
|
||||
|
||||
# Z.ai model context windows
|
||||
# Reference: https://docs.z.ai/
|
||||
MODEL_CONTEXT_WINDOWS = {
|
||||
"glm-4.5": 128000,
|
||||
"glm-4.6": 200000,
|
||||
"glm-4.7": 200000,
|
||||
}
|
||||
|
||||
|
||||
class ZAIProvider(OpenAIProvider):
|
||||
"""Z.ai (ZhipuAI) provider - https://docs.z.ai/"""
|
||||
|
||||
provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str | None = Field(None, description="API key for the Z.ai API.", deprecated=True)
|
||||
base_url: str = Field("https://api.z.ai/api/paas/v4/", description="Base URL for the Z.ai API.")
|
||||
|
||||
def get_model_context_window_size(self, model_name: str) -> int | None:
|
||||
# Z.ai doesn't return context window in the model listing,
|
||||
# this is hardcoded from documentation
|
||||
return MODEL_CONTEXT_WINDOWS.get(model_name)
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
|
||||
data = response.get("data", response)
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"Z.ai model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
# In case Z.ai starts supporting it in the future:
|
||||
if "context_length" in model:
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
if not context_window_size:
|
||||
logger.warning(f"Couldn't find context window size for model {model_name}")
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type=self.provider_type.value,
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
@@ -27,6 +27,9 @@ class Run(RunBase):
|
||||
# Agent relationship
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent associated with the run.")
|
||||
|
||||
# Conversation relationship
|
||||
conversation_id: Optional[str] = Field(None, description="The unique identifier of the conversation associated with the run.")
|
||||
|
||||
# Template fields
|
||||
base_template_id: Optional[str] = Field(None, description="The base template ID that the run belongs to.")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic_core import core_schema
|
||||
|
||||
from letta.helpers.crypto_utils import CryptoUtils
|
||||
from letta.log import get_logger
|
||||
from letta.utils import bounded_gather
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -67,6 +68,72 @@ class Secret(BaseModel):
|
||||
return instance
|
||||
raise # Re-raise if it's a different error
|
||||
|
||||
@classmethod
|
||||
async def from_plaintext_async(cls, value: Optional[str]) -> "Secret":
|
||||
"""
|
||||
Create a Secret from a plaintext value, encrypting it asynchronously.
|
||||
|
||||
This async version runs encryption in a thread pool to avoid blocking
|
||||
the event loop during the CPU-intensive PBKDF2 key derivation (100-500ms).
|
||||
|
||||
Use this method in all async contexts (FastAPI endpoints, async services, etc.)
|
||||
to avoid blocking the event loop.
|
||||
|
||||
Args:
|
||||
value: The plaintext value to encrypt
|
||||
|
||||
Returns:
|
||||
A Secret instance with the encrypted (or plaintext) value
|
||||
"""
|
||||
if value is None:
|
||||
return cls.model_construct(encrypted_value=None)
|
||||
|
||||
# Guard against double encryption - check if value is already encrypted
|
||||
if CryptoUtils.is_encrypted(value):
|
||||
logger.warning("Creating Secret from already-encrypted value. This can be dangerous.")
|
||||
|
||||
# Try to encrypt asynchronously, but fall back to storing plaintext if no encryption key
|
||||
try:
|
||||
encrypted = await CryptoUtils.encrypt_async(value)
|
||||
return cls.model_construct(encrypted_value=encrypted)
|
||||
except ValueError as e:
|
||||
# No encryption key available, store as plaintext in the _enc column
|
||||
if "No encryption key configured" in str(e):
|
||||
logger.warning(
|
||||
"No encryption key configured. Storing Secret value as plaintext in _enc column. "
|
||||
"Set LETTA_ENCRYPTION_KEY environment variable to enable encryption."
|
||||
)
|
||||
instance = cls.model_construct(encrypted_value=value)
|
||||
instance._plaintext_cache = value # Cache it since we know the plaintext
|
||||
return instance
|
||||
raise # Re-raise if it's a different error
|
||||
|
||||
@classmethod
|
||||
async def from_plaintexts_async(cls, values: dict[str, str], max_concurrency: int = 10) -> dict[str, "Secret"]:
|
||||
"""
|
||||
Create multiple Secrets from plaintexts concurrently with bounded concurrency.
|
||||
|
||||
Uses bounded_gather() to encrypt values in parallel while limiting
|
||||
concurrent operations to prevent overwhelming the event loop.
|
||||
|
||||
Args:
|
||||
values: Dict of key -> plaintext value
|
||||
max_concurrency: Maximum number of concurrent encryption operations (default: 10)
|
||||
|
||||
Returns:
|
||||
Dict of key -> Secret
|
||||
"""
|
||||
if not values:
|
||||
return {}
|
||||
|
||||
keys = list(values.keys())
|
||||
|
||||
async def encrypt_one(key: str) -> "Secret":
|
||||
return await cls.from_plaintext_async(values[key])
|
||||
|
||||
secrets = await bounded_gather([encrypt_one(k) for k in keys], max_concurrency=max_concurrency)
|
||||
return dict(zip(keys, secrets))
|
||||
|
||||
@classmethod
|
||||
def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret":
|
||||
"""
|
||||
|
||||
@@ -38,6 +38,7 @@ class Step(StepBase):
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
|
||||
trace_id: Optional[str] = Field(None, description="The trace id of the agent step.")
|
||||
request_id: Optional[str] = Field(None, description="The API request log ID from cloud-api for correlating steps with API requests.")
|
||||
messages: List[Message] = Field(
|
||||
[],
|
||||
description="The messages generated during this step. Deprecated: use `GET /v1/steps/{step_id}/messages` endpoint instead",
|
||||
|
||||
@@ -39,12 +39,17 @@ else:
|
||||
|
||||
# Add asyncpg-specific settings for connection
|
||||
if not settings.disable_sqlalchemy_pooling:
|
||||
engine_args["connect_args"] = {
|
||||
connect_args = {
|
||||
"timeout": settings.pg_pool_timeout,
|
||||
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid.uuid4()}__",
|
||||
"statement_cache_size": 0,
|
||||
"prepared_statement_cache_size": 0,
|
||||
}
|
||||
# Only add SSL if not already specified in connection string
|
||||
if "sslmode" not in async_pg_uri and "ssl" not in async_pg_uri:
|
||||
connect_args["ssl"] = "require"
|
||||
|
||||
engine_args["connect_args"] = connect_args
|
||||
|
||||
# Create the engine once at module level
|
||||
engine: AsyncEngine = create_async_engine(async_pg_uri, **engine_args)
|
||||
|
||||
@@ -18,7 +18,7 @@ faulthandler.enable()
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import JSONResponse, ORJSONResponse
|
||||
from marshmallow import ValidationError
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
@@ -32,6 +32,9 @@ from letta.errors import (
|
||||
AgentFileImportError,
|
||||
AgentNotFoundForExportError,
|
||||
BedrockPermissionError,
|
||||
ConcurrentUpdateError,
|
||||
ConversationBusyError,
|
||||
EmbeddingConfigRequiredError,
|
||||
HandleNotFoundError,
|
||||
LettaAgentNotFoundError,
|
||||
LettaExpiredError,
|
||||
@@ -49,6 +52,7 @@ from letta.errors import (
|
||||
LLMProviderOverloaded,
|
||||
LLMRateLimitError,
|
||||
LLMTimeoutError,
|
||||
NoActiveRunsToCancelError,
|
||||
PendingApprovalError,
|
||||
)
|
||||
from letta.helpers.pinecone_utils import get_pinecone_indices, should_use_pinecone, upsert_pinecone_indices
|
||||
@@ -69,7 +73,7 @@ from letta.server.global_exception_handler import setup_global_exception_handler
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware
|
||||
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware, RequestIdMiddleware
|
||||
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
|
||||
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
|
||||
@@ -241,10 +245,6 @@ def create_application() -> "FastAPI":
|
||||
os.environ.setdefault("DD_PROFILING_MEMORY_ENABLED", str(telemetry_settings.datadog_profiling_memory_enabled).lower())
|
||||
os.environ.setdefault("DD_PROFILING_HEAP_ENABLED", str(telemetry_settings.datadog_profiling_heap_enabled).lower())
|
||||
|
||||
# Enable LLM Observability for tracking LLM calls, prompts, and completions
|
||||
os.environ.setdefault("DD_LLMOBS_ENABLED", "1")
|
||||
os.environ.setdefault("DD_LLMOBS_ML_APP", "memgpt-server")
|
||||
|
||||
# Note: DD_LOGS_INJECTION, DD_APPSEC_ENABLED, DD_IAST_ENABLED, DD_APPSEC_SCA_ENABLED
|
||||
# are set via deployment configs and automatically picked up by ddtrace
|
||||
|
||||
@@ -252,6 +252,49 @@ def create_application() -> "FastAPI":
|
||||
import ddtrace
|
||||
|
||||
ddtrace.patch_all() # Auto-instrument FastAPI, HTTP, DB, etc.
|
||||
|
||||
llmobs_flag = os.getenv("DD_LLMOBS_ENABLED", "")
|
||||
from ddtrace.llmobs import LLMObs
|
||||
|
||||
try:
|
||||
from ddtrace.llmobs._constants import MODEL_PROVIDER
|
||||
from ddtrace.llmobs._integrations.openai import OpenAIIntegration
|
||||
|
||||
if not getattr(OpenAIIntegration, "_letta_provider_patch_done", False):
|
||||
original_set_tags = OpenAIIntegration._llmobs_set_tags
|
||||
|
||||
def _letta_set_tags(self, span, args, kwargs, response=None, operation=""):
|
||||
original_set_tags(self, span, args, kwargs, response=response, operation=operation)
|
||||
|
||||
base_url = span.get_tag("openai.api_base")
|
||||
if not base_url:
|
||||
try:
|
||||
client = getattr(self, "_client", None)
|
||||
base_url = str(getattr(client, "_base_url", "") or "")
|
||||
except Exception:
|
||||
base_url = ""
|
||||
|
||||
u = (base_url or "").lower()
|
||||
provider = None
|
||||
if "openrouter" in u:
|
||||
provider = "openrouter"
|
||||
elif "groq" in u:
|
||||
provider = "groq"
|
||||
|
||||
if provider:
|
||||
span._set_ctx_item(MODEL_PROVIDER, provider)
|
||||
span._set_tag_str("openai.request.provider", provider)
|
||||
|
||||
OpenAIIntegration._llmobs_set_tags = _letta_set_tags
|
||||
OpenAIIntegration._letta_provider_patch_done = True
|
||||
except Exception:
|
||||
logger.exception("Failed to patch ddtrace OpenAI LLMObs provider detection")
|
||||
|
||||
if llmobs_flag:
|
||||
LLMObs.enable(
|
||||
ml_app=os.getenv("DD_LLMOBS_ML_APP") or telemetry_settings.datadog_service_name,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Datadog tracer initialized: env={dd_env}, "
|
||||
f"service={telemetry_settings.datadog_service_name}, "
|
||||
@@ -296,6 +339,7 @@ def create_application() -> "FastAPI":
|
||||
version=letta_version,
|
||||
debug=debug_mode, # if True, the stack trace will be printed in the response
|
||||
lifespan=lifespan,
|
||||
default_response_class=ORJSONResponse, # Use orjson for 10x faster JSON serialization
|
||||
)
|
||||
|
||||
# === Global Exception Handlers ===
|
||||
@@ -431,6 +475,7 @@ def create_application() -> "FastAPI":
|
||||
app.add_exception_handler(LettaToolCreateError, _error_handler_400)
|
||||
app.add_exception_handler(LettaToolNameConflictError, _error_handler_400)
|
||||
app.add_exception_handler(AgentFileImportError, _error_handler_400)
|
||||
app.add_exception_handler(EmbeddingConfigRequiredError, _error_handler_400)
|
||||
app.add_exception_handler(ValueError, _error_handler_400)
|
||||
|
||||
# 404 Not Found errors
|
||||
@@ -451,7 +496,10 @@ def create_application() -> "FastAPI":
|
||||
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
|
||||
app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409)
|
||||
app.add_exception_handler(IntegrityError, _error_handler_409)
|
||||
app.add_exception_handler(ConcurrentUpdateError, _error_handler_409)
|
||||
app.add_exception_handler(ConversationBusyError, _error_handler_409)
|
||||
app.add_exception_handler(PendingApprovalError, _error_handler_409)
|
||||
app.add_exception_handler(NoActiveRunsToCancelError, _error_handler_409)
|
||||
|
||||
# 415 Unsupported Media Type errors
|
||||
app.add_exception_handler(LettaUnsupportedFileUploadError, _error_handler_415)
|
||||
@@ -586,6 +634,10 @@ def create_application() -> "FastAPI":
|
||||
# Add unified logging middleware - enriches log context and logs exceptions
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
# Add request ID middleware - extracts x-api-request-log-id header and sets it in contextvar
|
||||
# This is a pure ASGI middleware to properly propagate contextvars to streaming responses
|
||||
app.add_middleware(RequestIdMiddleware)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
@@ -634,9 +686,6 @@ def create_application() -> "FastAPI":
|
||||
# app.include_router(route, prefix="", include_in_schema=False)
|
||||
app.include_router(route, prefix="/latest", include_in_schema=False)
|
||||
|
||||
# NOTE: ethan these are the extra routes
|
||||
# TODO(ethan) remove
|
||||
|
||||
# admin/users
|
||||
app.include_router(users_router, prefix=ADMIN_PREFIX)
|
||||
app.include_router(organizations_router, prefix=ADMIN_PREFIX)
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, Optional
|
||||
from fastapi import Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.otel.tracing import tracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@@ -39,25 +41,27 @@ def get_headers(
|
||||
modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox"),
|
||||
) -> HeaderParams:
|
||||
"""Dependency injection function to extract common headers from requests."""
|
||||
return HeaderParams(
|
||||
actor_id=actor_id,
|
||||
user_agent=user_agent,
|
||||
project_id=project_id,
|
||||
letta_source=letta_source,
|
||||
sdk_version=sdk_version,
|
||||
experimental_params=ExperimentalParams(
|
||||
message_async=(message_async == "true") if message_async else None,
|
||||
letta_v1_agent=(letta_v1_agent == "true") if letta_v1_agent else None,
|
||||
letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None,
|
||||
modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None,
|
||||
),
|
||||
)
|
||||
with tracer.start_as_current_span("dependency.get_headers"):
|
||||
return HeaderParams(
|
||||
actor_id=actor_id,
|
||||
user_agent=user_agent,
|
||||
project_id=project_id,
|
||||
letta_source=letta_source,
|
||||
sdk_version=sdk_version,
|
||||
experimental_params=ExperimentalParams(
|
||||
message_async=(message_async == "true") if message_async else None,
|
||||
letta_v1_agent=(letta_v1_agent == "true") if letta_v1_agent else None,
|
||||
letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None,
|
||||
modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# TODO: why does this double up the interface?
|
||||
async def get_letta_server() -> "SyncServer":
|
||||
# Check if a global server is already instantiated
|
||||
from letta.server.rest_api.app import server
|
||||
with tracer.start_as_current_span("dependency.get_letta_server"):
|
||||
# Check if a global server is already instantiated
|
||||
from letta.server.rest_api.app import server
|
||||
|
||||
# assert isinstance(server, SyncServer)
|
||||
return server
|
||||
# assert isinstance(server, SyncServer)
|
||||
return server
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware
|
||||
from letta.server.rest_api.middleware.logging import LoggingMiddleware
|
||||
from letta.server.rest_api.middleware.request_id import RequestIdMiddleware
|
||||
|
||||
__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware"]
|
||||
__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware", "RequestIdMiddleware"]
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Unified logging middleware that enriches log context and ensures exceptions are logged.
|
||||
"""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable
|
||||
|
||||
@@ -11,6 +10,7 @@ from starlette.requests import Request
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.log_context import clear_log_context, update_log_context
|
||||
from letta.otel.tracing import tracer
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.validators import PRIMITIVE_ID_PATTERNS
|
||||
|
||||
@@ -33,95 +33,96 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
clear_log_context()
|
||||
|
||||
try:
|
||||
# Extract and set log context
|
||||
context = {}
|
||||
with tracer.start_as_current_span("middleware.logging"):
|
||||
# Extract and set log context
|
||||
context = {}
|
||||
with tracer.start_as_current_span("middleware.logging.context"):
|
||||
# Headers
|
||||
actor_id = request.headers.get("user_id")
|
||||
if actor_id:
|
||||
context["actor_id"] = actor_id
|
||||
|
||||
# Headers
|
||||
actor_id = request.headers.get("user_id")
|
||||
if actor_id:
|
||||
context["actor_id"] = actor_id
|
||||
project_id = request.headers.get("x-project-id")
|
||||
if project_id:
|
||||
context["project_id"] = project_id
|
||||
|
||||
project_id = request.headers.get("x-project-id")
|
||||
if project_id:
|
||||
context["project_id"] = project_id
|
||||
org_id = request.headers.get("x-organization-id")
|
||||
if org_id:
|
||||
context["org_id"] = org_id
|
||||
|
||||
org_id = request.headers.get("x-organization-id")
|
||||
if org_id:
|
||||
context["org_id"] = org_id
|
||||
user_agent = request.headers.get("x-agent-id")
|
||||
if user_agent:
|
||||
context["agent_id"] = user_agent
|
||||
|
||||
user_agent = request.headers.get("x-agent-id")
|
||||
if user_agent:
|
||||
context["agent_id"] = user_agent
|
||||
run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id")
|
||||
if run_id_header:
|
||||
context["run_id"] = run_id_header
|
||||
|
||||
run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id")
|
||||
if run_id_header:
|
||||
context["run_id"] = run_id_header
|
||||
path = request.url.path
|
||||
path_parts = [p for p in path.split("/") if p]
|
||||
|
||||
path = request.url.path
|
||||
path_parts = [p for p in path.split("/") if p]
|
||||
# Path
|
||||
matched_parts = set()
|
||||
for part in path_parts:
|
||||
if part in matched_parts:
|
||||
continue
|
||||
|
||||
# Path
|
||||
matched_parts = set()
|
||||
for part in path_parts:
|
||||
if part in matched_parts:
|
||||
continue
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
if pattern and pattern.match(part):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
|
||||
if pattern and pattern.match(part):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
context[context_key] = part
|
||||
matched_parts.add(part)
|
||||
break
|
||||
|
||||
context[context_key] = part
|
||||
matched_parts.add(part)
|
||||
break
|
||||
# Query Parameters
|
||||
for param_value in request.query_params.values():
|
||||
if param_value in matched_parts:
|
||||
continue
|
||||
|
||||
# Query Parameters
|
||||
for param_value in request.query_params.values():
|
||||
if param_value in matched_parts:
|
||||
continue
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
if pattern and pattern.match(param_value):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
|
||||
if pattern and pattern.match(param_value):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
# Only set if not already set from path (path takes precedence over query params)
|
||||
# Query params can overwrite headers, but path values take precedence
|
||||
if context_key not in context:
|
||||
context[context_key] = param_value
|
||||
matched_parts.add(param_value)
|
||||
break
|
||||
|
||||
# Only set if not already set from path (path takes precedence over query params)
|
||||
# Query params can overwrite headers, but path values take precedence
|
||||
if context_key not in context:
|
||||
context[context_key] = param_value
|
||||
matched_parts.add(param_value)
|
||||
break
|
||||
if context:
|
||||
update_log_context(**context)
|
||||
|
||||
if context:
|
||||
update_log_context(**context)
|
||||
logger.debug(
|
||||
f"Incoming request: {request.method} {request.url.path}",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"client_host": request.client.host if request.client else None,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Incoming request: {request.method} {request.url.path}",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"client_host": request.client.host if request.client else None,
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except Exception as exc:
|
||||
# Extract request context
|
||||
|
||||
66
letta/server/rest_api/middleware/request_id.py
Normal file
66
letta/server/rest_api/middleware/request_id.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Middleware for extracting and propagating API request IDs from cloud-api.
|
||||
|
||||
Uses a pure ASGI middleware pattern to properly propagate the request_id
|
||||
to streaming responses. BaseHTTPMiddleware has a known limitation where
|
||||
contextvars are not propagated to streaming response generators.
|
||||
See: https://github.com/encode/starlette/discussions/1729
|
||||
|
||||
This middleware:
|
||||
1. Extracts the x-api-request-log-id header from cloud-api
|
||||
2. Sets it in the contextvar (for non-streaming code)
|
||||
3. Stores it in request.state (for streaming responses where contextvars don't propagate)
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from letta.otel.tracing import tracer
|
||||
|
||||
# Contextvar for storing the request ID across async boundaries
|
||||
request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
|
||||
|
||||
|
||||
def get_request_id() -> Optional[str]:
|
||||
"""Get the request ID from the current context."""
|
||||
return request_id_var.get()
|
||||
|
||||
|
||||
class RequestIdMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that extracts and propagates the API request ID.
|
||||
|
||||
The request ID comes from cloud-api via the x-api-request-log-id header
|
||||
and is used to correlate steps with API request logs.
|
||||
|
||||
This middleware stores the request_id in:
|
||||
- The request_id_var contextvar (works for non-streaming responses)
|
||||
- request.state.request_id (works for streaming responses where contextvars may not propagate)
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
with tracer.start_as_current_span("middleware.request_id"):
|
||||
# Create a Request object for easier header access
|
||||
request = Request(scope)
|
||||
|
||||
# Extract request_id from header
|
||||
request_id = request.headers.get("x-api-request-log-id")
|
||||
|
||||
# Set in contextvar (for non-streaming code paths)
|
||||
request_id_var.set(request_id)
|
||||
|
||||
# Also store in request.state for streaming responses where contextvars don't propagate
|
||||
# This is accessible via request.state.request_id throughout the request lifecycle
|
||||
request.state.request_id = request_id
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
@@ -4,7 +4,9 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import AsyncIterator, Dict, List, Optional
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import aclosing
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from letta.data_sources.redis_client import AsyncRedisClient
|
||||
from letta.log import get_logger
|
||||
@@ -194,12 +196,13 @@ class RedisSSEStreamWriter:
|
||||
|
||||
|
||||
async def create_background_stream_processor(
|
||||
stream_generator,
|
||||
stream_generator: AsyncGenerator[str | bytes | tuple[str | bytes, int], None],
|
||||
redis_client: AsyncRedisClient,
|
||||
run_id: str,
|
||||
writer: Optional[RedisSSEStreamWriter] = None,
|
||||
run_manager: Optional[RunManager] = None,
|
||||
actor: Optional[User] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Process a stream in the background and store chunks to Redis.
|
||||
@@ -214,10 +217,12 @@ async def create_background_stream_processor(
|
||||
writer: Optional pre-configured writer (creates new if not provided)
|
||||
run_manager: Optional run manager for updating run status
|
||||
actor: Optional actor for run status updates
|
||||
conversation_id: Optional conversation ID for releasing lock on terminal states
|
||||
"""
|
||||
stop_reason = None
|
||||
saw_done = False
|
||||
saw_error = False
|
||||
error_metadata = None
|
||||
|
||||
if writer is None:
|
||||
writer = RedisSSEStreamWriter(redis_client)
|
||||
@@ -227,32 +232,52 @@ async def create_background_stream_processor(
|
||||
should_stop_writer = False
|
||||
|
||||
try:
|
||||
async for chunk in stream_generator:
|
||||
if isinstance(chunk, tuple):
|
||||
chunk = chunk[0]
|
||||
# Always close the upstream async generator so its `finally` blocks run.
|
||||
# (e.g., stream adapters may persist terminal error metadata on close)
|
||||
async with aclosing(stream_generator):
|
||||
async for chunk in stream_generator:
|
||||
if isinstance(chunk, tuple):
|
||||
chunk = chunk[0]
|
||||
|
||||
# Track terminal events
|
||||
if isinstance(chunk, str):
|
||||
if "data: [DONE]" in chunk:
|
||||
saw_done = True
|
||||
if "event: error" in chunk:
|
||||
saw_error = True
|
||||
# Track terminal events
|
||||
if isinstance(chunk, str):
|
||||
if "data: [DONE]" in chunk:
|
||||
saw_done = True
|
||||
if "event: error" in chunk:
|
||||
saw_error = True
|
||||
|
||||
is_done = saw_done or saw_error
|
||||
# Best-effort extraction of the error payload so we can persist it on the run.
|
||||
# Chunk format is typically: "event: error\ndata: {json}\n\n"
|
||||
if saw_error and error_metadata is None:
|
||||
try:
|
||||
# Grab the first `data:` line after `event: error`
|
||||
for line in chunk.splitlines():
|
||||
if line.startswith("data: "):
|
||||
maybe_json = line[len("data: ") :].strip()
|
||||
if maybe_json and maybe_json[0] in "[{":
|
||||
error_metadata = {"error": json.loads(maybe_json)}
|
||||
else:
|
||||
error_metadata = {"error": {"message": maybe_json}}
|
||||
break
|
||||
except Exception:
|
||||
# Don't let parsing failures interfere with streaming
|
||||
error_metadata = {"error": {"message": "Failed to parse error payload from stream."}}
|
||||
|
||||
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
|
||||
is_done = saw_done or saw_error
|
||||
|
||||
if is_done:
|
||||
break
|
||||
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
|
||||
|
||||
try:
|
||||
# Extract stop_reason from stop_reason chunks
|
||||
maybe_json_chunk = chunk.split("data: ")[1]
|
||||
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
|
||||
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
|
||||
stop_reason = maybe_stop_reason.get("stop_reason")
|
||||
except:
|
||||
pass
|
||||
if is_done:
|
||||
break
|
||||
|
||||
try:
|
||||
# Extract stop_reason from stop_reason chunks
|
||||
maybe_json_chunk = chunk.split("data: ")[1]
|
||||
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
|
||||
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
|
||||
stop_reason = maybe_stop_reason.get("stop_reason")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Stream ended naturally - check if we got a proper terminal
|
||||
if not saw_done and not saw_error:
|
||||
@@ -319,6 +344,7 @@ async def create_background_stream_processor(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value, metadata={"error": str(e)}),
|
||||
actor=actor,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
finally:
|
||||
if should_stop_writer:
|
||||
@@ -357,10 +383,15 @@ async def create_background_stream_processor(
|
||||
logger.warning(f"Unknown stop_reason '{final_stop_reason}' for run {run_id}, defaulting to completed")
|
||||
run_status = RunStatus.completed
|
||||
|
||||
update_kwargs = {"status": run_status, "stop_reason": final_stop_reason}
|
||||
if run_status == RunStatus.failed and error_metadata is not None:
|
||||
update_kwargs["metadata"] = error_metadata
|
||||
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=run_status, stop_reason=final_stop_reason),
|
||||
update=RunUpdate(**update_kwargs),
|
||||
actor=actor,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
# Belt-and-suspenders: always append a terminal [DONE] chunk to ensure clients terminate
|
||||
|
||||
@@ -3,6 +3,7 @@ from letta.server.rest_api.routers.v1.anthropic import router as anthropic_route
|
||||
from letta.server.rest_api.routers.v1.archives import router as archives_router
|
||||
from letta.server.rest_api.routers.v1.blocks import router as blocks_router
|
||||
from letta.server.rest_api.routers.v1.chat_completions import router as chat_completions_router, router as openai_chat_completions_router
|
||||
from letta.server.rest_api.routers.v1.conversations import router as conversations_router
|
||||
from letta.server.rest_api.routers.v1.embeddings import router as embeddings_router
|
||||
from letta.server.rest_api.routers.v1.folders import router as folders_router
|
||||
from letta.server.rest_api.routers.v1.groups import router as groups_router
|
||||
@@ -36,6 +37,7 @@ ROUTERS = [
|
||||
sources_router,
|
||||
folders_router,
|
||||
agents_router,
|
||||
conversations_router,
|
||||
chat_completions_router,
|
||||
groups_router,
|
||||
identities_router,
|
||||
|
||||
@@ -24,8 +24,8 @@ from letta.errors import (
|
||||
AgentExportProcessingError,
|
||||
AgentFileImportError,
|
||||
AgentNotFoundForExportError,
|
||||
NoActiveRunsToCancelError,
|
||||
PendingApprovalError,
|
||||
RunCancelError,
|
||||
)
|
||||
from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
||||
@@ -66,6 +66,7 @@ from letta.server.server import SyncServer
|
||||
from letta.services.lettuce import LettuceClient
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.services.streaming_service import StreamingService
|
||||
from letta.services.summarizer.summarizer_config import CompactionSettings
|
||||
from letta.settings import settings
|
||||
from letta.utils import is_1_0_sdk_version, safe_create_shielded_task, safe_create_task, truncate_file_visible_content
|
||||
from letta.validators import AgentId, BlockId, FileId, MessageId, SourceId, ToolId
|
||||
@@ -389,7 +390,7 @@ async def import_agent(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
try:
|
||||
serialized_data = file.file.read()
|
||||
serialized_data = await file.read()
|
||||
file_size_mb = len(serialized_data) / (1024 * 1024)
|
||||
logger.info(f"Agent import: loaded {file_size_mb:.2f} MB into memory")
|
||||
agent_json = json.loads(serialized_data)
|
||||
@@ -744,9 +745,10 @@ async def detach_source(
|
||||
if not agent_state.sources:
|
||||
agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=actor)
|
||||
|
||||
files = await server.file_manager.list_files(source_id, actor)
|
||||
file_ids = [f.id for f in files]
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
# Query files_agents directly to get exactly what was attached, regardless of source changes
|
||||
file_ids = await server.file_agent_manager.get_file_ids_for_agent_by_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
if file_ids:
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
try:
|
||||
@@ -775,9 +777,10 @@ async def detach_folder_from_agent(
|
||||
if not agent_state.sources:
|
||||
agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=actor)
|
||||
|
||||
files = await server.file_manager.list_files(folder_id, actor)
|
||||
file_ids = [f.id for f in files]
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
# Query files_agents directly to get exactly what was attached, regardless of source changes
|
||||
file_ids = await server.file_agent_manager.get_file_ids_for_agent_by_source(agent_id=agent_id, source_id=folder_id, actor=actor)
|
||||
if file_ids:
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
try:
|
||||
@@ -1382,6 +1385,7 @@ async def list_messages(
|
||||
),
|
||||
order_by: Literal["created_at"] = Query("created_at", description="Field to sort by"),
|
||||
group_id: str | None = Query(None, description="Group ID to filter messages by."),
|
||||
conversation_id: str | None = Query(None, description="Conversation ID to filter messages by."),
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages", deprecated=True),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool.", deprecated=True),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument.", deprecated=True),
|
||||
@@ -1401,6 +1405,7 @@ async def list_messages(
|
||||
before=before,
|
||||
limit=limit,
|
||||
group_id=group_id,
|
||||
conversation_id=conversation_id,
|
||||
reverse=(order == "desc"),
|
||||
return_message_object=False,
|
||||
use_assistant_message=use_assistant_message,
|
||||
@@ -1521,6 +1526,7 @@ async def send_message(
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
@@ -1558,6 +1564,7 @@ async def send_message(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
client_tools=request.client_tools,
|
||||
)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
@@ -1615,6 +1622,7 @@ async def send_message(
|
||||
},
|
||||
}
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def send_message_streaming(
|
||||
request_obj: Request, # FastAPI Request
|
||||
@@ -1625,6 +1633,9 @@ async def send_message_streaming(
|
||||
) -> StreamingResponse | LettaResponse:
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
|
||||
Deprecated: Use the `POST /{agent_id}/messages` endpoint with `streaming=true` in the request body instead.
|
||||
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
@@ -1668,40 +1679,50 @@ async def cancel_message(
|
||||
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
|
||||
run_ids = request.run_ids if request else None
|
||||
if not run_ids:
|
||||
redis_client = await get_redis_client()
|
||||
run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}")
|
||||
run_id = None
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}")
|
||||
except Exception as e:
|
||||
# Redis is optional; fall back to DB to avoid surfacing 5XXs for cancellation.
|
||||
logger.warning(f"Failed to look up run to cancel in redis for agent {agent_id}, falling back to DB: {e}")
|
||||
|
||||
if run_id is None:
|
||||
logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.")
|
||||
run_ids = await server.run_manager.list_runs(
|
||||
runs = await server.run_manager.list_runs(
|
||||
actor=actor,
|
||||
statuses=[RunStatus.created, RunStatus.running],
|
||||
ascending=False,
|
||||
agent_id=agent_id, # NOTE: this will override agent_ids if provided
|
||||
limit=100, # Limit to 10 most recent active runs for cancellation
|
||||
limit=100, # Limit to 100 most recent active runs for cancellation
|
||||
)
|
||||
run_ids = [run.id for run in run_ids]
|
||||
run_ids = [run.id for run in runs]
|
||||
else:
|
||||
run_ids = [run_id]
|
||||
|
||||
if not run_ids:
|
||||
raise NoActiveRunsToCancelError(agent_id=agent_id)
|
||||
|
||||
results = {}
|
||||
failed_to_cancel = []
|
||||
for run_id in run_ids:
|
||||
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if run.metadata.get("lettuce"):
|
||||
lettuce_client = await LettuceClient.create()
|
||||
await lettuce_client.cancel(run_id)
|
||||
try:
|
||||
run = await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id)
|
||||
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if run.metadata and run.metadata.get("lettuce"):
|
||||
try:
|
||||
lettuce_client = await LettuceClient.create()
|
||||
await lettuce_client.cancel(run_id)
|
||||
except Exception as e:
|
||||
# Do not surface cancellation failures as 5XXs.
|
||||
logger.error(f"Failed to cancel Lettuce run {run_id}: {e}")
|
||||
|
||||
await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id)
|
||||
except Exception as e:
|
||||
results[run_id] = "failed"
|
||||
# Cancellation failures should not raise errors back to the client.
|
||||
logger.error(f"Failed to cancel run {run_id}: {str(e)}")
|
||||
failed_to_cancel.append(run_id)
|
||||
continue
|
||||
results[run_id] = "cancelled"
|
||||
logger.info(f"Cancelled run {run_id}")
|
||||
|
||||
if failed_to_cancel:
|
||||
raise RunCancelError(f"Failed to cancel runs: {failed_to_cancel}")
|
||||
return results
|
||||
|
||||
|
||||
@@ -1732,6 +1753,7 @@ async def search_messages(
|
||||
agent_id=request.agent_id,
|
||||
project_id=request.project_id,
|
||||
template_id=request.template_id,
|
||||
conversation_id=request.conversation_id,
|
||||
limit=request.limit,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
@@ -1771,6 +1793,7 @@ async def _process_message_background(
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
@@ -2075,6 +2098,7 @@ async def preview_model_request(
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
@@ -2091,9 +2115,23 @@ async def preview_model_request(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/summarize", status_code=204, operation_id="summarize_messages")
|
||||
class CompactionRequest(BaseModel):
|
||||
compaction_settings: Optional[CompactionSettings] = Field(
|
||||
default=None,
|
||||
description="Optional compaction settings to use for this summarization request. If not provided, the agent's default settings will be used.",
|
||||
)
|
||||
|
||||
|
||||
class CompactionResponse(BaseModel):
|
||||
summary: str
|
||||
num_messages_before: int
|
||||
num_messages_after: int
|
||||
|
||||
|
||||
@router.post("/{agent_id}/summarize", response_model=CompactionResponse, operation_id="summarize_messages")
|
||||
async def summarize_messages(
|
||||
agent_id: AgentId,
|
||||
request: Optional[CompactionRequest] = Body(default=None),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
@@ -2114,6 +2152,7 @@ async def summarize_messages(
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
@@ -2121,12 +2160,27 @@ async def summarize_messages(
|
||||
if agent_eligible and model_compatible:
|
||||
agent_loop = LettaAgentV3(agent_state=agent, actor=actor)
|
||||
in_context_messages = await server.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
|
||||
summary_message, messages = await agent_loop.compact(
|
||||
compaction_settings = request.compaction_settings if request else None
|
||||
num_messages_before = len(in_context_messages)
|
||||
summary_message, messages, summary = await agent_loop.compact(
|
||||
messages=in_context_messages,
|
||||
compaction_settings=compaction_settings,
|
||||
)
|
||||
num_messages_after = len(messages)
|
||||
|
||||
# update the agent state
|
||||
logger.info(f"Summarized {num_messages_before} messages to {num_messages_after}")
|
||||
if num_messages_before <= num_messages_after:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Summarization failed to reduce the number of messages. You may need to use a different CompactionSettings (e.g. using `all` mode).",
|
||||
)
|
||||
await agent_loop._checkpoint_messages(run_id=None, step_id=None, new_messages=[summary_message], in_context_messages=messages)
|
||||
return CompactionResponse(
|
||||
summary=summary,
|
||||
num_messages_before=num_messages_before,
|
||||
num_messages_after=num_messages_after,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -2185,6 +2239,7 @@ async def capture_messages(
|
||||
|
||||
response_messages = await server.message_manager.create_many_messages_async(messages_to_persist, actor=actor)
|
||||
|
||||
run_ids = []
|
||||
sleeptime_group = agent.multi_agent_group if agent.multi_agent_group and agent.multi_agent_group.manager_type == "sleeptime" else None
|
||||
if sleeptime_group:
|
||||
sleeptime_agent_loop = SleeptimeMultiAgentV4(agent_state=agent, actor=actor, group=sleeptime_group)
|
||||
|
||||
273
letta/server/rest_api/routers/v1/conversations.py
Normal file
273
letta/server/rest_api/routers/v1/conversations.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from datetime import timedelta
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from pydantic import Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import LettaExpiredError, LettaInvalidArgumentError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.conversation import Conversation, CreateConversation
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import LettaStreamingRequest, RetrieveStreamRequest
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
||||
from letta.server.rest_api.streaming_response import (
|
||||
StreamingResponseWithStatusCode,
|
||||
add_keepalive_to_stream,
|
||||
cancellation_aware_stream_wrapper,
|
||||
)
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.services.streaming_service import StreamingService
|
||||
from letta.settings import settings
|
||||
from letta.validators import ConversationId
|
||||
|
||||
router = APIRouter(prefix="/conversations", tags=["conversations"])
|
||||
|
||||
# Instantiate manager
|
||||
conversation_manager = ConversationManager()
|
||||
|
||||
|
||||
@router.post("/", response_model=Conversation, operation_id="create_conversation")
|
||||
async def create_conversation(
|
||||
agent_id: str = Query(..., description="The agent ID to create a conversation for"),
|
||||
conversation_create: CreateConversation = Body(default_factory=CreateConversation),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""Create a new conversation for an agent."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.create_conversation(
|
||||
agent_id=agent_id,
|
||||
conversation_create=conversation_create,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Conversation], operation_id="list_conversations")
|
||||
async def list_conversations(
|
||||
agent_id: str = Query(..., description="The agent ID to list conversations for"),
|
||||
limit: int = Query(50, description="Maximum number of conversations to return"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination (conversation ID)"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""List all conversations for an agent."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.list_conversations(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
limit=limit,
|
||||
after=after,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=Conversation, operation_id="retrieve_conversation")
|
||||
async def retrieve_conversation(
|
||||
conversation_id: ConversationId,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""Retrieve a specific conversation."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
ConversationMessagesResponse = Annotated[
|
||||
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
|
||||
]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{conversation_id}/messages",
|
||||
response_model=ConversationMessagesResponse,
|
||||
operation_id="list_conversation_messages",
|
||||
)
|
||||
async def list_conversation_messages(
|
||||
conversation_id: ConversationId,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
before: Optional[str] = Query(
|
||||
None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the conversation"
|
||||
),
|
||||
after: Optional[str] = Query(
|
||||
None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the conversation"
|
||||
),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
):
|
||||
"""
|
||||
List all messages in a conversation.
|
||||
|
||||
Returns LettaMessage objects (UserMessage, AssistantMessage, etc.) for all
|
||||
messages in the conversation, ordered by position (oldest first),
|
||||
with support for cursor-based pagination.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=after,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{conversation_id}/messages",
|
||||
response_model=LettaStreamingResponse,
|
||||
operation_id="send_conversation_message",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def send_conversation_message(
|
||||
conversation_id: ConversationId,
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
) -> StreamingResponse | LettaResponse:
|
||||
"""
|
||||
Send a message to a conversation and get a streaming response.
|
||||
|
||||
This endpoint sends a message to an existing conversation and streams
|
||||
the agent's response back.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Get the conversation to find the agent_id
|
||||
conversation = await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Force streaming mode for this endpoint
|
||||
request.streaming = True
|
||||
|
||||
# Use streaming service
|
||||
streaming_service = StreamingService(server)
|
||||
run, result = await streaming_service.create_agent_stream(
|
||||
agent_id=conversation.agent_id,
|
||||
actor=actor,
|
||||
request=request,
|
||||
run_type="send_conversation_message",
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{conversation_id}/stream",
|
||||
response_model=None,
|
||||
operation_id="retrieve_conversation_stream",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {
|
||||
"description": "Server-Sent Events stream",
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/SystemMessage"},
|
||||
{"$ref": "#/components/schemas/UserMessage"},
|
||||
{"$ref": "#/components/schemas/ReasoningMessage"},
|
||||
{"$ref": "#/components/schemas/HiddenReasoningMessage"},
|
||||
{"$ref": "#/components/schemas/ToolCallMessage"},
|
||||
{"$ref": "#/components/schemas/ToolReturnMessage"},
|
||||
{"$ref": "#/components/schemas/AssistantMessage"},
|
||||
{"$ref": "#/components/schemas/ApprovalRequestMessage"},
|
||||
{"$ref": "#/components/schemas/ApprovalResponseMessage"},
|
||||
{"$ref": "#/components/schemas/LettaPing"},
|
||||
{"$ref": "#/components/schemas/LettaErrorMessage"},
|
||||
{"$ref": "#/components/schemas/LettaStopReason"},
|
||||
{"$ref": "#/components/schemas/LettaUsageStatistics"},
|
||||
]
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def retrieve_conversation_stream(
|
||||
conversation_id: ConversationId,
|
||||
request: RetrieveStreamRequest = Body(None),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Resume the stream for the most recent active run in a conversation.
|
||||
|
||||
This endpoint allows you to reconnect to an active background stream
|
||||
for a conversation, enabling recovery from network interruptions.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
runs_manager = RunManager()
|
||||
|
||||
# Find the most recent active run for this conversation
|
||||
active_runs = await runs_manager.list_runs(
|
||||
actor=actor,
|
||||
conversation_id=conversation_id,
|
||||
statuses=[RunStatus.created, RunStatus.running],
|
||||
limit=1,
|
||||
ascending=False,
|
||||
)
|
||||
|
||||
if not active_runs:
|
||||
raise LettaInvalidArgumentError("No active runs found for this conversation.")
|
||||
|
||||
run = active_runs[0]
|
||||
|
||||
if not run.background:
|
||||
raise LettaInvalidArgumentError("Run was not created in background mode, so it cannot be retrieved.")
|
||||
|
||||
if run.created_at < get_utc_time() - timedelta(hours=3):
|
||||
raise LettaExpiredError("Run was created more than 3 hours ago, and is now expired.")
|
||||
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Background streaming requires Redis to be running. "
|
||||
"Please ensure Redis is properly configured. "
|
||||
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
|
||||
),
|
||||
)
|
||||
|
||||
stream = redis_sse_stream_generator(
|
||||
redis_client=redis_client,
|
||||
run_id=run.id,
|
||||
starting_after=request.starting_after if request else None,
|
||||
poll_interval=request.poll_interval if request else None,
|
||||
batch_size=request.batch_size if request else None,
|
||||
)
|
||||
|
||||
if settings.enable_cancellation_aware_streaming:
|
||||
stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=stream,
|
||||
run_manager=server.run_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if request and request.include_pings and settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
|
||||
|
||||
return StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -204,8 +204,6 @@ async def delete_folder(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
folder = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
|
||||
agent_states = await server.source_manager.list_attached_agents(source_id=folder_id, actor=actor)
|
||||
files = await server.file_manager.list_files(folder_id, actor)
|
||||
file_ids = [f.id for f in files]
|
||||
|
||||
if should_use_tpuf():
|
||||
logger.info(f"Deleting folder {folder_id} from Turbopuffer")
|
||||
@@ -218,7 +216,12 @@ async def delete_folder(
|
||||
await delete_source_records_from_pinecone_index(source_id=folder_id, actor=actor)
|
||||
|
||||
for agent_state in agent_states:
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
# Query files_agents directly to get exactly what was attached to this agent
|
||||
file_ids = await server.file_agent_manager.get_file_ids_for_agent_by_source(
|
||||
agent_id=agent_state.id, source_id=folder_id, actor=actor
|
||||
)
|
||||
if file_ids:
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=folder.name, actor=actor)
|
||||
@@ -470,6 +473,30 @@ async def list_files_for_folder(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{folder_id}/files/{file_id}", response_model=FileMetadata, operation_id="retrieve_file")
|
||||
async def retrieve_file(
|
||||
folder_id: FolderId,
|
||||
file_id: FileId,
|
||||
include_content: bool = Query(False, description="Whether to include full file content"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""
|
||||
Retrieve a file from a folder by ID.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# NoResultFound will propagate and be handled as 404 by the global exception handler
|
||||
file_metadata = await server.file_manager.get_file_by_id(
|
||||
file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True
|
||||
)
|
||||
|
||||
if file_metadata.source_id != folder_id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File with id={file_id} not found in folder {folder_id}")
|
||||
|
||||
return file_metadata
|
||||
|
||||
|
||||
# @router.get("/{folder_id}/files/{file_id}", response_model=FileMetadata, operation_id="get_file_metadata")
|
||||
# async def get_file_metadata(
|
||||
# folder_id: str,
|
||||
|
||||
@@ -64,6 +64,7 @@ async def list_runs(
|
||||
deprecated=True,
|
||||
),
|
||||
project_id: Optional[str] = Query(None, description="Filter runs by project ID."),
|
||||
conversation_id: Optional[str] = Query(None, description="Filter runs by conversation ID."),
|
||||
duration_percentile: Optional[int] = Query(
|
||||
None, description="Filter runs by duration percentile (1-100). Returns runs slower than this percentile."
|
||||
),
|
||||
@@ -122,6 +123,7 @@ async def list_runs(
|
||||
step_count_operator=step_count_operator,
|
||||
tools_used=tools_used,
|
||||
project_id=project_id,
|
||||
conversation_id=conversation_id,
|
||||
order_by=order_by,
|
||||
duration_percentile=duration_percentile,
|
||||
duration_filter=duration_filter,
|
||||
|
||||
@@ -40,6 +40,7 @@ async def list_all_messages(
|
||||
order: Literal["asc", "desc"] = Query(
|
||||
"desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
|
||||
),
|
||||
conversation_id: Optional[str] = Query(None, description="Conversation ID to filter messages by"),
|
||||
):
|
||||
"""
|
||||
List messages across all agents for the current user.
|
||||
@@ -51,6 +52,7 @@ async def list_all_messages(
|
||||
limit=limit,
|
||||
reverse=(order == "desc"),
|
||||
return_message_object=False,
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -73,6 +75,8 @@ async def search_all_messages(
|
||||
actor=actor,
|
||||
query_text=request.query,
|
||||
search_mode=request.search_mode,
|
||||
agent_id=request.agent_id,
|
||||
conversation_id=request.conversation_id,
|
||||
limit=request.limit,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
from fastapi import APIRouter, Body, Depends, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.providers import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.validators import ProviderId
|
||||
@@ -39,7 +39,14 @@ async def list_providers(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
providers = await server.provider_manager.list_providers_async(
|
||||
before=before, after=after, limit=limit, actor=actor, name=name, provider_type=provider_type, ascending=(order == "asc")
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
name=name,
|
||||
provider_type=provider_type,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
ascending=(order == "asc"),
|
||||
)
|
||||
return providers
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ async def list_runs(
|
||||
statuses: Optional[List[str]] = Query(None, description="Filter runs by status. Can specify multiple statuses."),
|
||||
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
|
||||
stop_reason: Optional[StopReasonType] = Query(None, description="Filter runs by stop reason."),
|
||||
conversation_id: Optional[str] = Query(None, description="Filter runs by conversation ID."),
|
||||
before: Optional[str] = Query(
|
||||
None, description="Run ID cursor for pagination. Returns runs that come before this run ID in the specified sort order"
|
||||
),
|
||||
@@ -109,6 +110,7 @@ async def list_runs(
|
||||
ascending=sort_ascending,
|
||||
stop_reason=stop_reason,
|
||||
background=background,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
return runs
|
||||
|
||||
|
||||
@@ -184,8 +184,6 @@ async def delete_source(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
files = await server.file_manager.list_files(source_id, actor)
|
||||
file_ids = [f.id for f in files]
|
||||
|
||||
if should_use_tpuf():
|
||||
logger.info(f"Deleting source {source_id} from Turbopuffer")
|
||||
@@ -198,7 +196,12 @@ async def delete_source(
|
||||
await delete_source_records_from_pinecone_index(source_id=source_id, actor=actor)
|
||||
|
||||
for agent_state in agent_states:
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
# Query files_agents directly to get exactly what was attached to this agent
|
||||
file_ids = await server.file_agent_manager.get_file_ids_for_agent_by_source(
|
||||
agent_id=agent_state.id, source_id=source_id, actor=actor
|
||||
)
|
||||
if file_ids:
|
||||
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
|
||||
|
||||
@@ -689,7 +689,6 @@ async def connect_mcp_server(
|
||||
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Create MCP client with respective transport type
|
||||
try:
|
||||
request.resolve_environment_variables()
|
||||
@@ -704,8 +703,18 @@ async def connect_mcp_server(
|
||||
tools = await client.list_tools(serialize=True)
|
||||
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
||||
return
|
||||
except ConnectionError:
|
||||
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
|
||||
except ConnectionError as e:
|
||||
# Only trigger OAuth flow on explicit unauthorized failures
|
||||
unauthorized = False
|
||||
if isinstance(e.__cause__, HTTPStatusError):
|
||||
unauthorized = e.__cause__.response.status_code == 401
|
||||
elif "401" in str(e) or "Unauthorized" in str(e):
|
||||
unauthorized = True
|
||||
|
||||
if not unauthorized:
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
|
||||
return
|
||||
|
||||
if isinstance(client, AsyncStdioMCPClient):
|
||||
logger.warning("OAuth not supported for stdio")
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
|
||||
|
||||
@@ -18,7 +18,13 @@ import letta.system as system
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.errors import HandleNotFoundError, LettaInvalidArgumentError, LettaMCPConnectionError, LettaMCPTimeoutError
|
||||
from letta.errors import (
|
||||
EmbeddingConfigRequiredError,
|
||||
HandleNotFoundError,
|
||||
LettaInvalidArgumentError,
|
||||
LettaMCPConnectionError,
|
||||
LettaMCPTimeoutError,
|
||||
)
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, MCPToolHealth, SSEServerConfig, StdioServerConfig
|
||||
from letta.functions.schema_validator import validate_complete_json_schema
|
||||
from letta.groups.helpers import load_multi_agent
|
||||
@@ -69,6 +75,7 @@ from letta.schemas.providers import (
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
XAIProvider,
|
||||
ZAIProvider,
|
||||
)
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate
|
||||
from letta.schemas.secret import Secret
|
||||
@@ -91,7 +98,8 @@ from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.llm_batch_manager import LLMBatchManager
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
||||
from letta.services.mcp.fastmcp_client import AsyncFastMCPSSEClient
|
||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY
|
||||
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
from letta.services.mcp_server_manager import MCPServerManager
|
||||
@@ -109,7 +117,7 @@ from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task
|
||||
from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task
|
||||
|
||||
config = LettaConfig.load()
|
||||
logger = get_logger(__name__)
|
||||
@@ -203,12 +211,10 @@ class SyncServer(object):
|
||||
"""Initialize the MCP clients (there may be multiple)"""
|
||||
self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {}
|
||||
|
||||
# TODO: Remove these in memory caches
|
||||
self._llm_config_cache = {}
|
||||
self._embedding_config_cache = {}
|
||||
|
||||
# collect providers (always has Letta as a default)
|
||||
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT
|
||||
|
||||
self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)]
|
||||
if model_settings.openai_api_key:
|
||||
self._enabled_providers.append(
|
||||
OpenAIProvider(
|
||||
@@ -316,6 +322,14 @@ class SyncServer(object):
|
||||
api_key_enc=Secret.from_plaintext(model_settings.xai_api_key),
|
||||
)
|
||||
)
|
||||
if model_settings.zai_api_key:
|
||||
self._enabled_providers.append(
|
||||
ZAIProvider(
|
||||
name="zai",
|
||||
api_key_enc=Secret.from_plaintext(model_settings.zai_api_key),
|
||||
base_url=model_settings.zai_base_url,
|
||||
)
|
||||
)
|
||||
if model_settings.openrouter_api_key:
|
||||
self._enabled_providers.append(
|
||||
OpenRouterProvider(
|
||||
@@ -332,6 +346,12 @@ class SyncServer(object):
|
||||
print(f"Default user: {self.default_user} and org: {self.default_org}")
|
||||
await self.tool_manager.upsert_base_tools_async(actor=self.default_user)
|
||||
|
||||
# Sync environment-based providers to database (idempotent, safe for multi-pod startup)
|
||||
await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user)
|
||||
|
||||
# Sync provider models to database
|
||||
await self._sync_provider_models_async()
|
||||
|
||||
# For OSS users, create a local sandbox config
|
||||
oss_default_user = await self.user_manager.get_default_actor_async()
|
||||
use_venv = False if not tool_settings.tool_exec_venv_name else True
|
||||
@@ -368,13 +388,72 @@ class SyncServer(object):
|
||||
force_recreate=True,
|
||||
)
|
||||
|
||||
def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]:
|
||||
"""Find and return an enabled provider by name.
|
||||
|
||||
Args:
|
||||
provider_name: The name of the provider to find
|
||||
|
||||
Returns:
|
||||
The matching enabled provider, or None if not found
|
||||
"""
|
||||
for provider in self._enabled_providers:
|
||||
if provider.name == provider_name:
|
||||
return provider
|
||||
return None
|
||||
|
||||
async def _sync_provider_models_async(self):
|
||||
"""Sync all provider models to database at startup."""
|
||||
logger.info("Syncing provider models to database")
|
||||
|
||||
# Get persisted providers from database (they now have IDs)
|
||||
persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user)
|
||||
|
||||
for persisted_provider in persisted_providers:
|
||||
try:
|
||||
# Find the matching enabled provider instance to call list_models on
|
||||
enabled_provider = self._get_enabled_provider(persisted_provider.name)
|
||||
|
||||
if not enabled_provider:
|
||||
# Only delete base providers that are no longer enabled
|
||||
# BYOK providers are user-created and should not be automatically deleted
|
||||
if persisted_provider.provider_category == ProviderCategory.base:
|
||||
logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database")
|
||||
try:
|
||||
await self.provider_manager.delete_provider_by_id_async(
|
||||
provider_id=persisted_provider.id, actor=self.default_user
|
||||
)
|
||||
except NoResultFound:
|
||||
# Provider was already deleted (race condition in multi-pod startup)
|
||||
logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping")
|
||||
else:
|
||||
logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync")
|
||||
continue
|
||||
|
||||
# Fetch models from provider
|
||||
llm_models = await enabled_provider.list_llm_models_async()
|
||||
embedding_models = await enabled_provider.list_embedding_models_async()
|
||||
|
||||
# Save to database with the persisted provider (which has an ID)
|
||||
await self.provider_manager.sync_provider_models_async(
|
||||
provider=persisted_provider,
|
||||
llm_models=llm_models,
|
||||
embedding_models=embedding_models,
|
||||
organization_id=None, # Global models
|
||||
)
|
||||
logger.info(
|
||||
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True)
|
||||
|
||||
async def init_mcp_clients(self):
|
||||
# TODO: remove this
|
||||
mcp_server_configs = self.get_mcp_servers()
|
||||
|
||||
for server_name, server_config in mcp_server_configs.items():
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
self.mcp_clients[server_name] = AsyncSSEMCPClient(server_config)
|
||||
self.mcp_clients[server_name] = AsyncFastMCPSSEClient(server_config)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
self.mcp_clients[server_name] = AsyncStdioMCPClient(server_config)
|
||||
else:
|
||||
@@ -395,39 +474,6 @@ class SyncServer(object):
|
||||
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
|
||||
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
|
||||
|
||||
@trace_method
|
||||
def get_cached_llm_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
|
||||
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
async def get_cached_llm_config_async(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs)
|
||||
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
def get_cached_embedding_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._embedding_config_cache:
|
||||
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
|
||||
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
# @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs))
|
||||
@trace_method
|
||||
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._embedding_config_cache:
|
||||
self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs)
|
||||
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
async def create_agent_async(
|
||||
self,
|
||||
@@ -461,10 +507,9 @@ class SyncServer(object):
|
||||
"max_reasoning_tokens": request.max_reasoning_tokens,
|
||||
"enable_reasoner": request.enable_reasoner,
|
||||
}
|
||||
config_params.update(additional_config_params)
|
||||
log_event(name="start get_cached_llm_config", attributes=config_params)
|
||||
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
|
||||
log_event(name="end get_cached_llm_config", attributes=config_params)
|
||||
log_event(name="start get_llm_config_from_handle", attributes=config_params)
|
||||
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
|
||||
log_event(name="end get_llm_config_from_handle", attributes=config_params)
|
||||
if request.model and isinstance(request.model, str):
|
||||
assert request.llm_config.handle == request.model, (
|
||||
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
|
||||
@@ -484,19 +529,17 @@ class SyncServer(object):
|
||||
|
||||
if request.embedding_config is None:
|
||||
if request.embedding is None:
|
||||
if settings.default_embedding_handle is None:
|
||||
raise LettaInvalidArgumentError(
|
||||
"Must specify either embedding or embedding_config in request", argument_name="embedding"
|
||||
)
|
||||
else:
|
||||
if settings.default_embedding_handle is not None:
|
||||
request.embedding = settings.default_embedding_handle
|
||||
embedding_config_params = {
|
||||
"handle": request.embedding,
|
||||
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
}
|
||||
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
|
||||
request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params)
|
||||
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
|
||||
# Only resolve embedding config if we have an embedding handle
|
||||
if request.embedding is not None:
|
||||
embedding_config_params = {
|
||||
"handle": request.embedding,
|
||||
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
}
|
||||
log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params)
|
||||
request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params)
|
||||
log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params)
|
||||
|
||||
log_event(name="start create_agent db")
|
||||
main_agent = await self.agent_manager.create_agent_async(
|
||||
@@ -545,9 +588,9 @@ class SyncServer(object):
|
||||
"context_window_limit": request.context_window_limit,
|
||||
"max_tokens": request.max_tokens,
|
||||
}
|
||||
log_event(name="start get_cached_llm_config", attributes=config_params)
|
||||
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
|
||||
log_event(name="end get_cached_llm_config", attributes=config_params)
|
||||
log_event(name="start get_llm_config_from_handle", attributes=config_params)
|
||||
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
|
||||
log_event(name="end get_llm_config_from_handle", attributes=config_params)
|
||||
|
||||
# update with model_settings
|
||||
if request.model_settings is not None:
|
||||
@@ -584,6 +627,8 @@ class SyncServer(object):
|
||||
)
|
||||
|
||||
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_sleeptime_agent")
|
||||
request = CreateAgent(
|
||||
name=main_agent.name + "-sleeptime",
|
||||
agent_type=AgentType.sleeptime_agent,
|
||||
@@ -616,6 +661,8 @@ class SyncServer(object):
|
||||
return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor)
|
||||
|
||||
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_voice_sleeptime_agent")
|
||||
# TODO: Inject system
|
||||
request = CreateAgent(
|
||||
name=main_agent.name + "-sleeptime",
|
||||
@@ -771,6 +818,7 @@ class SyncServer(object):
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
include_err: Optional[bool] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
records = await self.message_manager.list_messages(
|
||||
agent_id=agent_id,
|
||||
@@ -781,6 +829,7 @@ class SyncServer(object):
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
include_err=include_err,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
@@ -816,6 +865,7 @@ class SyncServer(object):
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
include_err: Optional[bool] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
records = await self.message_manager.list_messages(
|
||||
agent_id=None,
|
||||
@@ -826,6 +876,7 @@ class SyncServer(object):
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
include_err=include_err,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
@@ -989,6 +1040,8 @@ class SyncServer(object):
|
||||
async def create_document_sleeptime_agent_async(
|
||||
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
||||
) -> AgentState:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_document_sleeptime_agent")
|
||||
try:
|
||||
block = await self.agent_manager.get_block_with_label_async(agent_id=main_agent.id, block_label=source.name, actor=actor)
|
||||
except:
|
||||
@@ -1043,6 +1096,18 @@ class SyncServer(object):
|
||||
passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor)
|
||||
return passage_count, document_count
|
||||
|
||||
def _get_provider_sort_key(self, model: LLMConfig) -> Tuple[int, str, str]:
|
||||
"""Get sort key for a model: (provider_priority, provider_name, model_name)"""
|
||||
provider_priority = constants.PROVIDER_ORDER.get(model.provider_name, 999)
|
||||
return (provider_priority, model.provider_name or "", model.model or "")
|
||||
|
||||
def _get_embedding_provider_sort_key(self, model: EmbeddingConfig) -> Tuple[int, str, str]:
|
||||
"""Get sort key for an embedding model: (provider_priority, provider_name, model_name)"""
|
||||
# Extract provider name from handle (format: "provider_name/model_name")
|
||||
provider_name = model.handle.split("/")[0] if model.handle and "/" in model.handle else ""
|
||||
provider_priority = constants.PROVIDER_ORDER.get(provider_name, 999)
|
||||
return (provider_priority, provider_name, model.embedding_model or "")
|
||||
|
||||
@trace_method
|
||||
async def list_llm_models_async(
|
||||
self,
|
||||
@@ -1051,73 +1116,121 @@ class SyncServer(object):
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[LLMConfig]:
|
||||
"""Asynchronously list available models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
providers = await self.get_enabled_providers_async(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def get_provider_models(provider: Provider) -> list[LLMConfig]:
|
||||
try:
|
||||
async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS):
|
||||
return await provider.list_llm_models_async()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout while listing LLM models for provider {provider}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while listing LLM models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Execute all provider model listing tasks concurrently
|
||||
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
|
||||
|
||||
# Flatten the results
|
||||
"""List available LLM models - base from DB, BYOK from provider endpoints"""
|
||||
llm_models = []
|
||||
for models in provider_results:
|
||||
llm_models.extend(models)
|
||||
|
||||
# Get local configs - if this is potentially slow, consider making it async too
|
||||
local_configs = self.get_local_llm_configs()
|
||||
llm_models.extend(local_configs)
|
||||
# Determine which categories to include
|
||||
include_base = not provider_category or ProviderCategory.base in provider_category
|
||||
include_byok = not provider_category or ProviderCategory.byok in provider_category
|
||||
|
||||
# dedupe by handle for uniqueness
|
||||
# Seems like this is required from the tests?
|
||||
seen_handles = set()
|
||||
unique_models = []
|
||||
for model in llm_models:
|
||||
if model.handle not in seen_handles:
|
||||
seen_handles.add(model.handle)
|
||||
unique_models.append(model)
|
||||
# Get base provider models from database
|
||||
if include_base:
|
||||
provider_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="llm",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
return unique_models
|
||||
# Build LLMConfig objects from database
|
||||
provider_cache: Dict[str, Provider] = {}
|
||||
for model in provider_models:
|
||||
# Get provider details (with caching to avoid N+1 queries)
|
||||
if model.provider_id not in provider_cache:
|
||||
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
||||
provider = provider_cache[model.provider_id]
|
||||
|
||||
# Skip non-base providers (they're handled separately)
|
||||
if provider.provider_category != ProviderCategory.base:
|
||||
continue
|
||||
|
||||
# Apply provider_name/provider_type filters if specified
|
||||
if provider_name and provider.name != provider_name:
|
||||
continue
|
||||
if provider_type and provider.provider_type != provider_type:
|
||||
continue
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
context_window=model.max_context_window or 16384,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
provider_category=provider.provider_category,
|
||||
)
|
||||
llm_models.append(llm_config)
|
||||
|
||||
# Get BYOK provider models by hitting provider endpoints directly
|
||||
if include_byok:
|
||||
byok_providers = await self.provider_manager.list_providers_async(
|
||||
actor=actor,
|
||||
name=provider_name,
|
||||
provider_type=provider_type,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_llm_models_async()
|
||||
llm_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
|
||||
|
||||
# Sort by provider order (matching old _enabled_providers order), then by model name
|
||||
llm_models.sort(key=self._get_provider_sort_key)
|
||||
|
||||
return llm_models
|
||||
|
||||
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""Asynchronously list available embedding models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
# Get all eligible providers first
|
||||
providers = await self.get_enabled_providers_async(actor=actor)
|
||||
|
||||
# Fetch embedding models from each provider concurrently
|
||||
async def get_provider_embedding_models(provider):
|
||||
try:
|
||||
# All providers now have list_embedding_models_async
|
||||
return await provider.list_embedding_models_async()
|
||||
except Exception as e:
|
||||
logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Execute all provider model listing tasks concurrently
|
||||
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
|
||||
|
||||
# Flatten the results
|
||||
"""List available embedding models - base from DB, BYOK from provider endpoints"""
|
||||
embedding_models = []
|
||||
for models in provider_results:
|
||||
embedding_models.extend(models)
|
||||
|
||||
# Get base provider models from database
|
||||
provider_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="embedding",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Build EmbeddingConfig objects from database (base providers only)
|
||||
provider_cache: Dict[str, Provider] = {}
|
||||
for model in provider_models:
|
||||
# Get provider details (with caching to avoid N+1 queries)
|
||||
if model.provider_id not in provider_cache:
|
||||
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
||||
provider = provider_cache[model.provider_id]
|
||||
|
||||
# Skip non-base providers (they're handled separately)
|
||||
if provider.provider_category != ProviderCategory.base:
|
||||
continue
|
||||
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_model=model.name,
|
||||
embedding_endpoint_type=model.model_endpoint_type,
|
||||
embedding_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
embedding_dim=model.embedding_dim or 1536,
|
||||
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=model.handle,
|
||||
)
|
||||
embedding_models.append(embedding_config)
|
||||
|
||||
# Get BYOK provider models by hitting provider endpoints directly
|
||||
byok_providers = await self.provider_manager.list_providers_async(
|
||||
actor=actor,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_embedding_models_async()
|
||||
embedding_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")
|
||||
|
||||
# Sort by provider order (matching old _enabled_providers order), then by model name
|
||||
embedding_models.sort(key=self._get_embedding_provider_sort_key)
|
||||
|
||||
return embedding_models
|
||||
|
||||
@@ -1128,25 +1241,17 @@ class SyncServer(object):
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[Provider]:
|
||||
providers = []
|
||||
if not provider_category or ProviderCategory.base in provider_category:
|
||||
providers_from_env = [p for p in self._enabled_providers]
|
||||
providers.extend(providers_from_env)
|
||||
# Query all persisted providers from database
|
||||
persisted_providers = await self.provider_manager.list_providers_async(
|
||||
name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
providers = [p.cast_to_subtype() for p in persisted_providers]
|
||||
|
||||
if not provider_category or ProviderCategory.byok in provider_category:
|
||||
providers_from_db = await self.provider_manager.list_providers_async(
|
||||
name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
providers_from_db = [p.cast_to_subtype() for p in providers_from_db if p.provider_category == ProviderCategory.byok]
|
||||
providers.extend(providers_from_db)
|
||||
|
||||
if provider_name is not None:
|
||||
providers = [p for p in providers if p.name == provider_name]
|
||||
|
||||
if provider_type is not None:
|
||||
providers = [p for p in providers if p.provider_type == provider_type]
|
||||
# Filter by category if specified
|
||||
if provider_category:
|
||||
providers = [p for p in providers if p.provider_category in provider_category]
|
||||
|
||||
return providers
|
||||
|
||||
@@ -1160,32 +1265,19 @@ class SyncServer(object):
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
enable_reasoner: Optional[bool] = None,
|
||||
) -> LLMConfig:
|
||||
# Use provider_manager to get LLMConfig from handle
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = await self.get_provider_from_name_async(provider_name, actor)
|
||||
|
||||
all_llm_configs = await provider.list_llm_models_async()
|
||||
llm_configs = [config for config in all_llm_configs if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in all_llm_configs if config.model == model_name]
|
||||
if not llm_configs:
|
||||
available_handles = [config.handle for config in all_llm_configs]
|
||||
raise HandleNotFoundError(handle, available_handles)
|
||||
except ValueError as e:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
|
||||
if not llm_configs:
|
||||
raise e
|
||||
|
||||
if len(llm_configs) == 1:
|
||||
llm_config = llm_configs[0]
|
||||
elif len(llm_configs) > 1:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name"
|
||||
llm_config = await self.provider_manager.get_llm_config_from_handle(
|
||||
handle=handle,
|
||||
actor=actor,
|
||||
)
|
||||
else:
|
||||
llm_config = llm_configs[0]
|
||||
except Exception as e:
|
||||
# Convert to HandleNotFoundError for backwards compatibility
|
||||
from letta.orm.errors import NoResultFound
|
||||
|
||||
if isinstance(e, NoResultFound):
|
||||
raise HandleNotFoundError(handle, [])
|
||||
raise
|
||||
|
||||
if context_window_limit is not None:
|
||||
if context_window_limit > llm_config.context_window:
|
||||
@@ -1217,33 +1309,22 @@ class SyncServer(object):
|
||||
async def get_embedding_config_from_handle_async(
|
||||
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
) -> EmbeddingConfig:
|
||||
# Use provider_manager to get EmbeddingConfig from handle
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = await self.get_provider_from_name_async(provider_name, actor)
|
||||
|
||||
all_embedding_configs = await provider.list_embedding_models_async()
|
||||
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name"
|
||||
)
|
||||
except LettaInvalidArgumentError as e:
|
||||
# search local configs
|
||||
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
raise e
|
||||
|
||||
if len(embedding_configs) == 1:
|
||||
embedding_config = embedding_configs[0]
|
||||
elif len(embedding_configs) > 1:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name"
|
||||
embedding_config = await self.provider_manager.get_embedding_config_from_handle(
|
||||
handle=handle,
|
||||
actor=actor,
|
||||
)
|
||||
else:
|
||||
embedding_config = embedding_configs[0]
|
||||
except Exception as e:
|
||||
# Convert to LettaInvalidArgumentError for backwards compatibility
|
||||
from letta.orm.errors import NoResultFound
|
||||
|
||||
if embedding_chunk_size:
|
||||
embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
if isinstance(e, NoResultFound):
|
||||
raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle")
|
||||
raise
|
||||
|
||||
# Override chunk size if provided
|
||||
embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
|
||||
return embedding_config
|
||||
|
||||
@@ -1252,57 +1333,17 @@ class SyncServer(object):
|
||||
providers = [provider for provider in all_providers if provider.name == provider_name]
|
||||
if not providers:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in self._enabled_providers])})",
|
||||
f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in all_providers])})",
|
||||
argument_name="provider_name",
|
||||
)
|
||||
elif len(providers) > 1:
|
||||
logger.warning(f"Multiple providers with name {provider_name} supported", argument_name="provider_name")
|
||||
logger.warning(f"Multiple providers with name {provider_name} supported")
|
||||
provider = providers[0]
|
||||
else:
|
||||
provider = providers[0]
|
||||
|
||||
return provider
|
||||
|
||||
def get_local_llm_configs(self):
|
||||
llm_models = []
|
||||
# NOTE: deprecated
|
||||
# try:
|
||||
# llm_configs_dir = os.path.expanduser("~/.letta/llm_configs")
|
||||
# if os.path.exists(llm_configs_dir):
|
||||
# for filename in os.listdir(llm_configs_dir):
|
||||
# if filename.endswith(".json"):
|
||||
# filepath = os.path.join(llm_configs_dir, filename)
|
||||
# try:
|
||||
# with open(filepath, "r") as f:
|
||||
# config_data = json.load(f)
|
||||
# llm_config = LLMConfig(**config_data)
|
||||
# llm_models.append(llm_config)
|
||||
# except (json.JSONDecodeError, ValueError) as e:
|
||||
# logger.warning(f"Error parsing LLM config file {filename}: {e}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error reading LLM configs directory: {e}")
|
||||
return llm_models
|
||||
|
||||
def get_local_embedding_configs(self):
|
||||
embedding_models = []
|
||||
# NOTE: deprecated
|
||||
# try:
|
||||
# embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs")
|
||||
# if os.path.exists(embedding_configs_dir):
|
||||
# for filename in os.listdir(embedding_configs_dir):
|
||||
# if filename.endswith(".json"):
|
||||
# filepath = os.path.join(embedding_configs_dir, filename)
|
||||
# try:
|
||||
# with open(filepath, "r") as f:
|
||||
# config_data = json.load(f)
|
||||
# embedding_config = EmbeddingConfig(**config_data)
|
||||
# embedding_models.append(embedding_config)
|
||||
# except (json.JSONDecodeError, ValueError) as e:
|
||||
# logger.warning(f"Error parsing embedding config file {filename}: {e}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error reading embedding configs directory: {e}")
|
||||
return embedding_models
|
||||
|
||||
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
||||
"""Add a new LLM model"""
|
||||
|
||||
@@ -1533,7 +1574,7 @@ class SyncServer(object):
|
||||
|
||||
# Attempt to initialize the connection to the server
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
new_mcp_client = AsyncSSEMCPClient(server_config)
|
||||
new_mcp_client = AsyncFastMCPSSEClient(server_config)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
new_mcp_client = AsyncStdioMCPClient(server_config)
|
||||
else:
|
||||
|
||||
@@ -25,7 +25,7 @@ from letta.constants import (
|
||||
INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES,
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
|
||||
)
|
||||
from letta.errors import LettaAgentNotFoundError
|
||||
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
@@ -60,6 +60,7 @@ from letta.schemas.agent import (
|
||||
from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentType, PrimitiveType, ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable as PydanticAgentEnvVar
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.group import Group as PydanticGroup, ManagerType
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
@@ -111,7 +112,13 @@ from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import DatabaseChoice, model_settings, settings
|
||||
from letta.utils import calculate_file_defaults_based_on_context_window, enforce_types, united_diff
|
||||
from letta.utils import (
|
||||
bounded_gather,
|
||||
calculate_file_defaults_based_on_context_window,
|
||||
decrypt_agent_secrets,
|
||||
enforce_types,
|
||||
united_diff,
|
||||
)
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -336,8 +343,8 @@ class AgentManager:
|
||||
ignore_invalid_tools: bool = False,
|
||||
) -> PydanticAgentState:
|
||||
# validate required configs
|
||||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||||
raise ValueError("llm_config and embedding_config are required")
|
||||
if not agent_create.llm_config:
|
||||
raise ValueError("llm_config is required")
|
||||
|
||||
# For v1 agents, enforce sane defaults even when reasoning is omitted
|
||||
if agent_create.agent_type == AgentType.letta_v1_agent:
|
||||
@@ -556,19 +563,18 @@ class AgentManager:
|
||||
agent_secrets = agent_create.secrets or agent_create.tool_exec_environment_variables
|
||||
|
||||
if agent_secrets:
|
||||
# Encrypt environment variable values
|
||||
env_rows = []
|
||||
for key, val in agent_secrets.items():
|
||||
# Encrypt value (Secret.from_plaintext handles missing encryption key internally)
|
||||
value_secret = Secret.from_plaintext(val)
|
||||
row = {
|
||||
# Encrypt environment variable values concurrently (async to avoid blocking event loop)
|
||||
secrets_dict = await Secret.from_plaintexts_async(agent_secrets)
|
||||
env_rows = [
|
||||
{
|
||||
"agent_id": aid,
|
||||
"key": key,
|
||||
"value": "", # Empty string for NOT NULL constraint (deprecated, use value_enc)
|
||||
"value_enc": value_secret.get_encrypted(),
|
||||
"value_enc": secret.get_encrypted(),
|
||||
"organization_id": actor.organization_id,
|
||||
}
|
||||
env_rows.append(row)
|
||||
for key, secret in secrets_dict.items()
|
||||
]
|
||||
|
||||
result = await session.execute(insert(AgentEnvironmentVariable).values(env_rows).returning(AgentEnvironmentVariable.id))
|
||||
env_rows = [{**row, "id": env_var_id} for row, env_var_id in zip(env_rows, result.scalars().all())]
|
||||
@@ -588,8 +594,10 @@ class AgentManager:
|
||||
result = await new_agent.to_pydantic_async(include_relationships=include_relationships)
|
||||
|
||||
if agent_secrets and env_rows:
|
||||
result.tool_exec_environment_variables = [AgentEnvironmentVariable(**row) for row in env_rows]
|
||||
result.secrets = [AgentEnvironmentVariable(**row) for row in env_rows]
|
||||
# Use Pydantic schema (not ORM model) with plaintext to avoid sync decryption in model validator
|
||||
env_vars = [PydanticAgentEnvVar(**{**row, "value": agent_secrets[row["key"]]}) for row in env_rows]
|
||||
result.tool_exec_environment_variables = env_vars
|
||||
result.secrets = env_vars
|
||||
|
||||
# initial message sequence (skip if _init_with_no_messages is True)
|
||||
if not _init_with_no_messages:
|
||||
@@ -824,7 +832,7 @@ class AgentManager:
|
||||
)
|
||||
session.expire(agent, ["tags"])
|
||||
|
||||
agent_secrets = agent_update.secrets or agent_update.tool_exec_environment_variables
|
||||
agent_secrets = agent_update.secrets if agent_update.secrets is not None else agent_update.tool_exec_environment_variables
|
||||
if agent_secrets is not None:
|
||||
# Fetch existing environment variables to check if values changed
|
||||
result = await session.execute(select(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
|
||||
@@ -832,25 +840,35 @@ class AgentManager:
|
||||
|
||||
# TODO: do we need to delete each time or can we just upsert?
|
||||
await session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
|
||||
# Encrypt environment variable values
|
||||
# Only re-encrypt if the value has actually changed
|
||||
|
||||
# Decrypt existing values to check for changes (async to avoid blocking)
|
||||
existing_values: dict[str, str | None] = {}
|
||||
for k, existing_env in existing_env_vars.items():
|
||||
if existing_env.value_enc:
|
||||
existing_secret = Secret.from_encrypted(existing_env.value_enc)
|
||||
existing_values[k] = await existing_secret.get_plaintext_async()
|
||||
else:
|
||||
existing_values[k] = None
|
||||
|
||||
# Identify values that need encryption (new or changed)
|
||||
to_encrypt = {
|
||||
k: v
|
||||
for k, v in agent_secrets.items()
|
||||
if k not in existing_env_vars or existing_values.get(k) != v or not existing_env_vars[k].value_enc
|
||||
}
|
||||
|
||||
# Batch encrypt new/changed values concurrently (async to avoid blocking event loop)
|
||||
new_secrets = await Secret.from_plaintexts_async(to_encrypt) if to_encrypt else {}
|
||||
|
||||
# Build rows, reusing existing encrypted values where unchanged
|
||||
env_rows = []
|
||||
for k, v in agent_secrets.items():
|
||||
# Check if value changed to avoid unnecessary re-encryption
|
||||
existing_env = existing_env_vars.get(k)
|
||||
existing_value = None
|
||||
if existing_env and existing_env.value_enc:
|
||||
existing_secret = Secret.from_encrypted(existing_env.value_enc)
|
||||
existing_value = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Encrypt value (reuse existing encrypted value if unchanged)
|
||||
if existing_value == v and existing_env and existing_env.value_enc:
|
||||
# Value unchanged, reuse existing encrypted value
|
||||
value_enc = existing_env.value_enc
|
||||
if k in new_secrets:
|
||||
# New or changed value - use newly encrypted value
|
||||
value_enc = new_secrets[k].get_encrypted()
|
||||
else:
|
||||
# Value changed or new, encrypt
|
||||
value_secret = Secret.from_plaintext(v)
|
||||
value_enc = value_secret.get_encrypted()
|
||||
# Value unchanged - reuse existing encrypted value
|
||||
value_enc = existing_env_vars[k].value_enc
|
||||
|
||||
row = {
|
||||
"agent_id": aid,
|
||||
@@ -875,7 +893,11 @@ class AgentManager:
|
||||
await session.flush()
|
||||
await session.refresh(agent)
|
||||
|
||||
return await agent.to_pydantic_async()
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -899,7 +921,8 @@ class AgentManager:
|
||||
agent.message_ids = message_ids
|
||||
|
||||
await agent.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@trace_method
|
||||
async def list_agents_async(
|
||||
@@ -969,10 +992,16 @@ class AgentManager:
|
||||
query = query.limit(limit)
|
||||
result = await session.execute(query)
|
||||
agents = result.scalars().all()
|
||||
return await asyncio.gather(
|
||||
*[agent.to_pydantic_async(include_relationships=include_relationships, include=include) for agent in agents]
|
||||
|
||||
# Convert to pydantic without decrypting (keeps encrypted values)
|
||||
# This allows us to release the DB connection before expensive PBKDF2 operations
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False) for agent in agents]
|
||||
)
|
||||
|
||||
# DB session released - now decrypt secrets outside session to prevent connection holding
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@trace_method
|
||||
async def count_agents_async(
|
||||
self,
|
||||
@@ -1067,7 +1096,12 @@ class AgentManager:
|
||||
|
||||
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
|
||||
result = await session.execute(query)
|
||||
return await asyncio.gather(*[agent.to_pydantic_async() for agent in result.scalars()])
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather([agent.to_pydantic_async(decrypt=False) for agent in result.scalars()])
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@trace_method
|
||||
async def size_async(
|
||||
@@ -1092,8 +1126,8 @@ class AgentManager:
|
||||
) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
try:
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(AgentModel)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
query = query.where(AgentModel.id == agent_id)
|
||||
@@ -1105,13 +1139,17 @@ class AgentManager:
|
||||
if agent is None:
|
||||
raise NoResultFound(f"Agent with ID {agent_id} not found")
|
||||
|
||||
return await agent.to_pydantic_async(include_relationships=include_relationships, include=include)
|
||||
except NoResultFound:
|
||||
# Re-raise NoResultFound without logging to preserve 404 handling
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
|
||||
raise
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
except NoResultFound:
|
||||
# Re-raise NoResultFound without logging to preserve 404 handling
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -1122,8 +1160,8 @@ class AgentManager:
|
||||
include_relationships: Optional[List[str]] = None,
|
||||
) -> list[PydanticAgentState]:
|
||||
"""Fetch a list of agents by their IDs."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
try:
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(AgentModel)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
query = query.where(AgentModel.id.in_(agent_ids))
|
||||
@@ -1136,10 +1174,16 @@ class AgentManager:
|
||||
logger.warning(f"No agents found with IDs: {agent_ids}")
|
||||
return []
|
||||
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
|
||||
raise
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=include_relationships, decrypt=False) for agent in agents]
|
||||
)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -1216,7 +1260,8 @@ class AgentManager:
|
||||
await session.commit()
|
||||
for agent in agents_to_delete:
|
||||
await session.delete(agent)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception(f"Failed to hard delete Agent with ID {agent_id}")
|
||||
@@ -1357,7 +1402,7 @@ class AgentManager:
|
||||
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
|
||||
if agent_state.message_ids == []:
|
||||
if not agent_state.message_ids: # Handles both None and empty list
|
||||
curr_system_message = None
|
||||
else:
|
||||
curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||||
@@ -1694,7 +1739,12 @@ class AgentManager:
|
||||
agent = await agent.update_async(session, actor=actor)
|
||||
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
|
||||
# or even better, not refresh at all.
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -1856,7 +1906,12 @@ class AgentManager:
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
|
||||
# or even better, not refresh at all.
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block management
|
||||
@@ -1888,25 +1943,25 @@ class AgentManager:
|
||||
) -> PydanticBlock:
|
||||
"""Gets a block attached to an agent by its label."""
|
||||
async with db_registry.async_session() as session:
|
||||
block = None
|
||||
matched_block = None
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
for block in agent.core_memory:
|
||||
if block.label == block_label:
|
||||
block = block
|
||||
matched_block = block
|
||||
break
|
||||
if not block:
|
||||
if not matched_block:
|
||||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||||
|
||||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# Validate limit constraints before updating
|
||||
validate_block_limit_constraint(update_data, block)
|
||||
validate_block_limit_constraint(update_data, matched_block)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(block, key, value)
|
||||
setattr(matched_block, key, value)
|
||||
|
||||
await block.update_async(session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
await matched_block.update_async(session, actor=actor)
|
||||
return matched_block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -1944,7 +1999,11 @@ class AgentManager:
|
||||
# TODO: I have too many things rn so lets look at this later
|
||||
# await session.commit()
|
||||
|
||||
return await agent.to_pydantic_async()
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -1965,7 +2024,12 @@ class AgentManager:
|
||||
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||||
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
# ======================================================================================================================
|
||||
# Passage Management
|
||||
@@ -2400,7 +2464,7 @@ class AgentManager:
|
||||
# Use ISO format if no timezone is set
|
||||
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
|
||||
|
||||
result_dict = {"timestamp": formatted_timestamp, "content": passage.text, "tags": passage.tags or []}
|
||||
result_dict = {"id": passage.id, "timestamp": formatted_timestamp, "content": passage.text, "tags": passage.tags or []}
|
||||
|
||||
# Add relevance metadata if available
|
||||
if metadata:
|
||||
@@ -2570,7 +2634,8 @@ class AgentManager:
|
||||
agent.tool_rules = tool_rules
|
||||
session.add(agent)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -2643,7 +2708,8 @@ class AgentManager:
|
||||
else:
|
||||
logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}")
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -2767,7 +2833,8 @@ class AgentManager:
|
||||
else:
|
||||
logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}")
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -2804,7 +2871,8 @@ class AgentManager:
|
||||
else:
|
||||
logger.info(f"Detached all {detached_count} tools from agent {agent_id}")
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -2832,7 +2900,8 @@ class AgentManager:
|
||||
|
||||
agent.tool_rules = tool_rules
|
||||
session.add(agent)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -3067,10 +3136,14 @@ class AgentManager:
|
||||
# Generate visible content for each file
|
||||
line_chunker = LineChunker()
|
||||
visible_content_map = {}
|
||||
for file_metadata in file_metadata_with_content:
|
||||
for i, file_metadata in enumerate(file_metadata_with_content):
|
||||
content_lines = line_chunker.chunk_text(file_metadata=file_metadata)
|
||||
visible_content_map[file_metadata.file_name] = "\n".join(content_lines)
|
||||
|
||||
# Yield to event loop every 100 files to prevent saturation
|
||||
if i > 0 and i % 100 == 0:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Use bulk attach to avoid race conditions and duplicate LRU eviction decisions
|
||||
closed_files = await self.file_agent_manager.attach_files_bulk(
|
||||
agent_id=agent_state.id,
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, or_, select
|
||||
|
||||
from letta.errors import EmbeddingConfigRequiredError
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
|
||||
@@ -17,7 +18,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -191,7 +192,8 @@ class ArchiveManager:
|
||||
is_owner=is_owner,
|
||||
)
|
||||
session.add(archives_agents)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -224,7 +226,8 @@ class ArchiveManager:
|
||||
else:
|
||||
logger.info(f"Detached agent {agent_id} from archive {archive_id}")
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -431,6 +434,8 @@ class ArchiveManager:
|
||||
return archive
|
||||
|
||||
# Create a default archive for this agent
|
||||
if agent_state.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="create_default_archive")
|
||||
archive_name = f"{agent_state.name}'s Archive"
|
||||
archive = await self.create_archive_async(
|
||||
name=archive_name,
|
||||
@@ -549,8 +554,13 @@ class ArchiveManager:
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
agents = await asyncio.gather(*[agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents_orm])
|
||||
return agents
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
|
||||
)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -609,6 +619,7 @@ class ArchiveManager:
|
||||
|
||||
# update the archive with the namespace
|
||||
await session.execute(update(ArchiveModel).where(ArchiveModel.id == archive_id).values(_vector_db_namespace=namespace_name))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return namespace_name
|
||||
|
||||
@@ -19,7 +19,7 @@ from letta.schemas.enums import ActorType, PrimitiveType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -100,7 +100,8 @@ class BlockManager:
|
||||
block = BlockModel(**data, organization_id=actor.organization_id)
|
||||
await block.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_block = block.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_block
|
||||
|
||||
@enforce_types
|
||||
@@ -119,18 +120,19 @@ class BlockManager:
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Validate all blocks before creating any
|
||||
validated_data = []
|
||||
for block in blocks:
|
||||
block_data = block.model_dump(to_orm=True, exclude_none=True)
|
||||
validate_block_creation(block_data)
|
||||
validated_data.append(block_data)
|
||||
|
||||
block_models = [
|
||||
BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks
|
||||
]
|
||||
block_models = [BlockModel(**data, organization_id=actor.organization_id) for data in validated_data]
|
||||
created_models = await BlockModel.batch_create_async(
|
||||
items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True
|
||||
)
|
||||
result = [m.to_pydantic() for m in created_models]
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@@ -150,7 +152,8 @@ class BlockManager:
|
||||
|
||||
await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_block = block.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_block
|
||||
|
||||
@enforce_types
|
||||
@@ -502,10 +505,13 @@ class BlockManager:
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
agents = await asyncio.gather(
|
||||
*[agent.to_pydantic_async(include_relationships=include_relationships, include=include) for agent in agents_orm]
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
|
||||
)
|
||||
return agents
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -591,7 +597,8 @@ class BlockManager:
|
||||
new_val = new_val[: block.limit]
|
||||
block.value = new_val
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
if return_hydrated:
|
||||
# TODO: implement for async
|
||||
@@ -669,7 +676,8 @@ class BlockManager:
|
||||
|
||||
# 7) Flush changes, then commit once
|
||||
block = await block.update_async(db_session=session, actor=actor, no_commit=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return block.to_pydantic()
|
||||
|
||||
@@ -757,7 +765,8 @@ class BlockManager:
|
||||
block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor)
|
||||
|
||||
# 4) Commit
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -805,5 +814,6 @@ class BlockManager:
|
||||
|
||||
block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return block.to_pydantic()
|
||||
|
||||
357
letta/services/conversation_manager.py
Normal file
357
letta/services/conversation_manager.py
Normal file
@@ -0,0 +1,357 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from letta.orm.conversation import Conversation as ConversationModel
|
||||
from letta.orm.conversation_messages import ConversationMessage as ConversationMessageModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.conversation import Conversation as PydanticConversation, CreateConversation, UpdateConversation
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""Manager class to handle business logic related to Conversations."""
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_conversation(
|
||||
self,
|
||||
agent_id: str,
|
||||
conversation_create: CreateConversation,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticConversation:
|
||||
"""Create a new conversation for an agent."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversation = ConversationModel(
|
||||
agent_id=agent_id,
|
||||
summary=conversation_create.summary,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
await conversation.create_async(session, actor=actor)
|
||||
return conversation.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_conversation_by_id(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticConversation:
|
||||
"""Retrieve a conversation by its ID, including in-context message IDs."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversation = await ConversationModel.read_async(
|
||||
db_session=session,
|
||||
identifier=conversation_id,
|
||||
actor=actor,
|
||||
check_is_deleted=True,
|
||||
)
|
||||
|
||||
# Get the in-context message IDs for this conversation
|
||||
message_ids = await self.get_message_ids_for_conversation(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Build the pydantic model with in_context_message_ids
|
||||
pydantic_conversation = conversation.to_pydantic()
|
||||
pydantic_conversation.in_context_message_ids = message_ids
|
||||
return pydantic_conversation
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_conversations(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: int = 50,
|
||||
after: Optional[str] = None,
|
||||
) -> List[PydanticConversation]:
|
||||
"""List conversations for an agent with cursor-based pagination."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversations = await ConversationModel.list_async(
|
||||
db_session=session,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
after=after,
|
||||
ascending=False,
|
||||
)
|
||||
return [conv.to_pydantic() for conv in conversations]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation_update: UpdateConversation,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticConversation:
|
||||
"""Update a conversation."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversation = await ConversationModel.read_async(
|
||||
db_session=session,
|
||||
identifier=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Set attributes on the model
|
||||
update_data = conversation_update.model_dump(exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(conversation, key, value)
|
||||
|
||||
# Commit the update
|
||||
updated_conversation = await conversation.update_async(
|
||||
db_session=session,
|
||||
actor=actor,
|
||||
)
|
||||
return updated_conversation.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> None:
|
||||
"""Soft delete a conversation."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversation = await ConversationModel.read_async(
|
||||
db_session=session,
|
||||
identifier=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
# Soft delete by setting is_deleted flag
|
||||
conversation.is_deleted = True
|
||||
await conversation.update_async(db_session=session, actor=actor)
|
||||
|
||||
# ==================== Message Management Methods ====================
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_message_ids_for_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get ordered message IDs for a conversation.
|
||||
|
||||
Returns message IDs ordered by position in the conversation.
|
||||
Only returns messages that are currently in_context.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = (
|
||||
select(ConversationMessageModel.message_id)
|
||||
.where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
ConversationMessageModel.in_context == True,
|
||||
ConversationMessageModel.is_deleted == False,
|
||||
)
|
||||
.order_by(ConversationMessageModel.position)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_messages_for_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Get ordered Message objects for a conversation.
|
||||
|
||||
Returns full Message objects ordered by position in the conversation.
|
||||
Only returns messages that are currently in_context.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = (
|
||||
select(MessageModel)
|
||||
.join(
|
||||
ConversationMessageModel,
|
||||
MessageModel.id == ConversationMessageModel.message_id,
|
||||
)
|
||||
.where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
ConversationMessageModel.in_context == True,
|
||||
ConversationMessageModel.is_deleted == False,
|
||||
)
|
||||
.order_by(ConversationMessageModel.position)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return [msg.to_pydantic() for msg in result.scalars().all()]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def add_messages_to_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
agent_id: str,
|
||||
message_ids: List[str],
|
||||
actor: PydanticUser,
|
||||
starting_position: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add messages to a conversation's tracking table.
|
||||
|
||||
Creates ConversationMessage entries with auto-incrementing positions.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation to add messages to
|
||||
agent_id: The agent ID
|
||||
message_ids: List of message IDs to add
|
||||
actor: The user performing the action
|
||||
starting_position: Optional starting position (defaults to next available)
|
||||
"""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Get starting position if not provided
|
||||
if starting_position is None:
|
||||
query = select(func.coalesce(func.max(ConversationMessageModel.position), -1)).where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
max_position = result.scalar()
|
||||
# Use explicit None check instead of `or` to handle position=0 correctly
|
||||
if max_position is None:
|
||||
max_position = -1
|
||||
starting_position = max_position + 1
|
||||
|
||||
# Create ConversationMessage entries
|
||||
for i, message_id in enumerate(message_ids):
|
||||
conv_msg = ConversationMessageModel(
|
||||
conversation_id=conversation_id,
|
||||
agent_id=agent_id,
|
||||
message_id=message_id,
|
||||
position=starting_position + i,
|
||||
in_context=True,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
session.add(conv_msg)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_in_context_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
in_context_message_ids: List[str],
|
||||
actor: PydanticUser,
|
||||
) -> None:
|
||||
"""
|
||||
Update which messages are in context for a conversation.
|
||||
|
||||
Sets in_context=True for messages in the list, False for others.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation to update
|
||||
in_context_message_ids: List of message IDs that should be in context
|
||||
actor: The user performing the action
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Get all conversation messages for this conversation
|
||||
query = select(ConversationMessageModel).where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
ConversationMessageModel.is_deleted == False,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
conv_messages = result.scalars().all()
|
||||
|
||||
# Update in_context status
|
||||
in_context_set = set(in_context_message_ids)
|
||||
for conv_msg in conv_messages:
|
||||
conv_msg.in_context = conv_msg.message_id in in_context_set
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_conversation_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: Optional[int] = 100,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> List[LettaMessage]:
|
||||
"""
|
||||
List all messages in a conversation with pagination support.
|
||||
|
||||
Unlike get_messages_for_conversation, this returns ALL messages
|
||||
(not just in_context) and supports cursor-based pagination.
|
||||
Messages are always ordered by position (oldest first).
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation to list messages for
|
||||
actor: The user performing the action
|
||||
limit: Maximum number of messages to return
|
||||
before: Return messages before this message ID
|
||||
after: Return messages after this message ID
|
||||
|
||||
Returns:
|
||||
List of LettaMessage objects
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Build base query joining Message with ConversationMessage
|
||||
query = (
|
||||
select(MessageModel)
|
||||
.join(
|
||||
ConversationMessageModel,
|
||||
MessageModel.id == ConversationMessageModel.message_id,
|
||||
)
|
||||
.where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
ConversationMessageModel.is_deleted == False,
|
||||
)
|
||||
)
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if before:
|
||||
# Get the position of the cursor message
|
||||
cursor_query = select(ConversationMessageModel.position).where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.message_id == before,
|
||||
)
|
||||
cursor_result = await session.execute(cursor_query)
|
||||
cursor_position = cursor_result.scalar_one_or_none()
|
||||
if cursor_position is not None:
|
||||
query = query.where(ConversationMessageModel.position < cursor_position)
|
||||
|
||||
if after:
|
||||
# Get the position of the cursor message
|
||||
cursor_query = select(ConversationMessageModel.position).where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.message_id == after,
|
||||
)
|
||||
cursor_result = await session.execute(cursor_query)
|
||||
cursor_position = cursor_result.scalar_one_or_none()
|
||||
if cursor_position is not None:
|
||||
query = query.where(ConversationMessageModel.position > cursor_position)
|
||||
|
||||
# Order by position (oldest first)
|
||||
query = query.order_by(ConversationMessageModel.position.asc())
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
messages = [msg.to_pydantic() for msg in result.scalars().all()]
|
||||
|
||||
# Convert to LettaMessages
|
||||
return PydanticMessage.to_letta_messages_from_list(messages, reverse=False, text_is_assistant_message=True)
|
||||
@@ -22,7 +22,7 @@ from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, S
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.utils import bounded_gather, enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -85,7 +85,7 @@ class FileManager:
|
||||
# invalidate cache for this new file
|
||||
await self._invalidate_file_caches(file_orm.id, actor, file_orm.original_file_name, file_orm.source_id)
|
||||
|
||||
return await file_orm.to_pydantic_async()
|
||||
return file_orm.to_pydantic()
|
||||
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
@@ -124,14 +124,20 @@ class FileManager:
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
file_orm = result.scalar_one()
|
||||
file_orm = result.scalar_one_or_none()
|
||||
else:
|
||||
# fast path (metadata only)
|
||||
file_orm = await FileMetadataModel.read_async(
|
||||
db_session=session,
|
||||
identifier=file_id,
|
||||
actor=actor,
|
||||
)
|
||||
try:
|
||||
file_orm = await FileMetadataModel.read_async(
|
||||
db_session=session,
|
||||
identifier=file_id,
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
if file_orm is None:
|
||||
return None
|
||||
|
||||
return await file_orm.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
|
||||
|
||||
@@ -278,7 +284,7 @@ class FileManager:
|
||||
identifier=file_id,
|
||||
actor=actor,
|
||||
)
|
||||
return await file_orm.to_pydantic_async()
|
||||
return file_orm.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -408,7 +414,7 @@ class FileManager:
|
||||
actor: PydanticUser,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
limit: Optional[int] = 1000,
|
||||
ascending: Optional[bool] = True,
|
||||
include_content: bool = False,
|
||||
strip_directory_prefix: bool = False,
|
||||
@@ -445,9 +451,15 @@ class FileManager:
|
||||
)
|
||||
|
||||
# convert all files to pydantic models
|
||||
file_metadatas = await asyncio.gather(
|
||||
*[file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix) for file in files]
|
||||
)
|
||||
if include_content:
|
||||
file_metadatas = await bounded_gather(
|
||||
[
|
||||
file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
|
||||
for file in files
|
||||
]
|
||||
)
|
||||
else:
|
||||
file_metadatas = [file.to_pydantic(strip_directory_prefix=strip_directory_prefix) for file in files]
|
||||
|
||||
# if status checking is enabled, check all files sequentially to avoid db pool exhaustion
|
||||
# Each status check may update the file in the database, so concurrent checks with many
|
||||
@@ -473,7 +485,7 @@ class FileManager:
|
||||
await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id)
|
||||
|
||||
await file.hard_delete_async(db_session=session, actor=actor)
|
||||
return await file.to_pydantic_async()
|
||||
return file.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -555,7 +567,7 @@ class FileManager:
|
||||
file_orm = result.scalar_one_or_none()
|
||||
|
||||
if file_orm:
|
||||
return await file_orm.to_pydantic_async()
|
||||
return file_orm.to_pydantic()
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@@ -664,7 +676,10 @@ class FileManager:
|
||||
result = await session.execute(query)
|
||||
files_orm = result.scalars().all()
|
||||
|
||||
return await asyncio.gather(*[file.to_pydantic_async(include_content=include_content) for file in files_orm])
|
||||
if include_content:
|
||||
return await bounded_gather([file.to_pydantic_async(include_content=include_content) for file in files_orm])
|
||||
else:
|
||||
return [file.to_pydantic() for file in files_orm]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -709,4 +724,7 @@ class FileManager:
|
||||
result = await session.execute(query)
|
||||
files_orm = result.scalars().all()
|
||||
|
||||
return await asyncio.gather(*[file.to_pydantic_async(include_content=include_content) for file in files_orm])
|
||||
if include_content:
|
||||
return await bounded_gather([file.to_pydantic_async(include_content=include_content) for file in files_orm])
|
||||
else:
|
||||
return [file.to_pydantic() for file in files_orm]
|
||||
|
||||
@@ -50,8 +50,10 @@ class FileProcessor:
|
||||
"""Chunk text and generate embeddings with fallback to default chunker if needed"""
|
||||
filename = file_metadata.file_name
|
||||
|
||||
# Create file-type-specific chunker
|
||||
text_chunker = LlamaIndexChunker(file_type=file_metadata.file_type, chunk_size=self.embedder.embedding_config.embedding_chunk_size)
|
||||
# Create file-type-specific chunker in thread pool to avoid blocking event loop
|
||||
text_chunker = await asyncio.to_thread(
|
||||
LlamaIndexChunker, file_type=file_metadata.file_type, chunk_size=self.embedder.embedding_config.embedding_chunk_size
|
||||
)
|
||||
|
||||
# First attempt with file-specific chunker
|
||||
try:
|
||||
|
||||
@@ -200,7 +200,8 @@ class FileAgentManager:
|
||||
stmt = delete(FileAgentModel).where(and_(or_(*conditions), FileAgentModel.organization_id == actor.organization_id))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return result.rowcount
|
||||
|
||||
@@ -291,6 +292,32 @@ class FileAgentManager:
|
||||
else:
|
||||
return [r.to_pydantic() for r in rows]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_file_ids_for_agent_by_source(
|
||||
self,
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all file IDs attached to an agent from a specific source.
|
||||
|
||||
This queries the files_agents junction table directly, ensuring we get
|
||||
exactly the files that were attached, regardless of any changes to the
|
||||
source's file list.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
stmt = select(FileAgentModel.file_id).where(
|
||||
and_(
|
||||
FileAgentModel.agent_id == agent_id,
|
||||
FileAgentModel.source_id == source_id,
|
||||
FileAgentModel.organization_id == actor.organization_id,
|
||||
)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_files_for_agent_paginated(
|
||||
@@ -405,7 +432,8 @@ class FileAgentManager:
|
||||
.values(last_accessed_at=func.now())
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -425,7 +453,8 @@ class FileAgentManager:
|
||||
.values(last_accessed_at=func.now())
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -458,7 +487,8 @@ class FileAgentManager:
|
||||
)
|
||||
|
||||
closed_file_names = [row.file_name for row in (await session.execute(stmt))]
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return closed_file_names
|
||||
|
||||
@enforce_types
|
||||
@@ -702,7 +732,8 @@ class FileAgentManager:
|
||||
.values(is_open=False, visible_content=None)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return closed_file_names
|
||||
|
||||
async def _get_association_by_file_id(self, session, agent_id: str, file_id: str, actor: PydanticUser) -> FileAgentModel:
|
||||
|
||||
@@ -246,7 +246,8 @@ class GroupManager:
|
||||
)
|
||||
await session.execute(delete_stmt)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@@ -258,7 +259,7 @@ class GroupManager:
|
||||
|
||||
# Update turns counter
|
||||
group.turns_counter = (group.turns_counter + 1) % group.sleeptime_agent_frequency
|
||||
await group.update_async(session, actor=actor)
|
||||
await group.update_async(session, actor=actor, no_refresh=True)
|
||||
return group.turns_counter
|
||||
|
||||
@enforce_types
|
||||
@@ -275,7 +276,7 @@ class GroupManager:
|
||||
# Update last processed message id
|
||||
prev_last_processed_message_id = group.last_processed_message_id
|
||||
group.last_processed_message_id = last_processed_message_id
|
||||
await group.update_async(session, actor=actor)
|
||||
await group.update_async(session, actor=actor, no_refresh=True)
|
||||
|
||||
return prev_last_processed_message_id
|
||||
|
||||
@@ -434,7 +435,8 @@ class GroupManager:
|
||||
|
||||
# Add block to group
|
||||
session.add(GroupsBlocks(group_id=group_id, block_id=block_id))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@@ -452,7 +454,8 @@ class GroupManager:
|
||||
# Remove block from group
|
||||
delete_group_block = delete(GroupsBlocks).where(and_(GroupsBlocks.group_id == group_id, GroupsBlocks.block_id == block_id))
|
||||
await session.execute(delete_group_block)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@staticmethod
|
||||
def ensure_buffer_length_range_valid(
|
||||
|
||||
@@ -1099,8 +1099,8 @@ async def build_source_passage_query(
|
||||
embedded_text = np.array(embeddings[0])
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
# Base query for source passages
|
||||
query = select(SourcePassage).where(SourcePassage.organization_id == actor.organization_id)
|
||||
# Base query for source passages - use noload to prevent lazy loading which can block the event loop
|
||||
query = select(SourcePassage).options(noload(SourcePassage.organization)).where(SourcePassage.organization_id == actor.organization_id)
|
||||
|
||||
# If agent_id is specified, join with SourcesAgents to get only passages linked to that agent
|
||||
if agent_id is not None:
|
||||
@@ -1208,23 +1208,26 @@ async def build_agent_passage_query(
|
||||
embedded_text = np.array(embeddings[0])
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
# Base query for passages
|
||||
# Base query for passages - use noload to prevent lazy loading which can block the event loop
|
||||
if agent_id:
|
||||
# Query for agent passages - join through archives_agents
|
||||
# Agent_id takes precedence if both agent_id and archive_id are provided
|
||||
query = (
|
||||
select(ArchivalPassage)
|
||||
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
|
||||
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
|
||||
.where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
elif archive_id:
|
||||
# Query for archive passages directly
|
||||
query = select(ArchivalPassage).where(
|
||||
ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id
|
||||
query = (
|
||||
select(ArchivalPassage)
|
||||
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
|
||||
.where(ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
else:
|
||||
# Org-wide search - all passages in organization
|
||||
query = select(ArchivalPassage).where(ArchivalPassage.organization_id == actor.organization_id)
|
||||
query = (
|
||||
select(ArchivalPassage)
|
||||
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
|
||||
.where(ArchivalPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if start_date:
|
||||
|
||||
@@ -24,7 +24,7 @@ from letta.schemas.identity import (
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
|
||||
@@ -257,7 +257,8 @@ class IdentityManager:
|
||||
if identity.organization_id != actor.organization_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
await session.delete(identity)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -335,7 +336,14 @@ class IdentityManager:
|
||||
ascending=ascending,
|
||||
identity_id=identity.id,
|
||||
)
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents])
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents]
|
||||
)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
|
||||
@@ -58,7 +58,8 @@ class JobManager:
|
||||
job.organization_id = actor.organization_id
|
||||
job = await job.create_async(session, actor=actor, no_commit=True, no_refresh=True) # Save job in the database
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Convert to pydantic first, then add agent_id if needed
|
||||
result = super(JobModel, job).to_pydantic()
|
||||
@@ -122,7 +123,8 @@ class JobManager:
|
||||
# Get the updated metadata for callback
|
||||
final_metadata = job.metadata_
|
||||
result = job.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Dispatch callback outside of database session if needed
|
||||
if needs_callback:
|
||||
@@ -143,7 +145,8 @@ class JobManager:
|
||||
job.callback_error = callback_result.get("callback_error")
|
||||
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
result = job.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return result
|
||||
|
||||
@@ -462,7 +465,8 @@ class JobManager:
|
||||
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
|
||||
job.ttft_ns = ttft_ns
|
||||
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record TTFT for job {job_id}: {e}")
|
||||
|
||||
@@ -475,7 +479,8 @@ class JobManager:
|
||||
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
|
||||
job.total_duration_ns = total_duration_ns
|
||||
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record response duration for job {job_id}: {e}")
|
||||
|
||||
|
||||
@@ -45,7 +45,8 @@ class LLMBatchManager:
|
||||
)
|
||||
await batch.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_batch = batch.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_batch
|
||||
|
||||
@enforce_types
|
||||
@@ -98,7 +99,8 @@ class LLMBatchManager:
|
||||
)
|
||||
|
||||
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -285,7 +287,8 @@ class LLMBatchManager:
|
||||
created_items = await LLMBatchItem.batch_create_async(orm_items, session, actor=actor, no_commit=True, no_refresh=True)
|
||||
|
||||
pydantic_items = [item.to_pydantic() for item in created_items]
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_items
|
||||
|
||||
@enforce_types
|
||||
@@ -421,7 +424,8 @@ class LLMBatchManager:
|
||||
|
||||
if mappings:
|
||||
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -82,7 +82,10 @@ class AsyncBaseMCPClient:
|
||||
result = await self.session.call_tool(tool_name, tool_args)
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ == "McpError":
|
||||
logger.warning(f"MCP tool '{tool_name}' execution failed: {str(e)}")
|
||||
# MCP errors are typically user-facing issues from external MCP servers
|
||||
# (e.g., resource not found, invalid arguments, permission errors)
|
||||
# Log at debug level to avoid triggering production alerts for expected failures
|
||||
logger.debug(f"MCP tool '{tool_name}' execution failed: {str(e)}")
|
||||
raise
|
||||
|
||||
parsed_content = []
|
||||
|
||||
307
letta/services/mcp/fastmcp_client.py
Normal file
307
letta/services/mcp/fastmcp_client.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""FastMCP-based MCP clients with server-side OAuth support.
|
||||
|
||||
This module provides MCP client implementations using the FastMCP library,
|
||||
with support for server-side OAuth flows where authorization URLs are
|
||||
forwarded to web clients instead of opening a browser.
|
||||
|
||||
These clients replace the existing AsyncSSEMCPClient and AsyncStreamableHTTPMCPClient
|
||||
implementations that used the lower-level MCP SDK directly.
|
||||
"""
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
|
||||
from mcp import Tool as MCPTool
|
||||
|
||||
from letta.functions.mcp_client.types import SSEServerConfig, StreamableHTTPServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.services.mcp.server_side_oauth import ServerSideOAuth
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncFastMCPSSEClient:
|
||||
"""SSE MCP client using FastMCP with server-side OAuth support.
|
||||
|
||||
This client connects to MCP servers using Server-Sent Events (SSE) transport.
|
||||
It supports both authenticated and unauthenticated connections, with OAuth
|
||||
handled via the ServerSideOAuth class for server-side flows.
|
||||
|
||||
Args:
|
||||
server_config: SSE server configuration including URL, headers, and auth settings
|
||||
oauth: Optional ServerSideOAuth instance for OAuth authentication
|
||||
agent_id: Optional agent ID to include in request headers
|
||||
"""
|
||||
|
||||
AGENT_ID_HEADER = "X-Agent-Id"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_config: SSEServerConfig,
|
||||
oauth: Optional[ServerSideOAuth] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
):
|
||||
self.server_config = server_config
|
||||
self.oauth = oauth
|
||||
self.agent_id = agent_id
|
||||
self.client: Optional[Client] = None
|
||||
self.initialized = False
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect_to_server(self):
|
||||
"""Establish connection to the MCP server.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If connection to the server fails
|
||||
"""
|
||||
try:
|
||||
headers = {}
|
||||
if self.server_config.custom_headers:
|
||||
headers.update(self.server_config.custom_headers)
|
||||
if self.server_config.auth_header and self.server_config.auth_token:
|
||||
headers[self.server_config.auth_header] = self.server_config.auth_token
|
||||
if self.agent_id:
|
||||
headers[self.AGENT_ID_HEADER] = self.agent_id
|
||||
|
||||
transport = SSETransport(
|
||||
url=self.server_config.server_url,
|
||||
headers=headers if headers else None,
|
||||
auth=self.oauth, # Pass ServerSideOAuth instance (or None)
|
||||
)
|
||||
|
||||
self.client = Client(transport)
|
||||
await self.client._connect()
|
||||
self.initialized = True
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Re-raise HTTP status errors for OAuth flow handling
|
||||
if e.response.status_code == 401:
|
||||
raise ConnectionError("401 Unauthorized")
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def list_tools(self, serialize: bool = False) -> List[MCPTool]:
|
||||
"""List available tools from the MCP server.
|
||||
|
||||
Args:
|
||||
serialize: If True, return tools as dictionaries instead of MCPTool objects
|
||||
|
||||
Returns:
|
||||
List of tools available on the server
|
||||
|
||||
Raises:
|
||||
RuntimeError: If client has not been initialized
|
||||
"""
|
||||
self._check_initialized()
|
||||
tools = await self.client.list_tools()
|
||||
if serialize:
|
||||
serializable_tools = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "model_dump"):
|
||||
serializable_tools.append(tool.model_dump())
|
||||
elif hasattr(tool, "dict"):
|
||||
serializable_tools.append(tool.dict())
|
||||
elif hasattr(tool, "__dict__"):
|
||||
serializable_tools.append(tool.__dict__)
|
||||
else:
|
||||
serializable_tools.append(str(tool))
|
||||
return serializable_tools
|
||||
return tools
|
||||
|
||||
async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
"""Execute a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
tool_args: Arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (result_content, success_flag)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If client has not been initialized
|
||||
"""
|
||||
self._check_initialized()
|
||||
try:
|
||||
result = await self.client.call_tool(tool_name, tool_args)
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ == "McpError":
|
||||
logger.warning(f"MCP tool '{tool_name}' execution failed: {str(e)}")
|
||||
raise
|
||||
|
||||
# Parse content from result
|
||||
parsed_content = []
|
||||
for content_piece in result.content:
|
||||
if hasattr(content_piece, "text"):
|
||||
parsed_content.append(content_piece.text)
|
||||
logger.debug(f"MCP tool result parsed content (text): {parsed_content}")
|
||||
else:
|
||||
parsed_content.append(str(content_piece))
|
||||
logger.debug(f"MCP tool result parsed content (other): {parsed_content}")
|
||||
|
||||
if parsed_content:
|
||||
final_content = " ".join(parsed_content)
|
||||
else:
|
||||
final_content = "Empty response from tool"
|
||||
|
||||
return final_content, not result.is_error
|
||||
|
||||
def _check_initialized(self):
|
||||
"""Check if the client has been initialized."""
|
||||
if not self.initialized:
|
||||
logger.error("MCPClient has not been initialized")
|
||||
raise RuntimeError("MCPClient has not been initialized")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up client resources."""
|
||||
if self.client:
|
||||
try:
|
||||
await self.client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during FastMCP client cleanup: {e}")
|
||||
self.initialized = False
|
||||
|
||||
|
||||
class AsyncFastMCPStreamableHTTPClient:
|
||||
"""Streamable HTTP MCP client using FastMCP with server-side OAuth support.
|
||||
|
||||
This client connects to MCP servers using Streamable HTTP transport.
|
||||
It supports both authenticated and unauthenticated connections, with OAuth
|
||||
handled via the ServerSideOAuth class for server-side flows.
|
||||
|
||||
Args:
|
||||
server_config: Streamable HTTP server configuration
|
||||
oauth: Optional ServerSideOAuth instance for OAuth authentication
|
||||
agent_id: Optional agent ID to include in request headers
|
||||
"""
|
||||
|
||||
AGENT_ID_HEADER = "X-Agent-Id"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_config: StreamableHTTPServerConfig,
|
||||
oauth: Optional[ServerSideOAuth] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
):
|
||||
self.server_config = server_config
|
||||
self.oauth = oauth
|
||||
self.agent_id = agent_id
|
||||
self.client: Optional[Client] = None
|
||||
self.initialized = False
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect_to_server(self):
|
||||
"""Establish connection to the MCP server.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If connection to the server fails
|
||||
"""
|
||||
try:
|
||||
headers = {}
|
||||
if self.server_config.custom_headers:
|
||||
headers.update(self.server_config.custom_headers)
|
||||
if self.server_config.auth_header and self.server_config.auth_token:
|
||||
headers[self.server_config.auth_header] = self.server_config.auth_token
|
||||
if self.agent_id:
|
||||
headers[self.AGENT_ID_HEADER] = self.agent_id
|
||||
|
||||
transport = StreamableHttpTransport(
|
||||
url=self.server_config.server_url,
|
||||
headers=headers if headers else None,
|
||||
auth=self.oauth, # Pass ServerSideOAuth instance (or None)
|
||||
)
|
||||
|
||||
self.client = Client(transport)
|
||||
await self.client._connect()
|
||||
self.initialized = True
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Re-raise HTTP status errors for OAuth flow handling
|
||||
if e.response.status_code == 401:
|
||||
raise ConnectionError("401 Unauthorized")
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def list_tools(self, serialize: bool = False) -> List[MCPTool]:
|
||||
"""List available tools from the MCP server.
|
||||
|
||||
Args:
|
||||
serialize: If True, return tools as dictionaries instead of MCPTool objects
|
||||
|
||||
Returns:
|
||||
List of tools available on the server
|
||||
|
||||
Raises:
|
||||
RuntimeError: If client has not been initialized
|
||||
"""
|
||||
self._check_initialized()
|
||||
tools = await self.client.list_tools()
|
||||
if serialize:
|
||||
serializable_tools = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "model_dump"):
|
||||
serializable_tools.append(tool.model_dump())
|
||||
elif hasattr(tool, "dict"):
|
||||
serializable_tools.append(tool.dict())
|
||||
elif hasattr(tool, "__dict__"):
|
||||
serializable_tools.append(tool.__dict__)
|
||||
else:
|
||||
serializable_tools.append(str(tool))
|
||||
return serializable_tools
|
||||
return tools
|
||||
|
||||
async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
"""Execute a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
tool_args: Arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (result_content, success_flag)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If client has not been initialized
|
||||
"""
|
||||
self._check_initialized()
|
||||
try:
|
||||
result = await self.client.call_tool(tool_name, tool_args)
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ == "McpError":
|
||||
logger.warning(f"MCP tool '{tool_name}' execution failed: {str(e)}")
|
||||
raise
|
||||
|
||||
# Parse content from result
|
||||
parsed_content = []
|
||||
for content_piece in result.content:
|
||||
if hasattr(content_piece, "text"):
|
||||
parsed_content.append(content_piece.text)
|
||||
logger.debug(f"MCP tool result parsed content (text): {parsed_content}")
|
||||
else:
|
||||
parsed_content.append(str(content_piece))
|
||||
logger.debug(f"MCP tool result parsed content (other): {parsed_content}")
|
||||
|
||||
if parsed_content:
|
||||
final_content = " ".join(parsed_content)
|
||||
else:
|
||||
final_content = "Empty response from tool"
|
||||
|
||||
return final_content, not result.is_error
|
||||
|
||||
def _check_initialized(self):
|
||||
"""Check if the client has been initialized."""
|
||||
if not self.initialized:
|
||||
logger.error("MCPClient has not been initialized")
|
||||
raise RuntimeError("MCPClient has not been initialized")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up client resources."""
|
||||
if self.client:
|
||||
try:
|
||||
await self.client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during FastMCP client cleanup: {e}")
|
||||
self.initialized = False
|
||||
@@ -6,7 +6,7 @@ import secrets
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Tuple
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
@@ -18,7 +18,9 @@ from letta.schemas.mcp import MCPOAuthSessionUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.mcp.types import OauthStreamEvent
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -26,7 +28,7 @@ logger = get_logger(__name__)
|
||||
class DatabaseTokenStorage(TokenStorage):
|
||||
"""Database-backed token storage using MCPOAuth table via mcp_manager."""
|
||||
|
||||
def __init__(self, session_id: str, mcp_manager: MCPManager, actor: PydanticUser):
|
||||
def __init__(self, session_id: str, mcp_manager: "MCPManager", actor: PydanticUser):
|
||||
self.session_id = session_id
|
||||
self.mcp_manager = mcp_manager
|
||||
self.actor = actor
|
||||
@@ -150,9 +152,10 @@ class MCPOAuthSession:
|
||||
try:
|
||||
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
|
||||
|
||||
# Encrypt the authorization_code and store only in _enc column
|
||||
# Encrypt the authorization_code and store only in _enc column (async to avoid blocking event loop)
|
||||
if code is not None:
|
||||
oauth_record.authorization_code_enc = Secret.from_plaintext(code).get_encrypted()
|
||||
code_secret = await Secret.from_plaintext_async(code)
|
||||
oauth_record.authorization_code_enc = code_secret.get_encrypted()
|
||||
|
||||
oauth_record.status = OAuthSessionStatus.AUTHORIZED
|
||||
oauth_record.state = state
|
||||
@@ -186,12 +189,17 @@ async def create_oauth_provider(
|
||||
session_id: str,
|
||||
server_url: str,
|
||||
redirect_uri: str,
|
||||
mcp_manager: MCPManager,
|
||||
mcp_manager: "MCPManager",
|
||||
actor: PydanticUser,
|
||||
logo_uri: Optional[str] = None,
|
||||
url_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> OAuthClientProvider:
|
||||
"""Create an OAuth provider for MCP server authentication."""
|
||||
"""Create an OAuth provider for MCP server authentication.
|
||||
|
||||
DEPRECATED: Use ServerSideOAuth from letta.services.mcp.server_side_oauth instead.
|
||||
This function is kept for backwards compatibility but will be removed in a future version.
|
||||
"""
|
||||
logger.warning("create_oauth_provider is deprecated. Use ServerSideOAuth from letta.services.mcp.server_side_oauth instead.")
|
||||
|
||||
client_metadata_dict = {
|
||||
"client_name": "Letta",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user