chore: bump v0.16.1 (#3107)
This commit is contained in:
@@ -57,7 +57,7 @@ RUN set -eux; \
|
||||
esac; \
|
||||
apt-get update && \
|
||||
# Install curl, Python, and PostgreSQL client libraries
|
||||
apt-get install -y curl python3 python3-venv libpq-dev && \
|
||||
apt-get install -y curl python3 python3-venv libpq-dev redis-server && \
|
||||
# Install Node.js
|
||||
curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
@@ -95,7 +95,7 @@ COPY --from=builder /app .
|
||||
# Copy initialization SQL if it exists
|
||||
COPY init.sql /docker-entrypoint-initdb.d/
|
||||
|
||||
EXPOSE 8283 5432 4317 4318
|
||||
EXPOSE 8283 5432 6379 4317 4318
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"]
|
||||
CMD ["./letta/server/startup.sh"]
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add project constraint on tools
|
||||
|
||||
Revision ID: 39577145c45d
|
||||
Revises: d0880aae6cee
|
||||
Create Date: 2025-12-17 15:46:06.184858
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "39577145c45d"
|
||||
down_revision: Union[str, None] = "d0880aae6cee"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_unique_constraint(
|
||||
"uix_organization_project_name", "tools", ["organization_id", "project_id", "name"], postgresql_nulls_not_distinct=True
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("uix_organization_project_name", "tools", type_="unique")
|
||||
# ### end Alembic commands ###
|
||||
@@ -19638,6 +19638,97 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/metadata/balance": {
|
||||
"get": {
|
||||
"description": "Retrieve the current usage balances for the organization.",
|
||||
"summary": "Retrieve current organization balance",
|
||||
"tags": ["metadata"],
|
||||
"parameters": [],
|
||||
"operationId": "metadata.retrieveCurrentBalances",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "200",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"total_balance": {
|
||||
"type": "number"
|
||||
},
|
||||
"monthly_credit_balance": {
|
||||
"type": "number"
|
||||
},
|
||||
"purchased_credit_balance": {
|
||||
"type": "number"
|
||||
},
|
||||
"billing_tier": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"total_balance",
|
||||
"monthly_credit_balance",
|
||||
"purchased_credit_balance",
|
||||
"billing_tier"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/metadata/feedback": {
|
||||
"post": {
|
||||
"description": "Send feedback from users to improve our services.",
|
||||
"summary": "Send user feedback",
|
||||
"tags": ["metadata"],
|
||||
"parameters": [],
|
||||
"operationId": "metadata.sendFeedback",
|
||||
"requestBody": {
|
||||
"description": "Body",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 10000
|
||||
},
|
||||
"feature": {
|
||||
"default": "letta-code",
|
||||
"type": "string",
|
||||
"enum": ["letta-code", "sdk"]
|
||||
}
|
||||
},
|
||||
"required": ["message"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "200",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": ["success"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/scheduled-messages/{scheduled_message_id}": {
|
||||
"delete": {
|
||||
"description": "Delete a scheduled message by its ID for a specific agent.",
|
||||
@@ -21952,6 +22043,11 @@
|
||||
"title": "Content",
|
||||
"description": "The message content sent by the assistant (can be a string or an array of content parts)"
|
||||
},
|
||||
"message_id": {
|
||||
"type": "string",
|
||||
"title": "Message Id",
|
||||
"description": "The unique identifier of the message."
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -21972,9 +22068,9 @@
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["content", "created_at"],
|
||||
"required": ["content", "message_id", "created_at"],
|
||||
"title": "AssistantMessageListResult",
|
||||
"description": "Assistant message list result with agent context.\n\nShape is identical to UpdateAssistantMessage but includes the owning agent_id."
|
||||
"description": "Assistant message list result with agent context.\n\nShape is identical to UpdateAssistantMessage but includes the owning agent_id and message id."
|
||||
},
|
||||
"Audio": {
|
||||
"properties": {
|
||||
@@ -34376,6 +34472,11 @@
|
||||
"title": "Message Type",
|
||||
"default": "reasoning_message"
|
||||
},
|
||||
"message_id": {
|
||||
"type": "string",
|
||||
"title": "Message Id",
|
||||
"description": "The unique identifier of the message."
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -34396,9 +34497,9 @@
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["reasoning", "created_at"],
|
||||
"required": ["reasoning", "message_id", "created_at"],
|
||||
"title": "ReasoningMessageListResult",
|
||||
"description": "Reasoning message list result with agent context.\n\nShape is identical to UpdateReasoningMessage but includes the owning agent_id."
|
||||
"description": "Reasoning message list result with agent context.\n\nShape is identical to UpdateReasoningMessage but includes the owning agent_id and message id."
|
||||
},
|
||||
"RedactedReasoningContent": {
|
||||
"properties": {
|
||||
@@ -36870,6 +36971,11 @@
|
||||
"title": "Content",
|
||||
"description": "The message content sent by the system (can be a string or an array of multi-modal content parts)"
|
||||
},
|
||||
"message_id": {
|
||||
"type": "string",
|
||||
"title": "Message Id",
|
||||
"description": "The unique identifier of the message."
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -36890,9 +36996,9 @@
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["content", "created_at"],
|
||||
"required": ["content", "message_id", "created_at"],
|
||||
"title": "SystemMessageListResult",
|
||||
"description": "System message list result with agent context.\n\nShape is identical to UpdateSystemMessage but includes the owning agent_id."
|
||||
"description": "System message list result with agent context.\n\nShape is identical to UpdateSystemMessage but includes the owning agent_id and message id."
|
||||
},
|
||||
"TagSchema": {
|
||||
"properties": {
|
||||
@@ -39541,6 +39647,11 @@
|
||||
"title": "Content",
|
||||
"description": "The message content sent by the user (can be a string or an array of multi-modal content parts)"
|
||||
},
|
||||
"message_id": {
|
||||
"type": "string",
|
||||
"title": "Message Id",
|
||||
"description": "The unique identifier of the message."
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -39561,9 +39672,9 @@
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["content", "created_at"],
|
||||
"required": ["content", "message_id", "created_at"],
|
||||
"title": "UserMessageListResult",
|
||||
"description": "User message list result with agent context.\n\nShape is identical to UpdateUserMessage but includes the owning agent_id."
|
||||
"description": "User message list result with agent context.\n\nShape is identical to UpdateUserMessage but includes the owning agent_id and message id."
|
||||
},
|
||||
"UserUpdate": {
|
||||
"properties": {
|
||||
|
||||
@@ -5,7 +5,7 @@ try:
|
||||
__version__ = version("letta")
|
||||
except PackageNotFoundError:
|
||||
# Fallback for development installations
|
||||
__version__ = "0.16.0"
|
||||
__version__ = "0.16.1"
|
||||
|
||||
if os.environ.get("LETTA_VERSION"):
|
||||
__version__ = os.environ["LETTA_VERSION"]
|
||||
|
||||
@@ -109,6 +109,8 @@ def validate_approval_tool_call_ids(approval_request_message: Message, approval_
|
||||
)
|
||||
|
||||
approval_responses = approval_response_message.approvals
|
||||
if not approval_responses:
|
||||
raise ValueError("Invalid approval response. Approval response message does not contain any approvals.")
|
||||
approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses]
|
||||
|
||||
request_response_diff = set(approval_request_tool_call_ids).symmetric_difference(set(approval_response_tool_call_ids))
|
||||
|
||||
@@ -1902,8 +1902,8 @@ class LettaAgent(BaseAgent):
|
||||
start_time = get_utc_timestamp_ns()
|
||||
agent_step_span.add_event(name="tool_execution_started")
|
||||
|
||||
# Decrypt environment variable values
|
||||
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
|
||||
# Use pre-decrypted environment variable values (populated in from_orm_async)
|
||||
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
|
||||
@@ -1184,8 +1184,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
start_time = get_utc_timestamp_ns()
|
||||
agent_step_span.add_event(name="tool_execution_started")
|
||||
|
||||
# Decrypt environment variable values
|
||||
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
|
||||
# Use pre-decrypted environment variable values (populated in from_orm_async)
|
||||
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
|
||||
@@ -368,7 +368,9 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# Cleanup and finalize (only runs if no exception occurred)
|
||||
try:
|
||||
if run_id:
|
||||
result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
|
||||
# Filter out LettaStopReason from messages (only valid in LettaStreamingResponse, not LettaResponse)
|
||||
filtered_messages = [m for m in response_letta_messages if not isinstance(m, LettaStopReason)]
|
||||
result = LettaResponse(messages=filtered_messages, stop_reason=self.stop_reason, usage=self.usage)
|
||||
if self.job_update_metadata is None:
|
||||
self.job_update_metadata = {}
|
||||
self.job_update_metadata["result"] = result.model_dump(mode="json")
|
||||
|
||||
@@ -438,8 +438,8 @@ class VoiceAgent(BaseAgent):
|
||||
)
|
||||
|
||||
# Use ToolExecutionManager for modern tool execution
|
||||
# Decrypt environment variable values
|
||||
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
|
||||
# Use pre-decrypted environment variable values (populated in from_orm_async)
|
||||
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
|
||||
@@ -226,11 +226,14 @@ CORE_MEMORY_LINE_NUMBER_WARNING = "# NOTE: Line numbers shown below (with arrows
|
||||
|
||||
# Constants to do with summarization / conversation length window
|
||||
# The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B)
|
||||
LLM_MAX_TOKENS = {
|
||||
LLM_MAX_CONTEXT_WINDOW = {
|
||||
"DEFAULT": 30000,
|
||||
# deepseek
|
||||
"deepseek-chat": 64000,
|
||||
"deepseek-reasoner": 64000,
|
||||
# glm (Z.AI)
|
||||
"glm-4.6": 200000,
|
||||
"glm-4.5": 128000,
|
||||
## OpenAI models: https://platform.openai.com/docs/models/overview
|
||||
# gpt-5
|
||||
"gpt-5": 272000,
|
||||
@@ -357,6 +360,9 @@ LLM_MAX_TOKENS = {
|
||||
"gemini-2.5-flash-preview-09-2025": 1048576,
|
||||
"gemini-2.5-flash-lite-preview-09-2025": 1048576,
|
||||
"gemini-2.5-computer-use-preview-10-2025": 1048576,
|
||||
# gemini 3
|
||||
"gemini-3-pro-preview": 1048576,
|
||||
"gemini-3-flash-preview": 1048576,
|
||||
# gemini latest aliases
|
||||
"gemini-flash-latest": 1048576,
|
||||
"gemini-flash-lite-latest": 1048576,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -204,14 +205,14 @@ class DynamicMultiAgent(BaseAgent):
|
||||
"holds info about them, and you should use this context to inform your decision."
|
||||
)
|
||||
self.agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_manager_persona)
|
||||
return Agent(
|
||||
return LettaAgent(
|
||||
agent_state=self.agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
save_last_response=True,
|
||||
)
|
||||
|
||||
def load_participant_agent(self, agent_id: str) -> Agent:
|
||||
def load_participant_agent(self, agent_id: str) -> LettaAgent:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
persona_block = agent_state.memory.get_block(label="persona")
|
||||
group_chat_participant_persona = (
|
||||
@@ -220,7 +221,7 @@ class DynamicMultiAgent(BaseAgent):
|
||||
f"Description of the group: {self.description}. About you: "
|
||||
)
|
||||
agent_state.memory.update_block_value(label="persona", value=group_chat_participant_persona + persona_block.value)
|
||||
return Agent(
|
||||
return LettaAgent(
|
||||
agent_state=agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.user import User
|
||||
@@ -17,7 +18,7 @@ def load_multi_agent(
|
||||
actor: User,
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
mcp_clients: Optional[Dict[str, AsyncBaseMCPClient]] = None,
|
||||
) -> "Agent":
|
||||
) -> LettaAgent:
|
||||
if len(group.agent_ids) == 0:
|
||||
raise ValueError("Empty group: group must have at least one agent")
|
||||
|
||||
@@ -63,7 +64,7 @@ def load_multi_agent(
|
||||
)
|
||||
case ManagerType.sleeptime:
|
||||
if not agent_state.enable_sleeptime:
|
||||
return Agent(
|
||||
return LettaAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -131,7 +132,7 @@ class RoundRobinMultiAgent(BaseAgent):
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
||||
def load_participant_agent(self, agent_id: str) -> Agent:
|
||||
def load_participant_agent(self, agent_id: str) -> LettaAgent:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
persona_block = agent_state.memory.get_block(label="persona")
|
||||
group_chat_participant_persona = (
|
||||
@@ -152,7 +153,7 @@ class RoundRobinMultiAgent(BaseAgent):
|
||||
"%%% END GROUP CHAT CONTEXT %%%"
|
||||
)
|
||||
agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_participant_persona)
|
||||
return Agent(
|
||||
return LettaAgent(
|
||||
agent_state=agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.groups.helpers import stringify_message
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
@@ -114,7 +115,7 @@ class SleeptimeMultiAgent(BaseAgent):
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
|
||||
|
||||
participant_agent_state = self.agent_manager.get_agent_by_id(participant_agent_id, actor=self.user)
|
||||
participant_agent = Agent(
|
||||
participant_agent = LettaAgent(
|
||||
agent_state=participant_agent_state,
|
||||
interface=StreamingServerInterface(),
|
||||
user=self.user,
|
||||
@@ -212,7 +213,7 @@ class SleeptimeMultiAgent(BaseAgent):
|
||||
|
||||
try:
|
||||
# Load main agent
|
||||
main_agent = Agent(
|
||||
main_agent = LettaAgent(
|
||||
agent_state=self.agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
|
||||
@@ -30,6 +30,7 @@ from openai.types.responses import (
|
||||
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.llm_api.error_utils import is_context_window_overflow_message
|
||||
from letta.llm_api.openai_client import is_openai_reasoning_model
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
@@ -746,6 +747,14 @@ class SimpleOpenAIStreamingInterface:
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
# IMPORTANT: If this is a context window overflow, we should propagate the
|
||||
# exception upward so the agent loop can compact/summarize + retry.
|
||||
# Yielding an error stop reason here would prematurely terminate the user's
|
||||
# stream even though a retry path exists.
|
||||
msg = str(e)
|
||||
if is_context_window_overflow_message(msg):
|
||||
raise
|
||||
|
||||
logger.exception("Error processing stream: %s", e)
|
||||
if ttft_span:
|
||||
ttft_span.add_event(
|
||||
|
||||
22
letta/llm_api/error_utils.py
Normal file
22
letta/llm_api/error_utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Shared helpers for provider error detection/mapping.
|
||||
|
||||
Keep these utilities free of heavy imports to avoid circular dependencies between
|
||||
LLM clients (provider-specific) and streaming interfaces.
|
||||
"""
|
||||
|
||||
|
||||
def is_context_window_overflow_message(msg: str) -> bool:
|
||||
"""Best-effort detection for context window overflow errors.
|
||||
|
||||
Different providers (and even different API surfaces within the same provider)
|
||||
may phrase context-window errors differently. We centralize the heuristic so
|
||||
all layers (clients, streaming interfaces, agent loops) behave consistently.
|
||||
"""
|
||||
|
||||
return (
|
||||
"exceeds the context window" in msg
|
||||
or "This model's maximum context length is" in msg
|
||||
or "maximum context length" in msg
|
||||
or "context_length_exceeded" in msg
|
||||
or "Input tokens exceed the configured limit" in msg
|
||||
)
|
||||
@@ -14,6 +14,8 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GoogleAIClient(GoogleVertexClient):
|
||||
provider_label = "Google AI"
|
||||
|
||||
def _get_client(self):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
return genai.Client(
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
GOOGLE_MODEL_TO_CONTEXT_LENGTH = {
|
||||
"gemini-3-pro-preview": 1048576,
|
||||
"gemini-3-flash-preview": 1048576,
|
||||
"gemini-2.5-pro": 1048576,
|
||||
"gemini-2.5-flash": 1048576,
|
||||
"gemini-live-2.5-flash": 1048576,
|
||||
|
||||
@@ -46,6 +46,7 @@ logger = get_logger(__name__)
|
||||
|
||||
class GoogleVertexClient(LLMClientBase):
|
||||
MAX_RETRIES = model_settings.gemini_max_retries
|
||||
provider_label = "Google Vertex"
|
||||
|
||||
def _get_client(self):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
@@ -56,6 +57,12 @@ class GoogleVertexClient(LLMClientBase):
|
||||
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
|
||||
)
|
||||
|
||||
def _provider_prefix(self) -> str:
|
||||
return f"[{self.provider_label}]"
|
||||
|
||||
def _provider_name(self) -> str:
|
||||
return self.provider_label
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
@@ -148,7 +155,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
config=request_data["config"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Google Vertex request: {e} with request data: {json.dumps(request_data)}")
|
||||
logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
# Direct yield - keeps response alive in generator's local scope throughout iteration
|
||||
# This is required because the SDK's connection lifecycle is tied to the response object
|
||||
@@ -448,9 +455,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
if content is None or content.role is None or content.parts is None:
|
||||
# This means the response is malformed like MALFORMED_FUNCTION_CALL
|
||||
if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
|
||||
raise LLMServerError(f"Malformed response from Google Vertex: {candidate.finish_reason}")
|
||||
raise LLMServerError(f"Malformed response from {self._provider_name()}: {candidate.finish_reason}")
|
||||
else:
|
||||
raise LLMServerError(f"Invalid response data from Google Vertex: {candidate.model_dump()}")
|
||||
raise LLMServerError(f"Invalid response data from {self._provider_name()}: {candidate.model_dump()}")
|
||||
|
||||
role = content.role
|
||||
assert role == "model", f"Unknown role in response: {role}"
|
||||
@@ -742,55 +749,55 @@ class GoogleVertexClient(LLMClientBase):
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
# Handle Google GenAI specific errors
|
||||
if isinstance(e, errors.ClientError):
|
||||
logger.warning(f"[Google Vertex] Client error ({e.code}): {e}")
|
||||
logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
|
||||
|
||||
# Handle specific error codes
|
||||
if e.code == 400:
|
||||
error_str = str(e).lower()
|
||||
if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
|
||||
return ContextWindowExceededError(
|
||||
message=f"Bad request to Google Vertex (context window exceeded): {str(e)}",
|
||||
message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
|
||||
)
|
||||
else:
|
||||
return LLMBadRequestError(
|
||||
message=f"Bad request to Google Vertex: {str(e)}",
|
||||
message=f"Bad request to {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 401:
|
||||
return LLMAuthenticationError(
|
||||
message=f"Authentication failed with Google Vertex: {str(e)}",
|
||||
message=f"Authentication failed with {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 403:
|
||||
return LLMPermissionDeniedError(
|
||||
message=f"Permission denied by Google Vertex: {str(e)}",
|
||||
message=f"Permission denied by {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 404:
|
||||
return LLMNotFoundError(
|
||||
message=f"Resource not found in Google Vertex: {str(e)}",
|
||||
message=f"Resource not found in {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 408:
|
||||
return LLMTimeoutError(
|
||||
message=f"Request to Google Vertex timed out: {str(e)}",
|
||||
message=f"Request to {self._provider_name()} timed out: {str(e)}",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
elif e.code == 422:
|
||||
return LLMUnprocessableEntityError(
|
||||
message=f"Invalid request content for Google Vertex: {str(e)}",
|
||||
message=f"Invalid request content for {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 429:
|
||||
logger.warning("[Google Vertex] Rate limited (429). Consider backoff.")
|
||||
logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
|
||||
return LLMRateLimitError(
|
||||
message=f"Rate limited by Google Vertex: {str(e)}",
|
||||
message=f"Rate limited by {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
)
|
||||
else:
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex client error: {str(e)}",
|
||||
message=f"{self._provider_name()} client error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -799,12 +806,12 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
|
||||
if isinstance(e, errors.ServerError):
|
||||
logger.warning(f"[Google Vertex] Server error ({e.code}): {e}")
|
||||
logger.warning(f"{self._provider_prefix()} Server error ({e.code}): {e}")
|
||||
|
||||
# Handle specific server error codes
|
||||
if e.code == 500:
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex internal server error: {str(e)}",
|
||||
message=f"{self._provider_name()} internal server error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -813,13 +820,13 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
elif e.code == 502:
|
||||
return LLMConnectionError(
|
||||
message=f"Bad gateway from Google Vertex: {str(e)}",
|
||||
message=f"Bad gateway from {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
elif e.code == 503:
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex service unavailable: {str(e)}",
|
||||
message=f"{self._provider_name()} service unavailable: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -828,13 +835,13 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
elif e.code == 504:
|
||||
return LLMTimeoutError(
|
||||
message=f"Gateway timeout from Google Vertex: {str(e)}",
|
||||
message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
else:
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex server error: {str(e)}",
|
||||
message=f"{self._provider_name()} server error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -843,9 +850,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
|
||||
if isinstance(e, errors.APIError):
|
||||
logger.warning(f"[Google Vertex] API error ({e.code}): {e}")
|
||||
logger.warning(f"{self._provider_prefix()} API error ({e.code}): {e}")
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex API error: {str(e)}",
|
||||
message=f"{self._provider_name()} API error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -855,9 +862,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
|
||||
# Handle connection-related errors
|
||||
if "connection" in str(e).lower() or "timeout" in str(e).lower():
|
||||
logger.warning(f"[Google Vertex] Connection/timeout error: {e}")
|
||||
logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Failed to connect to Google Vertex: {str(e)}",
|
||||
message=f"Failed to connect to {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from letta.errors import (
|
||||
LLMTimeoutError,
|
||||
LLMUnprocessableEntityError,
|
||||
)
|
||||
from letta.llm_api.error_utils import is_context_window_overflow_message
|
||||
from letta.llm_api.helpers import (
|
||||
add_inner_thoughts_to_functions,
|
||||
convert_response_format_to_responses_api,
|
||||
@@ -978,11 +979,7 @@ class OpenAIClient(LLMClientBase):
|
||||
error_code = error_details.get("code")
|
||||
|
||||
# Check both the error code and message content for context length issues
|
||||
if (
|
||||
error_code == "context_length_exceeded"
|
||||
or "This model's maximum context length is" in str(e)
|
||||
or "Input tokens exceed the configured limit" in str(e)
|
||||
):
|
||||
if error_code == "context_length_exceeded" or is_context_window_overflow_message(str(e)):
|
||||
return ContextWindowExceededError(
|
||||
message=f"Bad request to OpenAI (context window exceeded): {str(e)}",
|
||||
)
|
||||
@@ -993,6 +990,25 @@ class OpenAIClient(LLMClientBase):
|
||||
details=e.body,
|
||||
)
|
||||
|
||||
# NOTE: The OpenAI Python SDK may raise a generic `openai.APIError` while *iterating*
|
||||
# over a stream (e.g. Responses API streaming). In this case we don't necessarily
|
||||
# get a `BadRequestError` with a structured error body, but we still want to
|
||||
# trigger Letta's context window compaction / retry logic.
|
||||
#
|
||||
# Example message:
|
||||
# "Your input exceeds the context window of this model. Please adjust your input and try again."
|
||||
if isinstance(e, openai.APIError):
|
||||
msg = str(e)
|
||||
if is_context_window_overflow_message(msg):
|
||||
return ContextWindowExceededError(
|
||||
message=f"OpenAI request exceeded the context window: {msg}",
|
||||
details={
|
||||
"provider_exception_type": type(e).__name__,
|
||||
# Best-effort extraction (may not exist on APIError)
|
||||
"body": getattr(e, "body", None),
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.AuthenticationError):
|
||||
logger.error(f"[OpenAI] Authentication error (401): {str(e)}") # More severe log level
|
||||
return LLMAuthenticationError(
|
||||
|
||||
@@ -18,6 +18,6 @@ SIMPLE = {
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
# "max_context_length": LLM_MAX_TOKENS,
|
||||
# "max_context_length": LLM_MAX_CONTEXT_WINDOW,
|
||||
"max_length": 512,
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ SIMPLE = {
|
||||
# This controls the maximum number of tokens that the model can generate
|
||||
# Cap this at the model context length (assuming 8k for Mistral 7B)
|
||||
# "max_tokens": 8000,
|
||||
# "max_tokens": LLM_MAX_TOKENS,
|
||||
# "max_tokens": LLM_MAX_CONTEXT_WINDOW,
|
||||
# This controls how LM studio handles context overflow
|
||||
# In Letta we handle this ourselves, so this should be commented out
|
||||
# "lmstudio": {"context_overflow_policy": 2},
|
||||
|
||||
@@ -20,7 +20,7 @@ SIMPLE = {
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
# "num_ctx": LLM_MAX_TOKENS,
|
||||
# "num_ctx": LLM_MAX_CONTEXT_WINDOW,
|
||||
},
|
||||
"stream": False,
|
||||
# turn off Ollama's own prompt formatting
|
||||
|
||||
@@ -19,5 +19,5 @@ SIMPLE = {
|
||||
],
|
||||
"max_new_tokens": 3072,
|
||||
# "truncation_length": 4096, # assuming llama2 models
|
||||
# "truncation_length": LLM_MAX_TOKENS, # assuming mistral 7b
|
||||
# "truncation_length": LLM_MAX_CONTEXT_WINDOW, # assuming mistral 7b
|
||||
}
|
||||
|
||||
@@ -20,5 +20,5 @@ SIMPLE = {
|
||||
],
|
||||
# "max_tokens": 3072,
|
||||
# "truncation_length": 4096, # assuming llama2 models
|
||||
# "truncation_length": LLM_MAX_TOKENS, # assuming mistral 7b
|
||||
# "truncation_length": LLM_MAX_CONTEXT_WINDOW, # assuming mistral 7b
|
||||
}
|
||||
|
||||
@@ -434,7 +434,9 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
state["multi_agent_group"] = multi_agent_group
|
||||
state["managed_group"] = multi_agent_group
|
||||
# Convert ORM env vars to Pydantic with async decryption
|
||||
env_vars_pydantic = [await PydanticAgentEnvVar.from_orm_async(e) for e in tool_exec_environment_variables]
|
||||
env_vars_pydantic = []
|
||||
for e in tool_exec_environment_variables:
|
||||
env_vars_pydantic.append(await PydanticAgentEnvVar.from_orm_async(e))
|
||||
state["tool_exec_environment_variables"] = env_vars_pydantic
|
||||
state["secrets"] = env_vars_pydantic
|
||||
state["model"] = self.llm_config.handle if self.llm_config else None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import JSON, String, Text, UniqueConstraint
|
||||
@@ -11,6 +12,7 @@ from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.enums import MCPServerType
|
||||
from letta.schemas.mcp import MCPServer
|
||||
from letta.schemas.secret import Secret
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
@@ -60,6 +62,23 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="mcp_servers")
|
||||
|
||||
def to_pydantic(self):
|
||||
"""Convert ORM model to Pydantic model, handling encrypted fields."""
|
||||
# Parse custom_headers from JSON if stored as string
|
||||
return self.__pydantic_model__(
|
||||
id=self.id,
|
||||
server_type=self.server_type,
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
token_enc=Secret.from_encrypted(self.token_enc) if self.token_enc else None,
|
||||
custom_headers_enc=Secret.from_encrypted(self.custom_headers_enc) if self.custom_headers_enc else None,
|
||||
stdio_config=self.stdio_config,
|
||||
organization_id=self.organization_id,
|
||||
created_by_id=self.created_by_id,
|
||||
last_updated_by_id=self.last_updated_by_id,
|
||||
metadata_=self.metadata_,
|
||||
)
|
||||
|
||||
|
||||
class MCPTools(SqlalchemyBase, OrganizationMixin):
|
||||
"""Represents a mapping of MCP server ID to tool ID"""
|
||||
|
||||
@@ -29,6 +29,7 @@ class Tool(SqlalchemyBase, OrganizationMixin, ProjectMixin):
|
||||
# An organization should not have multiple tools with the same name
|
||||
__table_args__ = (
|
||||
UniqueConstraint("name", "organization_id", name="uix_name_organization"),
|
||||
UniqueConstraint("organization_id", "project_id", "name", name="uix_organization_project_name", postgresql_nulls_not_distinct=True),
|
||||
Index("ix_tools_created_at_name", "created_at", "name"),
|
||||
Index("ix_tools_organization_id", "organization_id"),
|
||||
Index("ix_tools_organization_id_name", "organization_id", "name"),
|
||||
|
||||
@@ -563,9 +563,13 @@ LettaMessageUpdateUnion = Annotated[
|
||||
class SystemMessageListResult(UpdateSystemMessage):
|
||||
"""System message list result with agent context.
|
||||
|
||||
Shape is identical to UpdateSystemMessage but includes the owning agent_id.
|
||||
Shape is identical to UpdateSystemMessage but includes the owning agent_id and message id.
|
||||
"""
|
||||
|
||||
message_id: str = Field(
|
||||
...,
|
||||
description="The unique identifier of the message.",
|
||||
)
|
||||
agent_id: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the agent that owns the message.",
|
||||
@@ -577,9 +581,13 @@ class SystemMessageListResult(UpdateSystemMessage):
|
||||
class UserMessageListResult(UpdateUserMessage):
|
||||
"""User message list result with agent context.
|
||||
|
||||
Shape is identical to UpdateUserMessage but includes the owning agent_id.
|
||||
Shape is identical to UpdateUserMessage but includes the owning agent_id and message id.
|
||||
"""
|
||||
|
||||
message_id: str = Field(
|
||||
...,
|
||||
description="The unique identifier of the message.",
|
||||
)
|
||||
agent_id: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the agent that owns the message.",
|
||||
@@ -591,9 +599,13 @@ class UserMessageListResult(UpdateUserMessage):
|
||||
class ReasoningMessageListResult(UpdateReasoningMessage):
|
||||
"""Reasoning message list result with agent context.
|
||||
|
||||
Shape is identical to UpdateReasoningMessage but includes the owning agent_id.
|
||||
Shape is identical to UpdateReasoningMessage but includes the owning agent_id and message id.
|
||||
"""
|
||||
|
||||
message_id: str = Field(
|
||||
...,
|
||||
description="The unique identifier of the message.",
|
||||
)
|
||||
agent_id: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the agent that owns the message.",
|
||||
@@ -605,9 +617,13 @@ class ReasoningMessageListResult(UpdateReasoningMessage):
|
||||
class AssistantMessageListResult(UpdateAssistantMessage):
|
||||
"""Assistant message list result with agent context.
|
||||
|
||||
Shape is identical to UpdateAssistantMessage but includes the owning agent_id.
|
||||
Shape is identical to UpdateAssistantMessage but includes the owning agent_id and message id.
|
||||
"""
|
||||
|
||||
message_id: str = Field(
|
||||
...,
|
||||
description="The unique identifier of the message.",
|
||||
)
|
||||
agent_id: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the agent that owns the message.",
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from letta.functions.mcp_client.types import (
|
||||
MCP_AUTH_HEADER_AUTHORIZATION,
|
||||
MCP_AUTH_TOKEN_BEARER_PREFIX,
|
||||
@@ -48,68 +51,110 @@ class MCPServer(BaseMCPServer):
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
|
||||
|
||||
def get_token_secret(self) -> Secret:
|
||||
"""Get the token as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.token_enc is not None:
|
||||
return self.token_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.token)
|
||||
def get_token_secret(self) -> Optional[Secret]:
|
||||
"""Get the token as a Secret object."""
|
||||
return self.token_enc
|
||||
|
||||
def get_custom_headers_secret(self) -> Secret:
|
||||
"""Get custom headers as a Secret object (stores JSON string). Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.custom_headers_enc is not None:
|
||||
return self.custom_headers_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
# Convert dict to JSON string for Secret storage
|
||||
plaintext_json = json.dumps(self.custom_headers) if self.custom_headers else None
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=plaintext_json)
|
||||
def get_custom_headers_secret(self) -> Optional[Secret]:
|
||||
"""Get the custom headers as a Secret object (JSON string)."""
|
||||
return self.custom_headers_enc
|
||||
|
||||
def get_custom_headers_dict(self) -> Optional[Dict[str, str]]:
|
||||
"""Get custom headers as a plaintext dictionary."""
|
||||
"""Get the custom headers as a dictionary."""
|
||||
if self.custom_headers_enc:
|
||||
json_str = self.custom_headers_enc.get_plaintext()
|
||||
if json_str:
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_custom_headers_dict_async(self) -> Optional[Dict[str, str]]:
|
||||
"""Get custom headers as a plaintext dictionary (async version)."""
|
||||
secret = self.get_custom_headers_secret()
|
||||
json_str = secret.get_plaintext()
|
||||
if secret is None:
|
||||
return None
|
||||
json_str = await secret.get_plaintext_async()
|
||||
if json_str:
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
|
||||
return None
|
||||
|
||||
def set_token_secret(self, secret: Secret) -> None:
|
||||
"""Set token from a Secret object, updating both encrypted and plaintext fields."""
|
||||
"""Set token from a Secret object."""
|
||||
self.token_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
# Only set plaintext during migration phase
|
||||
if not secret.was_encrypted:
|
||||
self.token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.token = None
|
||||
|
||||
def set_custom_headers_secret(self, secret: Secret) -> None:
|
||||
"""Set custom headers from a Secret object (containing JSON string), updating both fields."""
|
||||
"""Set custom headers from a Secret object (JSON string)."""
|
||||
self.custom_headers_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
# Parse JSON string to dict for plaintext field
|
||||
json_str = secret_dict.get("plaintext")
|
||||
if json_str and not secret.was_encrypted:
|
||||
try:
|
||||
self.custom_headers = json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
self.custom_headers = None
|
||||
else:
|
||||
self.custom_headers = None
|
||||
|
||||
def to_config(
|
||||
self,
|
||||
environment_variables: Optional[Dict[str, str]] = None,
|
||||
resolve_variables: bool = True,
|
||||
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
|
||||
# Get decrypted values directly from encrypted columns
|
||||
token_plaintext = self.token_enc.get_plaintext() if self.token_enc else None
|
||||
|
||||
# Get custom headers as dict from encrypted column
|
||||
headers_plaintext = None
|
||||
if self.custom_headers_enc:
|
||||
json_str = self.custom_headers_enc.get_plaintext()
|
||||
if json_str:
|
||||
try:
|
||||
headers_plaintext = json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
|
||||
|
||||
if self.server_type == MCPServerType.SSE:
|
||||
config = SSEServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
|
||||
custom_headers=headers_plaintext,
|
||||
)
|
||||
if resolve_variables:
|
||||
config.resolve_environment_variables(environment_variables)
|
||||
return config
|
||||
elif self.server_type == MCPServerType.STDIO:
|
||||
if self.stdio_config is None:
|
||||
raise ValueError("stdio_config is required for STDIO server type")
|
||||
if resolve_variables:
|
||||
self.stdio_config.resolve_environment_variables(environment_variables)
|
||||
return self.stdio_config
|
||||
elif self.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
if self.server_url is None:
|
||||
raise ValueError("server_url is required for STREAMABLE_HTTP server type")
|
||||
|
||||
config = StreamableHTTPServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
|
||||
custom_headers=headers_plaintext,
|
||||
)
|
||||
if resolve_variables:
|
||||
config.resolve_environment_variables(environment_variables)
|
||||
return config
|
||||
else:
|
||||
raise ValueError(f"Unsupported server type: {self.server_type}")
|
||||
|
||||
async def to_config_async(
|
||||
self,
|
||||
environment_variables: Optional[Dict[str, str]] = None,
|
||||
resolve_variables: bool = True,
|
||||
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
|
||||
"""Async version of to_config() that uses async decryption."""
|
||||
# Get decrypted values for use in config
|
||||
token_secret = self.get_token_secret()
|
||||
token_plaintext = token_secret.get_plaintext()
|
||||
token_plaintext = await token_secret.get_plaintext_async() if token_secret else None
|
||||
|
||||
# Get custom headers as dict
|
||||
headers_plaintext = self.get_custom_headers_dict()
|
||||
headers_plaintext = await self.get_custom_headers_dict_async()
|
||||
|
||||
if self.server_type == MCPServerType.SSE:
|
||||
config = SSEServerConfig(
|
||||
@@ -228,66 +273,6 @@ class MCPOAuthSession(BaseMCPOAuth):
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="Session creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
|
||||
|
||||
def get_access_token_secret(self) -> Secret:
|
||||
"""Get the access token as a Secret object, preferring encrypted over plaintext."""
|
||||
if self.access_token_enc is not None:
|
||||
return self.access_token_enc
|
||||
return Secret.from_db(None, self.access_token)
|
||||
|
||||
def get_refresh_token_secret(self) -> Secret:
|
||||
"""Get the refresh token as a Secret object, preferring encrypted over plaintext."""
|
||||
if self.refresh_token_enc is not None:
|
||||
return self.refresh_token_enc
|
||||
return Secret.from_db(None, self.refresh_token)
|
||||
|
||||
def get_client_secret_secret(self) -> Secret:
|
||||
"""Get the client secret as a Secret object, preferring encrypted over plaintext."""
|
||||
if self.client_secret_enc is not None:
|
||||
return self.client_secret_enc
|
||||
return Secret.from_db(None, self.client_secret)
|
||||
|
||||
def get_authorization_code_secret(self) -> Secret:
|
||||
"""Get the authorization code as a Secret object, preferring encrypted over plaintext."""
|
||||
if self.authorization_code_enc is not None:
|
||||
return self.authorization_code_enc
|
||||
return Secret.from_db(None, self.authorization_code)
|
||||
|
||||
def set_access_token_secret(self, secret: Secret) -> None:
|
||||
"""Set access token from a Secret object."""
|
||||
self.access_token_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.access_token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.access_token = None
|
||||
|
||||
def set_refresh_token_secret(self, secret: Secret) -> None:
|
||||
"""Set refresh token from a Secret object."""
|
||||
self.refresh_token_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.refresh_token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.refresh_token = None
|
||||
|
||||
def set_client_secret_secret(self, secret: Secret) -> None:
|
||||
"""Set client secret from a Secret object."""
|
||||
self.client_secret_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.client_secret = secret_dict["plaintext"]
|
||||
else:
|
||||
self.client_secret = None
|
||||
|
||||
def set_authorization_code_secret(self, secret: Secret) -> None:
|
||||
"""Set authorization code from a Secret object."""
|
||||
self.authorization_code_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.authorization_code = secret_dict["plaintext"]
|
||||
else:
|
||||
self.authorization_code = None
|
||||
|
||||
|
||||
class MCPOAuthSessionCreate(BaseMCPOAuth):
|
||||
"""Create a new OAuth session."""
|
||||
|
||||
@@ -165,68 +165,36 @@ class MCPOAuthSession(BaseMCPOAuth):
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
|
||||
|
||||
def get_access_token_secret(self) -> Secret:
|
||||
"""Get the access token as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.access_token_enc is not None:
|
||||
return self.access_token_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.access_token)
|
||||
"""Get the access token as a Secret object."""
|
||||
return self.access_token_enc if self.access_token_enc is not None else Secret.from_plaintext(None)
|
||||
|
||||
def get_refresh_token_secret(self) -> Secret:
|
||||
"""Get the refresh token as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.refresh_token_enc is not None:
|
||||
return self.refresh_token_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.refresh_token)
|
||||
"""Get the refresh token as a Secret object."""
|
||||
return self.refresh_token_enc if self.refresh_token_enc is not None else Secret.from_plaintext(None)
|
||||
|
||||
def get_client_secret_secret(self) -> Secret:
|
||||
"""Get the client secret as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.client_secret_enc is not None:
|
||||
return self.client_secret_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.client_secret)
|
||||
"""Get the client secret as a Secret object."""
|
||||
return self.client_secret_enc if self.client_secret_enc is not None else Secret.from_plaintext(None)
|
||||
|
||||
def get_authorization_code_secret(self) -> Secret:
|
||||
"""Get the authorization code as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.authorization_code_enc is not None:
|
||||
return self.authorization_code_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.authorization_code)
|
||||
"""Get the authorization code as a Secret object."""
|
||||
return self.authorization_code_enc if self.authorization_code_enc is not None else Secret.from_plaintext(None)
|
||||
|
||||
def set_access_token_secret(self, secret: Secret) -> None:
|
||||
"""Set access token from a Secret object."""
|
||||
self.access_token_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.access_token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.access_token = None
|
||||
|
||||
def set_refresh_token_secret(self, secret: Secret) -> None:
|
||||
"""Set refresh token from a Secret object."""
|
||||
self.refresh_token_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.refresh_token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.refresh_token = None
|
||||
|
||||
def set_client_secret_secret(self, secret: Secret) -> None:
|
||||
"""Set client secret from a Secret object."""
|
||||
self.client_secret_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.client_secret = secret_dict["plaintext"]
|
||||
else:
|
||||
self.client_secret = None
|
||||
|
||||
def set_authorization_code_secret(self, secret: Secret) -> None:
|
||||
"""Set authorization code from a Secret object."""
|
||||
self.authorization_code_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.authorization_code = secret_dict["plaintext"]
|
||||
else:
|
||||
self.authorization_code = None
|
||||
|
||||
|
||||
class MCPOAuthSessionCreate(BaseMCPOAuth):
|
||||
@@ -290,7 +258,7 @@ class UpdateMCPServerRequest(LettaBase):
|
||||
]
|
||||
|
||||
|
||||
def convert_generic_to_union(server) -> MCPServerUnion:
|
||||
async def convert_generic_to_union(server) -> MCPServerUnion:
|
||||
"""
|
||||
Convert a generic MCPServer (from letta.schemas.mcp) to the appropriate MCPServerUnion type
|
||||
based on the server_type field.
|
||||
@@ -319,24 +287,30 @@ def convert_generic_to_union(server) -> MCPServerUnion:
|
||||
env=server.stdio_config.env if server.stdio_config else None,
|
||||
)
|
||||
elif server.server_type == MCPServerType.SSE:
|
||||
# Get decrypted values from encrypted columns (async)
|
||||
token = await server.token_enc.get_plaintext_async() if server.token_enc else None
|
||||
headers = await server.get_custom_headers_dict_async()
|
||||
return SSEMCPServer(
|
||||
id=server.id,
|
||||
server_name=server.server_name,
|
||||
mcp_server_type=MCPServerType.SSE,
|
||||
server_url=server.server_url,
|
||||
auth_header="Authorization" if server.token else None,
|
||||
auth_token=f"Bearer {server.token}" if server.token else None,
|
||||
custom_headers=server.custom_headers,
|
||||
auth_header="Authorization" if token else None,
|
||||
auth_token=f"Bearer {token}" if token else None,
|
||||
custom_headers=headers,
|
||||
)
|
||||
elif server.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
# Get decrypted values from encrypted columns (async)
|
||||
token = await server.token_enc.get_plaintext_async() if server.token_enc else None
|
||||
headers = await server.get_custom_headers_dict_async()
|
||||
return StreamableHTTPMCPServer(
|
||||
id=server.id,
|
||||
server_name=server.server_name,
|
||||
mcp_server_type=MCPServerType.STREAMABLE_HTTP,
|
||||
server_url=server.server_url,
|
||||
auth_header="Authorization" if server.token else None,
|
||||
auth_token=f"Bearer {server.token}" if server.token else None,
|
||||
custom_headers=server.custom_headers,
|
||||
auth_header="Authorization" if token else None,
|
||||
auth_token=f"Bearer {token}" if token else None,
|
||||
custom_headers=headers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown server type: {server.server_type}")
|
||||
|
||||
@@ -360,6 +360,7 @@ class Message(BaseMessage):
|
||||
if isinstance(lm, SystemMessage):
|
||||
letta_search_results.append(
|
||||
SystemMessageListResult(
|
||||
message_id=message.id,
|
||||
message_type=lm.message_type,
|
||||
content=lm.content,
|
||||
agent_id=message.agent_id,
|
||||
@@ -369,6 +370,7 @@ class Message(BaseMessage):
|
||||
elif isinstance(lm, UserMessage):
|
||||
letta_search_results.append(
|
||||
UserMessageListResult(
|
||||
message_id=message.id,
|
||||
message_type=lm.message_type,
|
||||
content=lm.content,
|
||||
agent_id=message.agent_id,
|
||||
@@ -378,6 +380,7 @@ class Message(BaseMessage):
|
||||
elif isinstance(lm, ReasoningMessage):
|
||||
letta_search_results.append(
|
||||
ReasoningMessageListResult(
|
||||
message_id=message.id,
|
||||
message_type=lm.message_type,
|
||||
reasoning=lm.reasoning,
|
||||
agent_id=message.agent_id,
|
||||
@@ -387,6 +390,7 @@ class Message(BaseMessage):
|
||||
elif isinstance(lm, AssistantMessage):
|
||||
letta_search_results.append(
|
||||
AssistantMessageListResult(
|
||||
message_id=message.id,
|
||||
message_type=lm.message_type,
|
||||
content=lm.content,
|
||||
agent_id=message.agent_id,
|
||||
|
||||
@@ -285,7 +285,7 @@ class AnthropicModelSettings(ModelSettings):
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"extended_thinking": self.thinking.type == "enabled",
|
||||
"thinking_budget_tokens": self.thinking.budget_tokens,
|
||||
"max_reasoning_tokens": self.thinking.budget_tokens,
|
||||
"verbosity": self.verbosity,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
"effort": self.effort,
|
||||
|
||||
@@ -108,7 +108,7 @@ class AnthropicProvider(Provider):
|
||||
base_url: str = "https://api.anthropic.com/v1"
|
||||
|
||||
async def check_api_key(self):
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if api_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
try:
|
||||
@@ -121,13 +121,23 @@ class AnthropicProvider(Provider):
|
||||
else:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for Anthropic models."""
|
||||
if "opus" in model_name:
|
||||
return 16384
|
||||
elif "sonnet" in model_name:
|
||||
return 16384
|
||||
elif "haiku" in model_name:
|
||||
return 8192
|
||||
return 8192 # default for anthropic
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
"""
|
||||
https://docs.anthropic.com/claude/docs/models-overview
|
||||
|
||||
NOTE: currently there is no GET /models, so we need to hardcode
|
||||
"""
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if api_key:
|
||||
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
@@ -171,11 +181,7 @@ class AnthropicProvider(Provider):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
max_tokens = 8192
|
||||
if "claude-3-opus" in model["id"]:
|
||||
max_tokens = 4096
|
||||
if "claude-3-haiku" in model["id"]:
|
||||
max_tokens = 4096
|
||||
max_tokens = self.get_default_max_output_tokens(model["id"])
|
||||
# TODO: set for 3-7 extended thinking mode
|
||||
|
||||
# NOTE: from 2025-02
|
||||
|
||||
@@ -5,7 +5,7 @@ import httpx
|
||||
from openai import AsyncAzureOpenAI
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_CONTEXT_WINDOW
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
@@ -60,7 +60,7 @@ class AzureProvider(Provider):
|
||||
async def azure_openai_get_deployed_model_list(self) -> list:
|
||||
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=self.api_version, azure_endpoint=self.base_url)
|
||||
|
||||
try:
|
||||
@@ -127,6 +127,7 @@ class AzureProvider(Provider):
|
||||
model_endpoint=model_endpoint,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
@@ -165,11 +166,11 @@ class AzureProvider(Provider):
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> int | None:
|
||||
# Hard coded as there are no API endpoints for this
|
||||
llm_default = LLM_MAX_TOKENS.get(model_name, 4096)
|
||||
llm_default = LLM_MAX_CONTEXT_WINDOW.get(model_name, 4096)
|
||||
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default)
|
||||
|
||||
async def check_api_key(self):
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if not api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
|
||||
@@ -145,6 +145,19 @@ class Provider(ProviderBase):
|
||||
async def get_model_context_window_async(self, model_name: str) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""
|
||||
Get the default max output tokens for a model.
|
||||
Override in subclasses for model-specific logic.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model.
|
||||
|
||||
Returns:
|
||||
int: The default max output tokens for the model.
|
||||
"""
|
||||
return 4096 # sensible fallback
|
||||
|
||||
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str:
|
||||
"""
|
||||
Get the handle for a model, with support for custom overrides.
|
||||
|
||||
@@ -26,8 +26,8 @@ class BedrockProvider(Provider):
|
||||
|
||||
try:
|
||||
# Decrypt credentials before using
|
||||
access_key = self.access_key_enc.get_plaintext() if self.access_key_enc else None
|
||||
secret_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
access_key = await self.access_key_enc.get_plaintext_async() if self.access_key_enc else None
|
||||
secret_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
session = Session()
|
||||
async with session.client(
|
||||
@@ -70,6 +70,7 @@ class BedrockProvider(Provider):
|
||||
model_endpoint=None,
|
||||
context_window=self.get_model_context_window(model_arn),
|
||||
handle=self.get_handle(model_arn),
|
||||
max_tokens=self.get_default_max_output_tokens(model_arn),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ class CerebrasProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
|
||||
if "data" in response:
|
||||
@@ -74,6 +74,7 @@ class CerebrasProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
|
||||
@@ -34,7 +34,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
data = response.get("data", response)
|
||||
|
||||
@@ -55,6 +55,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
|
||||
@@ -7,7 +7,7 @@ logger = get_logger(__name__)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_CONTEXT_WINDOW
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -23,14 +23,20 @@ class GoogleAIProvider(Provider):
|
||||
async def check_api_key(self):
|
||||
from letta.llm_api.google_ai_client import google_ai_check_valid_api_key_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
await google_ai_check_valid_api_key_async(api_key)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for Google Gemini models."""
|
||||
if "2.5" in model_name or "2-5" in model_name or model_name.startswith("gemini-3"):
|
||||
return 65536
|
||||
return 8192 # default for google gemini
|
||||
|
||||
async def list_llm_models_async(self):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
||||
|
||||
# Get and filter the model list
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key)
|
||||
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
@@ -50,7 +56,7 @@ class GoogleAIProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
max_tokens=self.get_default_max_output_tokens(model),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
@@ -64,7 +70,7 @@ class GoogleAIProvider(Provider):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
||||
|
||||
# TODO: use base_url instead
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key)
|
||||
return self._list_embedding_models(model_options)
|
||||
|
||||
@@ -95,8 +101,8 @@ class GoogleAIProvider(Provider):
|
||||
logger.warning("This is deprecated, use get_model_context_window_async when possible.")
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_context_window
|
||||
|
||||
if model_name in LLM_MAX_TOKENS:
|
||||
return LLM_MAX_TOKENS[model_name]
|
||||
if model_name in LLM_MAX_CONTEXT_WINDOW:
|
||||
return LLM_MAX_CONTEXT_WINDOW[model_name]
|
||||
else:
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
return google_ai_get_model_context_window(self.base_url, api_key, model_name)
|
||||
@@ -104,8 +110,8 @@ class GoogleAIProvider(Provider):
|
||||
async def get_model_context_window_async(self, model_name: str) -> int | None:
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
|
||||
|
||||
if model_name in LLM_MAX_TOKENS:
|
||||
return LLM_MAX_TOKENS[model_name]
|
||||
if model_name in LLM_MAX_CONTEXT_WINDOW:
|
||||
return LLM_MAX_CONTEXT_WINDOW[model_name]
|
||||
else:
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
return await google_ai_get_model_context_window_async(self.base_url, api_key, model_name)
|
||||
|
||||
@@ -16,6 +16,12 @@ class GoogleVertexProvider(Provider):
|
||||
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
||||
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for Google Vertex models."""
|
||||
if "2.5" in model_name or "2-5" in model_name or model_name.startswith("gemini-3"):
|
||||
return 65536
|
||||
return 8192 # default for google vertex
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH
|
||||
|
||||
@@ -28,7 +34,7 @@ class GoogleVertexProvider(Provider):
|
||||
model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
|
||||
context_window=context_length,
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
max_tokens=self.get_default_max_output_tokens(model),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ class GroqProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
configs = []
|
||||
for model in response["data"]:
|
||||
@@ -29,6 +29,7 @@ class GroqProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
max_tokens=self.get_default_max_output_tokens(model["id"]),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ class LettaProvider(Provider):
|
||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||
context_window=30000,
|
||||
handle=self.get_handle("letta-free"),
|
||||
max_tokens=self.get_default_max_output_tokens("letta-free"),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -61,6 +61,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
model_endpoint=self.model_endpoint_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
compatibility_type=compatibility_type,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
|
||||
@@ -18,7 +18,7 @@ class MistralProvider(Provider):
|
||||
|
||||
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await mistral_get_model_list_async(self.base_url, api_key=api_key)
|
||||
|
||||
assert "data" in response, f"Mistral model query response missing 'data' field: {response}"
|
||||
@@ -34,6 +34,7 @@ class MistralProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_context_length"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
max_tokens=self.get_default_max_output_tokens(model["id"]),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -125,6 +125,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
# model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
# put_inner_thoughts_in_kwargs=True,
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_CONTEXT_WINDOW
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
@@ -26,9 +26,17 @@ class OpenAIProvider(Provider):
|
||||
from letta.llm_api.openai import openai_check_valid_api_key # TODO: DO NOT USE THIS - old code path
|
||||
|
||||
# Decrypt API key before using
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
openai_check_valid_api_key(self.base_url, api_key)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for OpenAI models."""
|
||||
if model_name.startswith("gpt-5"):
|
||||
return 16384
|
||||
elif model_name.startswith("o1") or model_name.startswith("o3"):
|
||||
return 100000
|
||||
return 16384 # default for openai
|
||||
|
||||
async def _get_models_async(self) -> list[dict]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
@@ -40,7 +48,7 @@ class OpenAIProvider(Provider):
|
||||
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
|
||||
|
||||
# Decrypt API key before using
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
response = await openai_get_model_list_async(
|
||||
self.base_url,
|
||||
@@ -154,6 +162,7 @@ class OpenAIProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=handle,
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
@@ -190,16 +199,16 @@ class OpenAIProvider(Provider):
|
||||
return llm_config
|
||||
|
||||
def get_model_context_window_size(self, model_name: str) -> int | None:
|
||||
if model_name in LLM_MAX_TOKENS:
|
||||
return LLM_MAX_TOKENS[model_name]
|
||||
if model_name in LLM_MAX_CONTEXT_WINDOW:
|
||||
return LLM_MAX_CONTEXT_WINDOW[model_name]
|
||||
else:
|
||||
logger.debug(
|
||||
"Model %s on %s for provider %s not found in LLM_MAX_TOKENS. Using default of {LLM_MAX_TOKENS['DEFAULT']}",
|
||||
"Model %s on %s for provider %s not found in LLM_MAX_CONTEXT_WINDOW. Using default of {LLM_MAX_CONTEXT_WINDOW['DEFAULT']}",
|
||||
model_name,
|
||||
self.base_url,
|
||||
self.__class__.__name__,
|
||||
)
|
||||
return LLM_MAX_TOKENS["DEFAULT"]
|
||||
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> int | None:
|
||||
return self.get_model_context_window_size(model_name)
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_CONTEXT_WINDOW
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
@@ -41,6 +41,7 @@ class OpenRouterProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=handle,
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
models = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
return self._list_llm_models(models)
|
||||
|
||||
@@ -93,7 +93,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
return configs
|
||||
|
||||
async def check_api_key(self):
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if not api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class VLLMProvider(Provider):
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model_name, base_name=self.handle_base) if self.handle_base else self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ class XAIProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
|
||||
data = response.get("data", response)
|
||||
@@ -65,6 +65,7 @@ class XAIProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, PrivateAttr
|
||||
@@ -17,22 +16,17 @@ class Secret(BaseModel):
|
||||
This class ensures that sensitive data remains encrypted as much as possible
|
||||
while passing through the codebase, only decrypting when absolutely necessary.
|
||||
|
||||
Migration status (Phase 1 - encrypted-first reads with plaintext fallback):
|
||||
- Reads: Prefer _enc columns, fallback to plaintext columns with ERROR logging
|
||||
- Writes: Still dual-write to both _enc and plaintext columns for backward compatibility
|
||||
- Encryption: Optional - if LETTA_ENCRYPTION_KEY is not set, stores plaintext in _enc column
|
||||
|
||||
TODO (Phase 2): Remove plaintext fallback in from_db() after verifying no error logs
|
||||
TODO (Phase 3): Remove dual-write logic in to_dict() and set_*_secret() methods
|
||||
TODO (Phase 4): Remove from_db() plaintext_value parameter, was_encrypted flag, and plaintext columns
|
||||
Usage:
|
||||
- Create from plaintext: Secret.from_plaintext(value)
|
||||
- Create from encrypted DB value: Secret.from_encrypted(encrypted_value)
|
||||
- Get encrypted for storage: secret.get_encrypted()
|
||||
- Get plaintext when needed: secret.get_plaintext()
|
||||
"""
|
||||
|
||||
# Store the encrypted value as a regular field
|
||||
encrypted_value: Optional[str] = None
|
||||
# Cache the decrypted value to avoid repeated decryption (not serialized for security)
|
||||
_plaintext_cache: Optional[str] = PrivateAttr(default=None)
|
||||
# Flag to indicate if the value was originally encrypted
|
||||
was_encrypted: bool = False
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@@ -51,7 +45,7 @@ class Secret(BaseModel):
|
||||
A Secret instance with the encrypted (or plaintext) value
|
||||
"""
|
||||
if value is None:
|
||||
return cls.model_construct(encrypted_value=None, was_encrypted=False)
|
||||
return cls.model_construct(encrypted_value=None)
|
||||
|
||||
# Guard against double encryption - check if value is already encrypted
|
||||
if CryptoUtils.is_encrypted(value):
|
||||
@@ -60,7 +54,7 @@ class Secret(BaseModel):
|
||||
# Try to encrypt, but fall back to storing plaintext if no encryption key
|
||||
try:
|
||||
encrypted = CryptoUtils.encrypt(value)
|
||||
return cls.model_construct(encrypted_value=encrypted, was_encrypted=False)
|
||||
return cls.model_construct(encrypted_value=encrypted)
|
||||
except ValueError as e:
|
||||
# No encryption key available, store as plaintext in the _enc column
|
||||
if "No encryption key configured" in str(e):
|
||||
@@ -68,7 +62,7 @@ class Secret(BaseModel):
|
||||
"No encryption key configured. Storing Secret value as plaintext in _enc column. "
|
||||
"Set LETTA_ENCRYPTION_KEY environment variable to enable encryption."
|
||||
)
|
||||
instance = cls.model_construct(encrypted_value=value, was_encrypted=False)
|
||||
instance = cls.model_construct(encrypted_value=value)
|
||||
instance._plaintext_cache = value # Cache it since we know the plaintext
|
||||
return instance
|
||||
raise # Re-raise if it's a different error
|
||||
@@ -76,47 +70,15 @@ class Secret(BaseModel):
|
||||
@classmethod
|
||||
def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret":
|
||||
"""
|
||||
Create a Secret from an already encrypted value.
|
||||
Create a Secret from an already encrypted value (read from DB).
|
||||
|
||||
Args:
|
||||
encrypted_value: The encrypted value
|
||||
encrypted_value: The encrypted value from the _enc column
|
||||
|
||||
Returns:
|
||||
A Secret instance
|
||||
"""
|
||||
return cls.model_construct(encrypted_value=encrypted_value, was_encrypted=True)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[str] = None) -> "Secret":
|
||||
"""
|
||||
Create a Secret from database values. Prefers encrypted column, falls back to plaintext with error logging.
|
||||
|
||||
During Phase 1 of migration, this method:
|
||||
1. Uses encrypted_value if available (preferred)
|
||||
2. Falls back to plaintext_value with ERROR logging if encrypted is unavailable
|
||||
3. Returns empty Secret if neither is available
|
||||
|
||||
The error logging helps identify any records that haven't been migrated to encrypted columns.
|
||||
|
||||
Args:
|
||||
encrypted_value: The encrypted value from the database (_enc column)
|
||||
plaintext_value: The plaintext value from the database (legacy column, fallback only)
|
||||
|
||||
Returns:
|
||||
A Secret instance with the value from encrypted or plaintext column
|
||||
"""
|
||||
if encrypted_value is not None:
|
||||
return cls.from_encrypted(encrypted_value)
|
||||
# Fallback to plaintext with error logging - this helps identify unmigrated data
|
||||
if plaintext_value is not None:
|
||||
logger.error(
|
||||
"MIGRATION_NEEDED: Reading from plaintext column instead of encrypted column. "
|
||||
"This indicates data that hasn't been migrated to the _enc column yet. "
|
||||
"Please run migrate data to _enc columns as plaintext columns will be deprecated.",
|
||||
stack_info=True,
|
||||
)
|
||||
return cls.from_plaintext(plaintext_value)
|
||||
return cls.from_plaintext(None)
|
||||
return cls.model_construct(encrypted_value=encrypted_value)
|
||||
|
||||
def get_encrypted(self) -> Optional[str]:
|
||||
"""
|
||||
@@ -146,14 +108,8 @@ class Secret(BaseModel):
|
||||
if self.encrypted_value is None:
|
||||
return None
|
||||
|
||||
# Use cached value if available, but only if it looks like plaintext
|
||||
# or we're confident we can decrypt it
|
||||
# Use cached value if available
|
||||
if self._plaintext_cache is not None:
|
||||
# If this was explicitly created as plaintext, trust the cache
|
||||
# This prevents false positives from is_encrypted() heuristic
|
||||
if not self.was_encrypted:
|
||||
return self._plaintext_cache
|
||||
# For encrypted values, trust the cache (already decrypted previously)
|
||||
return self._plaintext_cache
|
||||
|
||||
# Try to decrypt
|
||||
@@ -213,8 +169,6 @@ class Secret(BaseModel):
|
||||
|
||||
# Use cached value if available
|
||||
if self._plaintext_cache is not None:
|
||||
if not self.was_encrypted:
|
||||
return self._plaintext_cache
|
||||
return self._plaintext_cache
|
||||
|
||||
# Try to decrypt (async)
|
||||
@@ -265,14 +219,6 @@ class Secret(BaseModel):
|
||||
"""Representation that doesn't expose the actual value."""
|
||||
return self.__str__()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert to dictionary for database storage.
|
||||
|
||||
Returns both encrypted and plaintext values for dual-write during migration.
|
||||
"""
|
||||
return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self.was_encrypted else None}
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""
|
||||
Compare two secrets by their plaintext values.
|
||||
|
||||
@@ -241,6 +241,10 @@ def create_application() -> "FastAPI":
|
||||
os.environ.setdefault("DD_PROFILING_MEMORY_ENABLED", str(telemetry_settings.datadog_profiling_memory_enabled).lower())
|
||||
os.environ.setdefault("DD_PROFILING_HEAP_ENABLED", str(telemetry_settings.datadog_profiling_heap_enabled).lower())
|
||||
|
||||
# Enable LLM Observability for tracking LLM calls, prompts, and completions
|
||||
os.environ.setdefault("DD_LLMOBS_ENABLED", "1")
|
||||
os.environ.setdefault("DD_LLMOBS_ML_APP", "memgpt-server")
|
||||
|
||||
# Note: DD_LOGS_INJECTION, DD_APPSEC_ENABLED, DD_IAST_ENABLED, DD_APPSEC_SCA_ENABLED
|
||||
# are set via deployment configs and automatically picked up by ddtrace
|
||||
|
||||
|
||||
@@ -16,6 +16,33 @@ from letta.settings import model_settings
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def strip_policy_specs(text: str) -> str:
|
||||
"""
|
||||
Remove Claude policy injection blocks from message text.
|
||||
|
||||
Claude injects policy instructions in two forms:
|
||||
1. Appended with prefix: 'user: <policy_spec>...'
|
||||
2. As entire message: '<policy_spec>...'
|
||||
|
||||
We truncate everything from the policy start marker onwards since it's all injected policy content.
|
||||
"""
|
||||
# Check if entire message is a policy spec (starts with tag)
|
||||
if text.startswith("<policy_spec>"):
|
||||
logger.info("[Proxy Helpers] Stripped policy injection (entire message)")
|
||||
return ""
|
||||
|
||||
# Check if policy spec is appended (with prefix)
|
||||
policy_start = text.find("user: <policy_spec>")
|
||||
if policy_start != -1:
|
||||
logger.info(f"[Proxy Helpers] Stripped policy injection from position {policy_start}")
|
||||
# Truncate everything from this point onwards
|
||||
cleaned = text[:policy_start].strip()
|
||||
return cleaned
|
||||
|
||||
# No policy injection found, return original text
|
||||
return text
|
||||
|
||||
|
||||
def extract_user_messages(body: bytes) -> list[str]:
|
||||
"""Extract user messages from request body."""
|
||||
messages = []
|
||||
@@ -28,12 +55,19 @@ def extract_user_messages(body: bytes) -> list[str]:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
user_messages.append(content)
|
||||
# Strip policy specs before adding
|
||||
cleaned = strip_policy_specs(content)
|
||||
if cleaned: # Only add if not empty after stripping
|
||||
user_messages.append(cleaned)
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
user_messages.append(block.get("text", ""))
|
||||
text = block.get("text", "")
|
||||
# Strip policy specs from text blocks
|
||||
cleaned = strip_policy_specs(text)
|
||||
if cleaned: # Only add if not empty after stripping
|
||||
user_messages.append(cleaned)
|
||||
elif block.get("type") == "image":
|
||||
user_messages.append("[IMAGE]")
|
||||
|
||||
@@ -419,6 +453,7 @@ async def get_or_create_claude_code_agent(
|
||||
server,
|
||||
actor,
|
||||
project_id: str = None,
|
||||
agent_id: str = None,
|
||||
):
|
||||
"""
|
||||
Get or create a special agent for Claude Code sessions.
|
||||
@@ -427,12 +462,24 @@ async def get_or_create_claude_code_agent(
|
||||
server: SyncServer instance
|
||||
actor: Actor performing the operation (user ID)
|
||||
project_id: Optional project ID to associate the agent with
|
||||
agent_id: Optional specific agent ID to use (from X-LETTA-AGENT-ID header)
|
||||
|
||||
Returns:
|
||||
Agent ID
|
||||
Agent instance
|
||||
"""
|
||||
from letta.schemas.agent import CreateAgent
|
||||
|
||||
# If a specific agent ID is provided, try to use it directly
|
||||
if agent_id:
|
||||
logger.debug(f"Attempting to fetch agent by ID: {agent_id}")
|
||||
try:
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
logger.info(f"Found agent via X-LETTA-AGENT-ID header: {agent.id} (name: {agent.name})")
|
||||
return agent
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not find agent with ID {agent_id}: {e}. Falling back to default behavior.")
|
||||
# Fall through to default behavior below
|
||||
|
||||
# Create short user identifier from UUID (first 8 chars)
|
||||
if actor:
|
||||
user_short_id = str(actor.id)[:8] if hasattr(actor, "id") else str(actor)[:8]
|
||||
|
||||
@@ -638,10 +638,11 @@ async def run_tool_for_agent(
|
||||
)
|
||||
|
||||
# Build environment variables dict from agent secrets
|
||||
# Use pre-decrypted value field (populated in from_orm_async)
|
||||
sandbox_env_vars = {}
|
||||
if agent.tool_exec_environment_variables:
|
||||
for env_var in agent.tool_exec_environment_variables:
|
||||
sandbox_env_vars[env_var.key] = env_var.value_enc.get_plaintext() if env_var.value_enc else None
|
||||
sandbox_env_vars[env_var.key] = env_var.value or ""
|
||||
|
||||
# Create tool execution manager and execute the tool
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
|
||||
@@ -62,8 +62,8 @@ async def anthropic_messages_proxy(
|
||||
# Claude Code sends full conversation history, but we only want to persist the new message
|
||||
user_messages = [all_user_messages[-1]] if all_user_messages else []
|
||||
|
||||
# Filter out system/metadata requests
|
||||
user_messages = [s for s in user_messages if not s.startswith("<system-reminder>")]
|
||||
# Filter out system/metadata requests and policy specs
|
||||
user_messages = [s for s in user_messages if not s.startswith("<system-reminder>") and not s.startswith("<policy_spec>")]
|
||||
if not user_messages:
|
||||
logger.debug(f"[{PROXY_NAME}] Skipping capture/memory for this turn")
|
||||
|
||||
@@ -99,10 +99,14 @@ async def anthropic_messages_proxy(
|
||||
# Message persistence happens in the background after the response is returned.
|
||||
agent = None
|
||||
try:
|
||||
# Check if X-LETTA-AGENT-ID header is provided
|
||||
custom_agent_id = request.headers.get("x-letta-agent-id")
|
||||
|
||||
agent = await get_or_create_claude_code_agent(
|
||||
server=server,
|
||||
actor=actor,
|
||||
project_id=project_id,
|
||||
agent_id=custom_agent_id,
|
||||
)
|
||||
logger.debug(f"[{PROXY_NAME}] Using agent ID: {agent.id}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -50,7 +50,7 @@ async def create_mcp_server(
|
||||
# TODO: add the tools to the MCP server table we made.
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
new_server = await server.mcp_server_manager.create_mcp_server_from_request(request, actor=actor)
|
||||
return convert_generic_to_union(new_server)
|
||||
return await convert_generic_to_union(new_server)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -67,7 +67,10 @@ async def list_mcp_servers(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
|
||||
return [convert_generic_to_union(mcp_server) for mcp_server in mcp_servers]
|
||||
result = []
|
||||
for mcp_server in mcp_servers:
|
||||
result.append(await convert_generic_to_union(mcp_server))
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -85,7 +88,7 @@ async def retrieve_mcp_server(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
current_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
||||
return convert_generic_to_union(current_server)
|
||||
return await convert_generic_to_union(current_server)
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -125,7 +128,7 @@ async def update_mcp_server(
|
||||
updated_server = await server.mcp_server_manager.update_mcp_server_by_id(
|
||||
mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor
|
||||
)
|
||||
return convert_generic_to_union(updated_server)
|
||||
return await convert_generic_to_union(updated_server)
|
||||
|
||||
|
||||
@router.get("/{mcp_server_id}/tools", response_model=List[Tool], operation_id="mcp_list_tools_for_mcp_server")
|
||||
@@ -238,7 +241,7 @@ async def connect_mcp_server(
|
||||
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
||||
|
||||
# Convert the MCP server to the appropriate config type
|
||||
config = mcp_server.to_config(resolve_variables=False)
|
||||
config = await mcp_server.to_config_async(resolve_variables=False)
|
||||
|
||||
async def oauth_stream_generator(
|
||||
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
|
||||
|
||||
@@ -122,8 +122,8 @@ async def check_existing_provider(
|
||||
provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
|
||||
|
||||
# Create a ProviderCheck from the existing provider
|
||||
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
|
||||
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
provider_check = ProviderCheck(
|
||||
provider_type=provider.provider_type,
|
||||
api_key=api_key,
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.errors import (
|
||||
LettaMCPTimeoutError,
|
||||
LettaToolCreateError,
|
||||
LettaToolNameConflictError,
|
||||
LLMError,
|
||||
)
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.mcp_client.exceptions import MCPTimeoutError
|
||||
@@ -426,7 +427,10 @@ async def list_mcp_servers(
|
||||
else:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
mcp_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return {server.server_name: server.to_config(resolve_variables=False) for server in mcp_servers}
|
||||
result = {}
|
||||
for mcp_server in mcp_servers:
|
||||
result[mcp_server.server_name] = await mcp_server.to_config_async(resolve_variables=False)
|
||||
return result
|
||||
|
||||
|
||||
# NOTE: async because the MCP client/session calls are async
|
||||
@@ -555,7 +559,10 @@ async def add_mcp_server_to_config(
|
||||
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return [server.to_config() for server in all_servers]
|
||||
result = []
|
||||
for mcp_server in all_servers:
|
||||
result.append(await mcp_server.to_config_async())
|
||||
return result
|
||||
|
||||
|
||||
@router.patch(
|
||||
@@ -580,7 +587,7 @@ async def update_mcp_server(
|
||||
updated_server = await server.mcp_manager.update_mcp_server_by_name(
|
||||
mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor
|
||||
)
|
||||
return updated_server.to_config()
|
||||
return await updated_server.to_config_async()
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -607,7 +614,10 @@ async def delete_mcp_server_from_config(
|
||||
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return [server.to_config() for server in all_servers]
|
||||
result = []
|
||||
for mcp_server in all_servers:
|
||||
result.append(await mcp_server.to_config_async())
|
||||
return result
|
||||
|
||||
|
||||
@deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream")
|
||||
@@ -794,7 +804,7 @@ async def execute_mcp_tool(
|
||||
raise NoResultFound(f"MCP server '{mcp_server_name}' not found")
|
||||
|
||||
# Create client and connect
|
||||
server_config = mcp_server.to_config()
|
||||
server_config = await mcp_server.to_config_async()
|
||||
server_config.resolve_environment_variables()
|
||||
client = await server.mcp_manager.get_mcp_client(server_config, actor)
|
||||
await client.connect_to_server()
|
||||
@@ -924,6 +934,14 @@ async def generate_tool_from_prompt(
|
||||
)
|
||||
response_data = await llm_client.request_async(request_data, llm_config)
|
||||
response = await llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
|
||||
# Validate that we got a tool call response
|
||||
if not response.choices or not response.choices[0].message.tool_calls:
|
||||
error_msg = (
|
||||
response.choices[0].message.content if response.choices and response.choices[0].message.content else "No response from LLM"
|
||||
)
|
||||
raise LLMError(f"Failed to generate tool '{request.tool_name}': LLM did not return a tool call. Response: {error_msg}")
|
||||
|
||||
output = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
|
||||
pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()]
|
||||
|
||||
|
||||
@@ -62,8 +62,8 @@ async def zai_messages_proxy(
|
||||
# Claude Code sends full conversation history, but we only want to persist the new message
|
||||
user_messages = [all_user_messages[-1]] if all_user_messages else []
|
||||
|
||||
# Filter out system/metadata requests
|
||||
user_messages = [s for s in user_messages if not s.startswith("<system-reminder>")]
|
||||
# Filter out system/metadata requests and policy specs
|
||||
user_messages = [s for s in user_messages if not s.startswith("<system-reminder>") and not s.startswith("<policy_spec>")]
|
||||
if not user_messages:
|
||||
logger.debug(f"[{PROXY_NAME}] Skipping capture/memory for this turn")
|
||||
|
||||
|
||||
@@ -12,6 +12,29 @@ wait_for_postgres() {
|
||||
done
|
||||
}
|
||||
|
||||
# Function to wait for Redis to be ready
|
||||
wait_for_redis() {
|
||||
until redis-cli ping 2>/dev/null | grep -q PONG; do
|
||||
echo "Waiting for Redis to be ready..."
|
||||
sleep 1
|
||||
done
|
||||
}
|
||||
|
||||
# Check if we're configured for external Redis
|
||||
if [ -n "$LETTA_REDIS_HOST" ]; then
|
||||
echo "External Redis configuration detected, using env var LETTA_REDIS_HOST=$LETTA_REDIS_HOST"
|
||||
else
|
||||
echo "No external Redis configuration detected, starting internal Redis..."
|
||||
redis-server --daemonize yes --bind 0.0.0.0
|
||||
|
||||
# Wait for Redis to be ready
|
||||
wait_for_redis
|
||||
|
||||
# Set default Redis host for internal redis
|
||||
export LETTA_REDIS_HOST="localhost"
|
||||
echo "Using internal Redis at: $LETTA_REDIS_HOST"
|
||||
fi
|
||||
|
||||
# Check if we're configured for external Postgres
|
||||
if [ -n "$LETTA_PG_URI" ]; then
|
||||
echo "External Postgres configuration detected, using env var LETTA_PG_URI"
|
||||
|
||||
@@ -717,8 +717,8 @@ class AgentManager:
|
||||
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def update_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -841,7 +841,7 @@ class AgentManager:
|
||||
existing_value = None
|
||||
if existing_env and existing_env.value_enc:
|
||||
existing_secret = Secret.from_encrypted(existing_env.value_enc)
|
||||
existing_value = existing_secret.get_plaintext()
|
||||
existing_value = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Encrypt value (reuse existing encrypted value if unchanged)
|
||||
if existing_value == v and existing_env and existing_env.value_enc:
|
||||
@@ -1081,8 +1081,8 @@ class AgentManager:
|
||||
return await AgentModel.size_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_agent_by_id_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -1142,8 +1142,8 @@ class AgentManager:
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_agent_archive_ids_async(self, agent_id: str, actor: PydanticUser) -> List[str]:
|
||||
"""Get all archive IDs associated with an agent."""
|
||||
from letta.orm import ArchivesAgents
|
||||
@@ -1156,8 +1156,8 @@ class AgentManager:
|
||||
return archive_ids
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def validate_agent_exists_async(self, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Validate that an agent exists and user has access to it.
|
||||
@@ -1174,8 +1174,8 @@ class AgentManager:
|
||||
await validate_agent_exists_async(session, agent_id, actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Deletes an agent and its associated relationships.
|
||||
@@ -1237,8 +1237,8 @@ class AgentManager:
|
||||
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
|
||||
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
return await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
@@ -1250,8 +1250,8 @@ class AgentManager:
|
||||
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
||||
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
|
||||
@@ -1432,8 +1432,8 @@ class AgentManager:
|
||||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@@ -1543,8 +1543,8 @@ class AgentManager:
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
@@ -1656,9 +1656,9 @@ class AgentManager:
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a source to an agent.
|
||||
@@ -1732,8 +1732,8 @@ class AgentManager:
|
||||
self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser):
|
||||
"""
|
||||
Async version of append_system_message.
|
||||
@@ -1820,9 +1820,9 @@ class AgentManager:
|
||||
return [source.to_pydantic() for source in sources]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a source from an agent.
|
||||
@@ -1909,9 +1909,9 @@ class AgentManager:
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -2503,9 +2503,9 @@ class AgentManager:
|
||||
# Tool Management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||||
@trace_method
|
||||
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Attaches a tool to an agent.
|
||||
@@ -2573,8 +2573,8 @@ class AgentManager:
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||||
"""
|
||||
Efficiently attaches multiple tools to an agent in a single operation.
|
||||
@@ -2739,9 +2739,9 @@ class AgentManager:
|
||||
return PydanticAgentState(**agent_state_dict)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||||
@trace_method
|
||||
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Detaches a tool from an agent.
|
||||
@@ -2770,8 +2770,8 @@ class AgentManager:
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||||
"""
|
||||
Efficiently detaches multiple tools from an agent in a single operation.
|
||||
@@ -2807,8 +2807,8 @@ class AgentManager:
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def modify_approvals_async(self, agent_id: str, tool_name: str, requires_approval: bool, actor: PydanticUser) -> None:
|
||||
def is_target_rule(rule):
|
||||
return rule.tool_name == tool_name and rule.type == "requires_approval"
|
||||
@@ -3157,8 +3157,8 @@ class AgentManager:
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_agent_files_config_async(self, agent_id: str, actor: PydanticUser) -> Tuple[int, int]:
|
||||
"""Get per_file_view_window_char_limit and max_files_open for an agent.
|
||||
|
||||
@@ -3214,8 +3214,8 @@ class AgentManager:
|
||||
return per_file_limit, max_files
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_agent_max_files_open_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
"""Get max_files_open for an agent.
|
||||
|
||||
@@ -3243,8 +3243,8 @@ class AgentManager:
|
||||
return row
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_agent_per_file_view_window_char_limit_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
"""Get per_file_view_window_char_limit for an agent.
|
||||
|
||||
@@ -3272,8 +3272,8 @@ class AgentManager:
|
||||
return row
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
|
||||
agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async(
|
||||
agent_id=agent_id, actor=actor, force=True, dry_run=True
|
||||
|
||||
@@ -55,8 +55,8 @@ class ArchiveManager:
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def get_archive_by_id_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -72,8 +72,8 @@ class ArchiveManager:
|
||||
return archive.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def update_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -99,8 +99,8 @@ class ArchiveManager:
|
||||
return archive.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def list_archives_async(
|
||||
self,
|
||||
*,
|
||||
@@ -150,9 +150,9 @@ class ArchiveManager:
|
||||
return [a.to_pydantic() for a in archives]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def attach_agent_to_archive_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -194,9 +194,9 @@ class ArchiveManager:
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def detach_agent_from_archive_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -227,8 +227,8 @@ class ArchiveManager:
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_default_archive_for_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -260,8 +260,8 @@ class ArchiveManager:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def delete_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -278,8 +278,8 @@ class ArchiveManager:
|
||||
logger.info(f"Deleted archive {archive_id}")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def create_passage_in_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -360,9 +360,9 @@ class ArchiveManager:
|
||||
return created_passage
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@raise_on_invalid_id(param_name="passage_id", expected_prefix=PrimitiveType.PASSAGE)
|
||||
@trace_method
|
||||
async def delete_passage_from_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -470,8 +470,8 @@ class ArchiveManager:
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def get_agents_for_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -583,8 +583,8 @@ class ArchiveManager:
|
||||
return agent_ids[0]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def get_or_set_vector_db_namespace_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
|
||||
@@ -134,8 +134,8 @@ class BlockManager:
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Update a block by its ID with the given BlockUpdate object."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -154,8 +154,8 @@ class BlockManager:
|
||||
return pydantic_block
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def delete_block_async(self, block_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a block by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -353,8 +353,8 @@ class BlockManager:
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -413,8 +413,8 @@ class BlockManager:
|
||||
return pydantic_blocks
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def get_agents_for_block_async(
|
||||
self,
|
||||
block_id: str,
|
||||
@@ -600,9 +600,9 @@ class BlockManager:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def checkpoint_block_async(
|
||||
self,
|
||||
block_id: str,
|
||||
@@ -710,8 +710,8 @@ class BlockManager:
|
||||
return updated_block
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def undo_checkpoint_block(
|
||||
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
|
||||
) -> PydanticBlock:
|
||||
@@ -761,8 +761,8 @@ class BlockManager:
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def redo_checkpoint_block(
|
||||
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
|
||||
) -> PydanticBlock:
|
||||
|
||||
@@ -93,8 +93,8 @@ class FileManager:
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
@trace_method
|
||||
# @async_redis_cache(
|
||||
# key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}",
|
||||
# prefix="file_content",
|
||||
@@ -136,8 +136,8 @@ class FileManager:
|
||||
return await file_orm.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
@trace_method
|
||||
async def update_file_status(
|
||||
self,
|
||||
*,
|
||||
@@ -354,8 +354,8 @@ class FileManager:
|
||||
return file_metadata
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
@trace_method
|
||||
async def upsert_file_content(
|
||||
self,
|
||||
*,
|
||||
@@ -400,8 +400,8 @@ class FileManager:
|
||||
return await result.scalar_one().to_pydantic_async(include_content=True)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def list_files(
|
||||
self,
|
||||
source_id: str,
|
||||
@@ -462,8 +462,8 @@ class FileManager:
|
||||
return file_metadatas
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
@trace_method
|
||||
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
|
||||
"""Delete a file by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -517,8 +517,8 @@ class FileManager:
|
||||
return f"{source.name}/{base}_({count}){ext}"
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
# @async_redis_cache(
|
||||
# key_func=lambda self, original_filename, source_id, actor: f"{original_filename}:{source_id}:{actor.organization_id}",
|
||||
# prefix="file_by_name",
|
||||
|
||||
@@ -65,8 +65,8 @@ class GroupManager:
|
||||
return [group.to_pydantic() for group in groups]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@trace_method
|
||||
async def retrieve_group_async(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
@@ -123,8 +123,8 @@ class GroupManager:
|
||||
return new_group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@trace_method
|
||||
async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
@@ -187,16 +187,16 @@ class GroupManager:
|
||||
return group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@trace_method
|
||||
async def delete_group_async(self, group_id: str, actor: PydanticUser) -> None:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
await group.hard_delete_async(session)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@trace_method
|
||||
async def list_group_messages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -233,8 +233,8 @@ class GroupManager:
|
||||
return messages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@trace_method
|
||||
async def reset_messages_async(self, group_id: str, actor: PydanticUser) -> None:
|
||||
async with db_registry.async_session() as session:
|
||||
# Ensure group is loadable by user
|
||||
@@ -249,8 +249,8 @@ class GroupManager:
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@trace_method
|
||||
async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int:
|
||||
async with db_registry.async_session() as session:
|
||||
# Ensure group is loadable by user
|
||||
@@ -262,9 +262,9 @@ class GroupManager:
|
||||
return group.turns_counter
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@raise_on_invalid_id(param_name="last_processed_message_id", expected_prefix=PrimitiveType.MESSAGE)
|
||||
@trace_method
|
||||
async def get_last_processed_message_id_and_update_async(
|
||||
self, group_id: str, last_processed_message_id: str, actor: PydanticUser
|
||||
) -> str:
|
||||
@@ -413,9 +413,9 @@ class GroupManager:
|
||||
session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label))
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def attach_block_async(self, group_id: str, block_id: str, actor: PydanticUser) -> None:
|
||||
"""Attach a block to a group."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -437,9 +437,9 @@ class GroupManager:
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def detach_block_async(self, group_id: str, block_id: str, actor: PydanticUser) -> None:
|
||||
"""Detach a block from a group."""
|
||||
async with db_registry.async_session() as session:
|
||||
|
||||
@@ -83,8 +83,8 @@ class IdentityManager:
|
||||
return [identity.to_pydantic() for identity in identities], next_cursor, has_more
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@trace_method
|
||||
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
@@ -165,8 +165,8 @@ class IdentityManager:
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@trace_method
|
||||
async def update_identity_async(
|
||||
self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
|
||||
) -> PydanticIdentity:
|
||||
@@ -229,8 +229,8 @@ class IdentityManager:
|
||||
return existing_identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@trace_method
|
||||
async def upsert_identity_properties_async(
|
||||
self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser
|
||||
) -> PydanticIdentity:
|
||||
@@ -247,8 +247,8 @@ class IdentityManager:
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@trace_method
|
||||
async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None:
|
||||
async with db_registry.async_session() as session:
|
||||
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
@@ -305,8 +305,8 @@ class IdentityManager:
|
||||
current_relationship.extend(new_items)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@trace_method
|
||||
async def list_agents_for_identity_async(
|
||||
self,
|
||||
identity_id: str,
|
||||
@@ -338,8 +338,8 @@ class IdentityManager:
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents])
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@trace_method
|
||||
async def list_blocks_for_identity_async(
|
||||
self,
|
||||
identity_id: str,
|
||||
@@ -370,9 +370,9 @@ class IdentityManager:
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def attach_agent_async(self, identity_id: str, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Attach an agent to an identity.
|
||||
@@ -388,9 +388,9 @@ class IdentityManager:
|
||||
await identity.update_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def detach_agent_async(self, identity_id: str, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Detach an agent from an identity.
|
||||
@@ -406,9 +406,9 @@ class IdentityManager:
|
||||
await identity.update_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def attach_block_async(self, identity_id: str, block_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Attach a block to an identity.
|
||||
@@ -424,9 +424,9 @@ class IdentityManager:
|
||||
await identity.update_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def detach_block_async(self, identity_id: str, block_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Detach a block from an identity.
|
||||
|
||||
@@ -70,8 +70,8 @@ class JobManager:
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
@trace_method
|
||||
async def update_job_by_id_async(
|
||||
self, job_id: str, job_update: JobUpdate, actor: PydanticUser, safe_update: bool = False
|
||||
) -> PydanticJob:
|
||||
@@ -148,8 +148,8 @@ class JobManager:
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
@trace_method
|
||||
async def safe_update_job_status_async(
|
||||
self,
|
||||
job_id: str,
|
||||
@@ -189,8 +189,8 @@ class JobManager:
|
||||
return False
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
@trace_method
|
||||
async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Fetch a job by its ID asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -304,8 +304,8 @@ class JobManager:
|
||||
return [job.to_pydantic() for job in jobs]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
@trace_method
|
||||
async def delete_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Delete a job by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -314,8 +314,8 @@ class JobManager:
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
@trace_method
|
||||
async def get_run_messages(
|
||||
self,
|
||||
run_id: str,
|
||||
@@ -372,8 +372,8 @@ class JobManager:
|
||||
return messages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
@trace_method
|
||||
async def get_step_messages(
|
||||
self,
|
||||
run_id: str,
|
||||
@@ -537,8 +537,8 @@ class JobManager:
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
@trace_method
|
||||
async def get_job_steps(
|
||||
self,
|
||||
job_id: str,
|
||||
|
||||
@@ -37,14 +37,12 @@ class DatabaseTokenStorage(TokenStorage):
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
# Decrypt tokens using getter methods
|
||||
access_token_secret = oauth_session.get_access_token_secret()
|
||||
access_token = access_token_secret.get_plaintext()
|
||||
# Read tokens directly from _enc columns
|
||||
access_token = await oauth_session.access_token_enc.get_plaintext_async() if oauth_session.access_token_enc else None
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
refresh_token_secret = oauth_session.get_refresh_token_secret()
|
||||
refresh_token = refresh_token_secret.get_plaintext()
|
||||
refresh_token = await oauth_session.refresh_token_enc.get_plaintext_async() if oauth_session.refresh_token_enc else None
|
||||
|
||||
return OAuthToken(
|
||||
access_token=access_token,
|
||||
@@ -72,9 +70,8 @@ class DatabaseTokenStorage(TokenStorage):
|
||||
if not oauth_session or not oauth_session.client_id:
|
||||
return None
|
||||
|
||||
# Decrypt client secret using getter method
|
||||
client_secret_secret = oauth_session.get_client_secret_secret()
|
||||
client_secret = client_secret_secret.get_plaintext()
|
||||
# Read client secret directly from _enc column
|
||||
client_secret = await oauth_session.client_secret_enc.get_plaintext_async() if oauth_session.client_secret_enc else None
|
||||
|
||||
return OAuthClientInformationFull(
|
||||
client_id=oauth_session.client_id,
|
||||
@@ -147,19 +144,15 @@ class MCPOAuthSession:
|
||||
|
||||
async def store_authorization_code(self, code: str, state: str) -> Optional[MCPOAuth]:
|
||||
"""Store the authorization code from OAuth callback."""
|
||||
# Use mcp_manager to ensure proper encryption
|
||||
from letta.schemas.mcp import MCPOAuthSessionUpdate
|
||||
from letta.schemas.secret import Secret
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
|
||||
|
||||
# Encrypt the authorization_code before storing
|
||||
# Encrypt the authorization_code and store only in _enc column
|
||||
if code is not None:
|
||||
oauth_record.authorization_code_enc = Secret.from_plaintext(code).get_encrypted()
|
||||
# Keep plaintext for dual-write during migration
|
||||
oauth_record.authorization_code = code
|
||||
|
||||
oauth_record.status = OAuthSessionStatus.AUTHORIZED
|
||||
oauth_record.state = state
|
||||
@@ -234,10 +227,10 @@ async def create_oauth_provider(
|
||||
logger.info(f"Waiting for authorization code for session {session_id}")
|
||||
while time.time() - start_time < timeout:
|
||||
oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor)
|
||||
if oauth_session and oauth_session.authorization_code:
|
||||
# Decrypt the authorization code before returning
|
||||
auth_code_secret = oauth_session.get_authorization_code_secret()
|
||||
return auth_code_secret.get_plaintext(), oauth_session.state
|
||||
if oauth_session and oauth_session.authorization_code_enc:
|
||||
# Read authorization code directly from _enc column
|
||||
auth_code = await oauth_session.authorization_code_enc.get_plaintext_async()
|
||||
return auth_code, oauth_session.state
|
||||
elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR:
|
||||
raise Exception("OAuth authorization failed")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@@ -70,7 +70,7 @@ class MCPManager:
|
||||
try:
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
server_config = await mcp_config.to_config_async()
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
@@ -116,7 +116,7 @@ class MCPManager:
|
||||
# read from DB
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config(environment_variables)
|
||||
server_config = await mcp_config.to_config_async(environment_variables)
|
||||
else:
|
||||
# read from config file
|
||||
mcp_config = await self.read_mcp_config()
|
||||
@@ -419,16 +419,14 @@ class MCPManager:
|
||||
server_type=server_config.type,
|
||||
server_url=server_config.server_url,
|
||||
)
|
||||
# Encrypt sensitive fields
|
||||
# Encrypt sensitive fields - write only to _enc columns
|
||||
token = server_config.resolve_token()
|
||||
if token:
|
||||
token_secret = Secret.from_plaintext(token)
|
||||
mcp_server.set_token_secret(token_secret)
|
||||
mcp_server.token_enc = Secret.from_plaintext(token)
|
||||
if server_config.custom_headers:
|
||||
# Convert dict to JSON string, then encrypt as Secret
|
||||
headers_json = json.dumps(server_config.custom_headers)
|
||||
headers_secret = Secret.from_plaintext(headers_json)
|
||||
mcp_server.set_custom_headers_secret(headers_secret)
|
||||
mcp_server.custom_headers_enc = Secret.from_plaintext(headers_json)
|
||||
|
||||
elif isinstance(server_config, StreamableHTTPServerConfig):
|
||||
mcp_server = MCPServer(
|
||||
@@ -436,16 +434,14 @@ class MCPManager:
|
||||
server_type=server_config.type,
|
||||
server_url=server_config.server_url,
|
||||
)
|
||||
# Encrypt sensitive fields
|
||||
# Encrypt sensitive fields - write only to _enc columns
|
||||
token = server_config.resolve_token()
|
||||
if token:
|
||||
token_secret = Secret.from_plaintext(token)
|
||||
mcp_server.set_token_secret(token_secret)
|
||||
mcp_server.token_enc = Secret.from_plaintext(token)
|
||||
if server_config.custom_headers:
|
||||
# Convert dict to JSON string, then encrypt as Secret
|
||||
headers_json = json.dumps(server_config.custom_headers)
|
||||
headers_secret = Secret.from_plaintext(headers_json)
|
||||
mcp_server.set_custom_headers_secret(headers_secret)
|
||||
mcp_server.custom_headers_enc = Secret.from_plaintext(headers_json)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
@@ -539,57 +535,44 @@ class MCPManager:
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True)
|
||||
|
||||
# Handle encryption for token if provided
|
||||
# Only re-encrypt if the value has actually changed
|
||||
# Handle encryption for token if provided - write only to _enc column
|
||||
if "token" in update_data and update_data["token"] is not None:
|
||||
# Check if value changed
|
||||
# Check if value changed by reading from _enc column only
|
||||
existing_token = None
|
||||
if mcp_server.token_enc:
|
||||
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
|
||||
existing_token = existing_secret.get_plaintext()
|
||||
elif mcp_server.token:
|
||||
existing_token = mcp_server.token
|
||||
existing_token = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_token != update_data["token"]:
|
||||
mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted()
|
||||
# Keep plaintext for dual-write during migration
|
||||
mcp_server.token = update_data["token"]
|
||||
|
||||
# Remove from update_data since we set directly on mcp_server
|
||||
update_data.pop("token", None)
|
||||
update_data.pop("token_enc", None)
|
||||
|
||||
# Handle encryption for custom_headers if provided
|
||||
# Only re-encrypt if the value has actually changed
|
||||
# Handle encryption for custom_headers if provided - write only to _enc column
|
||||
if "custom_headers" in update_data:
|
||||
if update_data["custom_headers"] is not None:
|
||||
# custom_headers is a Dict[str, str], serialize to JSON then encrypt
|
||||
import json
|
||||
|
||||
json_str = json.dumps(update_data["custom_headers"])
|
||||
|
||||
# Check if value changed
|
||||
# Check if value changed by reading from _enc column only
|
||||
existing_headers_json = None
|
||||
if mcp_server.custom_headers_enc:
|
||||
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
|
||||
existing_headers_json = existing_secret.get_plaintext()
|
||||
elif mcp_server.custom_headers:
|
||||
existing_headers_json = json.dumps(mcp_server.custom_headers)
|
||||
existing_headers_json = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_headers_json != json_str:
|
||||
mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted()
|
||||
# Keep plaintext for dual-write during migration
|
||||
mcp_server.custom_headers = update_data["custom_headers"]
|
||||
|
||||
# Remove from update_data since we set directly on mcp_server
|
||||
update_data.pop("custom_headers", None)
|
||||
update_data.pop("custom_headers_enc", None)
|
||||
else:
|
||||
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
||||
# Ensure custom_headers_enc None is stored as SQL NULL
|
||||
update_data.pop("custom_headers", None)
|
||||
setattr(mcp_server, "custom_headers", null())
|
||||
setattr(mcp_server, "custom_headers_enc", None)
|
||||
|
||||
for key, value in update_data.items():
|
||||
@@ -811,7 +794,7 @@ class MCPManager:
|
||||
if oauth_provider is None and hasattr(server_config, "server_url"):
|
||||
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
|
||||
# Check if access token exists by attempting to decrypt it
|
||||
if oauth_session and oauth_session.get_access_token_secret().get_plaintext():
|
||||
if oauth_session and oauth_session.access_token_enc and await oauth_session.access_token_enc.get_plaintext_async():
|
||||
# Create OAuth provider from stored credentials
|
||||
from letta.services.mcp.oauth_utils import create_oauth_provider
|
||||
|
||||
@@ -836,31 +819,25 @@ class MCPManager:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
# OAuth-related methods
|
||||
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
async def _oauth_orm_to_pydantic_async(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
"""
|
||||
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
|
||||
|
||||
Note: Prefers encrypted columns (_enc fields), falls back to plaintext with error logging.
|
||||
This helps identify unmigrated data during the migration period.
|
||||
Convert OAuth ORM model to Pydantic model, reading directly from encrypted columns.
|
||||
"""
|
||||
# Get decrypted values - prefer encrypted, fallback to plaintext with error logging
|
||||
access_token = Secret.from_db(
|
||||
encrypted_value=oauth_session.access_token_enc, plaintext_value=oauth_session.access_token
|
||||
).get_plaintext()
|
||||
# Convert encrypted columns to Secret objects
|
||||
authorization_code_enc = (
|
||||
Secret.from_encrypted(oauth_session.authorization_code_enc) if oauth_session.authorization_code_enc else None
|
||||
)
|
||||
access_token_enc = Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None
|
||||
refresh_token_enc = Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None
|
||||
client_secret_enc = Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None
|
||||
|
||||
refresh_token = Secret.from_db(
|
||||
encrypted_value=oauth_session.refresh_token_enc, plaintext_value=oauth_session.refresh_token
|
||||
).get_plaintext()
|
||||
# Get plaintext values from encrypted columns (primary source of truth)
|
||||
authorization_code = await authorization_code_enc.get_plaintext_async() if authorization_code_enc else None
|
||||
access_token = await access_token_enc.get_plaintext_async() if access_token_enc else None
|
||||
refresh_token = await refresh_token_enc.get_plaintext_async() if refresh_token_enc else None
|
||||
client_secret = await client_secret_enc.get_plaintext_async() if client_secret_enc else None
|
||||
|
||||
client_secret = Secret.from_db(
|
||||
encrypted_value=oauth_session.client_secret_enc, plaintext_value=oauth_session.client_secret
|
||||
).get_plaintext()
|
||||
|
||||
authorization_code = Secret.from_db(
|
||||
encrypted_value=oauth_session.authorization_code_enc, plaintext_value=oauth_session.authorization_code
|
||||
).get_plaintext()
|
||||
|
||||
# Create the Pydantic object with encrypted fields as Secret objects
|
||||
# Create the Pydantic object with both encrypted and plaintext fields
|
||||
pydantic_session = MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
@@ -870,25 +847,24 @@ class MCPManager:
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
authorization_url=oauth_session.authorization_url,
|
||||
authorization_code=authorization_code,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_at=oauth_session.expires_at,
|
||||
scope=oauth_session.scope,
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
# Encrypted fields as Secret objects (converted from encrypted strings in DB)
|
||||
authorization_code_enc=Secret.from_encrypted(oauth_session.authorization_code_enc)
|
||||
if oauth_session.authorization_code_enc
|
||||
else None,
|
||||
access_token_enc=Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None,
|
||||
refresh_token_enc=Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None,
|
||||
client_secret_enc=Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None,
|
||||
# Plaintext fields populated from encrypted columns
|
||||
authorization_code=authorization_code,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
client_secret=client_secret,
|
||||
# Encrypted fields as Secret objects
|
||||
authorization_code_enc=authorization_code_enc,
|
||||
access_token_enc=access_token_enc,
|
||||
refresh_token_enc=refresh_token_enc,
|
||||
client_secret_enc=client_secret_enc,
|
||||
)
|
||||
return pydantic_session
|
||||
|
||||
@@ -911,7 +887,7 @@ class MCPManager:
|
||||
oauth_session = await oauth_session.create_async(session, actor=actor)
|
||||
|
||||
# Convert to Pydantic model - note: new sessions won't have tokens yet
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
||||
@@ -919,7 +895,7 @@ class MCPManager:
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@@ -945,7 +921,7 @@ class MCPManager:
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
@@ -957,56 +933,41 @@ class MCPManager:
|
||||
if session_update.authorization_url is not None:
|
||||
oauth_session.authorization_url = session_update.authorization_url
|
||||
|
||||
# Handle encryption for authorization_code
|
||||
# Only re-encrypt if the value has actually changed
|
||||
# Handle encryption for authorization_code - write only to _enc column
|
||||
if session_update.authorization_code is not None:
|
||||
# Check if value changed
|
||||
# Check if value changed by reading from _enc column only
|
||||
existing_code = None
|
||||
if oauth_session.authorization_code_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
|
||||
existing_code = existing_secret.get_plaintext()
|
||||
elif oauth_session.authorization_code:
|
||||
existing_code = oauth_session.authorization_code
|
||||
existing_code = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_code != session_update.authorization_code:
|
||||
oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted()
|
||||
# Keep plaintext for dual-write during migration
|
||||
oauth_session.authorization_code = session_update.authorization_code
|
||||
|
||||
# Handle encryption for access_token
|
||||
# Only re-encrypt if the value has actually changed
|
||||
# Handle encryption for access_token - write only to _enc column
|
||||
if session_update.access_token is not None:
|
||||
# Check if value changed
|
||||
# Check if value changed by reading from _enc column only
|
||||
existing_token = None
|
||||
if oauth_session.access_token_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
|
||||
existing_token = existing_secret.get_plaintext()
|
||||
elif oauth_session.access_token:
|
||||
existing_token = oauth_session.access_token
|
||||
existing_token = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_token != session_update.access_token:
|
||||
oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted()
|
||||
# Keep plaintext for dual-write during migration
|
||||
oauth_session.access_token = session_update.access_token
|
||||
|
||||
# Handle encryption for refresh_token
|
||||
# Only re-encrypt if the value has actually changed
|
||||
# Handle encryption for refresh_token - write only to _enc column
|
||||
if session_update.refresh_token is not None:
|
||||
# Check if value changed
|
||||
# Check if value changed by reading from _enc column only
|
||||
existing_refresh = None
|
||||
if oauth_session.refresh_token_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
|
||||
existing_refresh = existing_secret.get_plaintext()
|
||||
elif oauth_session.refresh_token:
|
||||
existing_refresh = oauth_session.refresh_token
|
||||
existing_refresh = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_refresh != session_update.refresh_token:
|
||||
oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted()
|
||||
# Keep plaintext for dual-write during migration
|
||||
oauth_session.refresh_token = session_update.refresh_token
|
||||
|
||||
if session_update.token_type is not None:
|
||||
oauth_session.token_type = session_update.token_type
|
||||
@@ -1017,22 +978,17 @@ class MCPManager:
|
||||
if session_update.client_id is not None:
|
||||
oauth_session.client_id = session_update.client_id
|
||||
|
||||
# Handle encryption for client_secret
|
||||
# Only re-encrypt if the value has actually changed
|
||||
# Handle encryption for client_secret - write only to _enc column
|
||||
if session_update.client_secret is not None:
|
||||
# Check if value changed
|
||||
# Check if value changed by reading from _enc column only
|
||||
existing_secret_val = None
|
||||
if oauth_session.client_secret_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
|
||||
existing_secret_val = existing_secret.get_plaintext()
|
||||
elif oauth_session.client_secret:
|
||||
existing_secret_val = oauth_session.client_secret
|
||||
existing_secret_val = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_secret_val != session_update.client_secret:
|
||||
oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted()
|
||||
# Keep plaintext for dual-write during migration
|
||||
oauth_session.client_secret = session_update.client_secret
|
||||
|
||||
if session_update.redirect_uri is not None:
|
||||
oauth_session.redirect_uri = session_update.redirect_uri
|
||||
@@ -1044,7 +1000,7 @@ class MCPManager:
|
||||
|
||||
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
|
||||
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:
|
||||
|
||||
@@ -162,7 +162,7 @@ class MCPServerManager:
|
||||
mcp_client = None
|
||||
try:
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
server_config = await mcp_config.to_config_async()
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
@@ -210,7 +210,7 @@ class MCPServerManager:
|
||||
|
||||
# Get the MCP server config
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config(environment_variables)
|
||||
server_config = await mcp_config.to_config_async(environment_variables)
|
||||
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
@@ -691,7 +691,7 @@ class MCPServerManager:
|
||||
existing_token = None
|
||||
if mcp_server.token_enc:
|
||||
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
|
||||
existing_token = existing_secret.get_plaintext()
|
||||
existing_token = await existing_secret.get_plaintext_async()
|
||||
elif mcp_server.token:
|
||||
existing_token = mcp_server.token
|
||||
|
||||
@@ -718,7 +718,7 @@ class MCPServerManager:
|
||||
existing_headers_json = None
|
||||
if mcp_server.custom_headers_enc:
|
||||
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
|
||||
existing_headers_json = existing_secret.get_plaintext()
|
||||
existing_headers_json = await existing_secret.get_plaintext_async()
|
||||
elif mcp_server.custom_headers:
|
||||
existing_headers_json = json.dumps(mcp_server.custom_headers)
|
||||
|
||||
@@ -961,7 +961,7 @@ class MCPServerManager:
|
||||
if oauth_provider is None and hasattr(server_config, "server_url"):
|
||||
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
|
||||
# Check if access token exists by attempting to decrypt it
|
||||
if oauth_session and oauth_session.get_access_token_secret().get_plaintext():
|
||||
if oauth_session and await oauth_session.get_access_token_secret().get_plaintext_async():
|
||||
# Create OAuth provider from stored credentials
|
||||
from letta.services.mcp.oauth_utils import create_oauth_provider
|
||||
|
||||
@@ -986,29 +986,24 @@ class MCPServerManager:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
# OAuth-related methods
|
||||
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
async def _oauth_orm_to_pydantic_async(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
"""
|
||||
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
|
||||
|
||||
Note: Prefers encrypted columns (_enc fields), falls back to plaintext with error logging.
|
||||
This helps identify unmigrated data during the migration period.
|
||||
Note: Prefers encrypted columns (_enc fields), falls back to legacy plaintext columns.
|
||||
"""
|
||||
# Get decrypted values - prefer encrypted, fallback to plaintext with error logging
|
||||
access_token = Secret.from_db(
|
||||
encrypted_value=oauth_session.access_token_enc, plaintext_value=oauth_session.access_token
|
||||
).get_plaintext()
|
||||
# Get decrypted values - prefer encrypted, fallback to legacy plaintext
|
||||
access_token_secret = Secret.from_encrypted(oauth_session.access_token_enc)
|
||||
access_token = await access_token_secret.get_plaintext_async()
|
||||
|
||||
refresh_token = Secret.from_db(
|
||||
encrypted_value=oauth_session.refresh_token_enc, plaintext_value=oauth_session.refresh_token
|
||||
).get_plaintext()
|
||||
refresh_token_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
|
||||
refresh_token = await refresh_token_secret.get_plaintext_async()
|
||||
|
||||
client_secret = Secret.from_db(
|
||||
encrypted_value=oauth_session.client_secret_enc, plaintext_value=oauth_session.client_secret
|
||||
).get_plaintext()
|
||||
client_secret_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
|
||||
client_secret = await client_secret_secret.get_plaintext_async()
|
||||
|
||||
authorization_code = Secret.from_db(
|
||||
encrypted_value=oauth_session.authorization_code_enc, plaintext_value=oauth_session.authorization_code
|
||||
).get_plaintext()
|
||||
authorization_code_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
|
||||
authorization_code = await authorization_code_secret.get_plaintext_async()
|
||||
|
||||
# Create the Pydantic object with encrypted fields as Secret objects
|
||||
pydantic_session = MCPOAuthSession(
|
||||
@@ -1061,7 +1056,7 @@ class MCPServerManager:
|
||||
oauth_session = await oauth_session.create_async(session, actor=actor)
|
||||
|
||||
# Convert to Pydantic model - note: new sessions won't have tokens yet
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
||||
@@ -1069,7 +1064,7 @@ class MCPServerManager:
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@@ -1095,7 +1090,7 @@ class MCPServerManager:
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
@@ -1114,7 +1109,7 @@ class MCPServerManager:
|
||||
existing_code = None
|
||||
if oauth_session.authorization_code_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
|
||||
existing_code = existing_secret.get_plaintext()
|
||||
existing_code = await existing_secret.get_plaintext_async()
|
||||
elif oauth_session.authorization_code:
|
||||
existing_code = oauth_session.authorization_code
|
||||
|
||||
@@ -1131,7 +1126,7 @@ class MCPServerManager:
|
||||
existing_token = None
|
||||
if oauth_session.access_token_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
|
||||
existing_token = existing_secret.get_plaintext()
|
||||
existing_token = await existing_secret.get_plaintext_async()
|
||||
elif oauth_session.access_token:
|
||||
existing_token = oauth_session.access_token
|
||||
|
||||
@@ -1148,7 +1143,7 @@ class MCPServerManager:
|
||||
existing_refresh = None
|
||||
if oauth_session.refresh_token_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
|
||||
existing_refresh = existing_secret.get_plaintext()
|
||||
existing_refresh = await existing_secret.get_plaintext_async()
|
||||
elif oauth_session.refresh_token:
|
||||
existing_refresh = oauth_session.refresh_token
|
||||
|
||||
@@ -1174,7 +1169,7 @@ class MCPServerManager:
|
||||
existing_secret_val = None
|
||||
if oauth_session.client_secret_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
|
||||
existing_secret_val = existing_secret.get_plaintext()
|
||||
existing_secret_val = await existing_secret.get_plaintext_async()
|
||||
elif oauth_session.client_secret:
|
||||
existing_secret_val = oauth_session.client_secret
|
||||
|
||||
@@ -1194,7 +1189,7 @@ class MCPServerManager:
|
||||
|
||||
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
|
||||
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:
|
||||
|
||||
@@ -345,8 +345,8 @@ class MessageManager:
|
||||
return combined_messages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="message_id", expected_prefix=PrimitiveType.MESSAGE)
|
||||
@trace_method
|
||||
async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||||
"""Fetch a message by ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -754,8 +754,8 @@ class MessageManager:
|
||||
return message
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="message_id", expected_prefix=PrimitiveType.MESSAGE)
|
||||
@trace_method
|
||||
async def delete_message_by_id_async(self, message_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool:
|
||||
"""Delete a message (async version with turbopuffer support)."""
|
||||
# capture agent_id before deletion
|
||||
|
||||
@@ -95,8 +95,8 @@ class ProviderManager:
|
||||
return provider_pydantic
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@trace_method
|
||||
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Update provider details."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -115,7 +115,7 @@ class ProviderManager:
|
||||
existing_api_key = None
|
||||
if existing_provider.api_key_enc:
|
||||
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
|
||||
existing_api_key = existing_secret.get_plaintext()
|
||||
existing_api_key = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_api_key != update_data["api_key"]:
|
||||
@@ -132,7 +132,7 @@ class ProviderManager:
|
||||
existing_access_key = None
|
||||
if existing_provider.access_key_enc:
|
||||
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
|
||||
existing_access_key = existing_secret.get_plaintext()
|
||||
existing_access_key = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_access_key != update_data["access_key"]:
|
||||
@@ -151,8 +151,8 @@ class ProviderManager:
|
||||
return existing_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@trace_method
|
||||
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
|
||||
"""Delete a provider."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -284,8 +284,8 @@ class ProviderManager:
|
||||
return [provider.to_pydantic() for provider in all_providers]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@trace_method
|
||||
async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider:
|
||||
async with db_registry.async_session() as session:
|
||||
# First try to get as organization-specific provider
|
||||
@@ -336,7 +336,7 @@ class ProviderManager:
|
||||
if providers:
|
||||
# Decrypt the API key before returning
|
||||
api_key_secret = providers[0].api_key_enc
|
||||
return api_key_secret.get_plaintext() if api_key_secret else None
|
||||
return await api_key_secret.get_plaintext_async() if api_key_secret else None
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@@ -349,8 +349,8 @@ class ProviderManager:
|
||||
# Decrypt the credentials before returning
|
||||
access_key_secret = providers[0].access_key_enc
|
||||
api_key_secret = providers[0].api_key_enc
|
||||
access_key = access_key_secret.get_plaintext() if access_key_secret else None
|
||||
secret_key = api_key_secret.get_plaintext() if api_key_secret else None
|
||||
access_key = await access_key_secret.get_plaintext_async() if access_key_secret else None
|
||||
secret_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
||||
region = providers[0].region
|
||||
return access_key, secret_key, region
|
||||
return None, None, None
|
||||
@@ -379,7 +379,7 @@ class ProviderManager:
|
||||
if providers:
|
||||
# Decrypt the API key before returning
|
||||
api_key_secret = providers[0].api_key_enc
|
||||
api_key = api_key_secret.get_plaintext() if api_key_secret else None
|
||||
api_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
||||
base_url = providers[0].base_url
|
||||
api_version = providers[0].api_version
|
||||
return api_key, base_url, api_version
|
||||
@@ -400,7 +400,7 @@ class ProviderManager:
|
||||
).cast_to_subtype()
|
||||
|
||||
# TODO: add more string sanity checks here before we hit actual endpoints
|
||||
if not provider.api_key_enc or not provider.api_key_enc.get_plaintext():
|
||||
if not provider.api_key_enc or not await provider.api_key_enc.get_plaintext_async():
|
||||
raise ValueError("API key is required!")
|
||||
|
||||
await provider.check_api_key()
|
||||
@@ -439,8 +439,8 @@ class ProviderManager:
|
||||
return
|
||||
|
||||
# Create provider instance with necessary parameters
|
||||
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
|
||||
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
kwargs = {
|
||||
"name": provider.name,
|
||||
"api_key": api_key,
|
||||
@@ -516,8 +516,8 @@ class ProviderManager:
|
||||
continue
|
||||
|
||||
# Convert Provider to ProviderCreate
|
||||
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
|
||||
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
provider_create = ProviderCreate(
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
|
||||
@@ -101,8 +101,8 @@ class SandboxConfigManager:
|
||||
return db_sandbox.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
async def update_sandbox_config_async(
|
||||
self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser
|
||||
) -> PydanticSandboxConfig:
|
||||
@@ -130,8 +130,8 @@ class SandboxConfigManager:
|
||||
return sandbox.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
async def delete_sandbox_config_async(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig:
|
||||
"""Delete a sandbox configuration by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -178,8 +178,8 @@ class SandboxConfigManager:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
async def create_sandbox_env_var_async(
|
||||
self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser
|
||||
) -> PydanticEnvVar:
|
||||
@@ -267,8 +267,8 @@ class SandboxConfigManager:
|
||||
return await PydanticEnvVar.from_orm_async(env_var)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
async def list_sandbox_env_vars_async(
|
||||
self,
|
||||
sandbox_config_id: str,
|
||||
@@ -285,7 +285,10 @@ class SandboxConfigManager:
|
||||
organization_id=actor.organization_id,
|
||||
sandbox_config_id=sandbox_config_id,
|
||||
)
|
||||
return [await PydanticEnvVar.from_orm_async(env_var) for env_var in env_vars]
|
||||
result = []
|
||||
for env_var in env_vars:
|
||||
result.append(await PydanticEnvVar.from_orm_async(env_var))
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -301,11 +304,14 @@ class SandboxConfigManager:
|
||||
organization_id=actor.organization_id,
|
||||
key=key,
|
||||
)
|
||||
return [await PydanticEnvVar.from_orm_async(env_var) for env_var in env_vars]
|
||||
result = []
|
||||
for env_var in env_vars:
|
||||
result.append(await PydanticEnvVar.from_orm_async(env_var))
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
def get_sandbox_env_vars_as_dict(
|
||||
self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
|
||||
) -> Dict[str, str]:
|
||||
@@ -317,8 +323,8 @@ class SandboxConfigManager:
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
async def get_sandbox_env_vars_as_dict_async(
|
||||
self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
|
||||
) -> Dict[str, str]:
|
||||
@@ -327,8 +333,8 @@ class SandboxConfigManager:
|
||||
return {env_var.key: env_var.value for env_var in env_vars}
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
async def get_sandbox_env_var_by_key_and_sandbox_config_id_async(
|
||||
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
|
||||
) -> Optional[PydanticEnvVar]:
|
||||
|
||||
@@ -201,8 +201,8 @@ class SourceManager:
|
||||
return sources
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
|
||||
"""Update a source by its ID with the given SourceUpdate object."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -225,8 +225,8 @@ class SourceManager:
|
||||
return source.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
|
||||
"""Delete a source by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -270,8 +270,8 @@ class SourceManager:
|
||||
return await SourceModel.size_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def list_attached_agents(
|
||||
self, source_id: str, actor: PydanticUser, ids_only: bool = False
|
||||
) -> Union[List[PydanticAgentState], List[str]]:
|
||||
@@ -321,11 +321,11 @@ class SourceManager:
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include=[]) for agent in agents_orm])
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def get_agents_for_source_id(
|
||||
self,
|
||||
source_id: str,
|
||||
@@ -439,8 +439,8 @@ class SourceManager:
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
|
||||
"""Retrieve a source by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
|
||||
@@ -33,9 +33,9 @@ class FeedbackType(str, Enum):
|
||||
|
||||
class StepManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
@trace_method
|
||||
async def list_steps_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -82,11 +82,11 @@ class StepManager:
|
||||
return [step.to_pydantic() for step in steps]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
def log_step(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -140,11 +140,11 @@ class StepManager:
|
||||
return new_step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def log_step_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -207,24 +207,24 @@ class StepManager:
|
||||
return pydantic_step
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def get_step_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
async with db_registry.async_session() as session:
|
||||
step = await StepModel.read_async(db_session=session, identifier=step_id, actor=actor)
|
||||
return step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def get_step_metrics_async(self, step_id: str, actor: PydanticUser) -> PydanticStepMetrics:
|
||||
async with db_registry.async_session() as session:
|
||||
metrics = await StepMetricsModel.read_async(db_session=session, identifier=step_id, actor=actor)
|
||||
return metrics.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def add_feedback_async(
|
||||
self, step_id: str, feedback: FeedbackType | None, actor: PydanticUser, tags: list[str] | None = None
|
||||
) -> PydanticStep:
|
||||
@@ -239,8 +239,8 @@ class StepManager:
|
||||
return step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
|
||||
"""Update the transaction ID for a step.
|
||||
|
||||
@@ -267,8 +267,8 @@ class StepManager:
|
||||
return step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def list_step_messages_async(
|
||||
self,
|
||||
step_id: str,
|
||||
@@ -291,8 +291,8 @@ class StepManager:
|
||||
return [message.to_pydantic() for message in messages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep:
|
||||
"""Update the stop reason for a step.
|
||||
|
||||
@@ -319,8 +319,8 @@ class StepManager:
|
||||
return step
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def update_step_error_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -369,8 +369,8 @@ class StepManager:
|
||||
return pydantic_step
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def update_step_success_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -420,8 +420,8 @@ class StepManager:
|
||||
return pydantic_step
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@trace_method
|
||||
async def update_step_cancelled_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -460,10 +460,10 @@ class StepManager:
|
||||
return pydantic_step
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
@trace_method
|
||||
async def record_step_metrics_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
|
||||
@@ -16,7 +16,7 @@ from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.prompts import gpt_summarize
|
||||
from letta.schemas.enums import AgentType, MessageRole
|
||||
from letta.schemas.enums import AgentType, MessageRole, ProviderType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
@@ -457,9 +457,63 @@ async def simple_summary(
|
||||
summarizer_llm_config.put_inner_thoughts_in_kwargs = False
|
||||
summarizer_llm_config.enable_reasoner = False
|
||||
|
||||
async def _run_summarizer_request(req_data: dict, req_messages_obj: list[Message]) -> str:
|
||||
"""Run summarization request and return assistant text.
|
||||
|
||||
For Anthropic, use provider-side streaming to avoid long-request failures
|
||||
(Anthropic requires streaming for requests that may exceed ~10 minutes).
|
||||
"""
|
||||
|
||||
if summarizer_llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
||||
logger.info(
|
||||
"Summarizer: using provider streaming (%s/%s) to avoid long-request failures",
|
||||
summarizer_llm_config.model_endpoint_type,
|
||||
summarizer_llm_config.model,
|
||||
)
|
||||
# Stream from provider and accumulate the final assistant text.
|
||||
from letta.interfaces.anthropic_parallel_tool_call_streaming_interface import (
|
||||
SimpleAnthropicStreamingInterface,
|
||||
)
|
||||
|
||||
interface = SimpleAnthropicStreamingInterface(
|
||||
requires_approval_tools=[],
|
||||
run_id=None,
|
||||
step_id=None,
|
||||
)
|
||||
|
||||
# AnthropicClient.stream_async sets request_data["stream"] = True internally.
|
||||
stream = await llm_client.stream_async(req_data, summarizer_llm_config)
|
||||
async for _chunk in interface.process(stream):
|
||||
# We don't emit anything; we just want the fully-accumulated content.
|
||||
pass
|
||||
|
||||
content_parts = interface.get_content()
|
||||
text = "".join(part.text for part in content_parts if isinstance(part, TextContent)).strip()
|
||||
if not text:
|
||||
logger.warning("No content returned from summarizer (streaming path)")
|
||||
raise Exception("Summary failed to generate")
|
||||
return text
|
||||
|
||||
# Default: non-streaming provider request, then normalize via chat-completions conversion.
|
||||
logger.debug(
|
||||
"Summarizer: using non-streaming request (%s/%s)",
|
||||
summarizer_llm_config.model_endpoint_type,
|
||||
summarizer_llm_config.model,
|
||||
)
|
||||
response_data = await llm_client.request_async(req_data, summarizer_llm_config)
|
||||
response = await llm_client.convert_response_to_chat_completion(
|
||||
response_data,
|
||||
req_messages_obj,
|
||||
summarizer_llm_config,
|
||||
)
|
||||
if response.choices[0].message.content is None:
|
||||
logger.warning("No content returned from summarizer")
|
||||
raise Exception("Summary failed to generate")
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
request_data = llm_client.build_request_data(AgentType.letta_v1_agent, input_messages_obj, summarizer_llm_config, tools=[])
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, summarizer_llm_config)
|
||||
summary = await _run_summarizer_request(request_data, input_messages_obj)
|
||||
except Exception as e:
|
||||
# handle LLM error (likely a context window exceeded error)
|
||||
try:
|
||||
@@ -497,7 +551,7 @@ async def simple_summary(
|
||||
)
|
||||
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, summarizer_llm_config)
|
||||
summary = await _run_summarizer_request(request_data, input_messages_obj)
|
||||
except Exception as fallback_error_a:
|
||||
# Fallback B: hard-truncate the user transcript to fit a conservative char budget
|
||||
logger.warning(f"Clamped tool returns still overflowed ({fallback_error_a}). Falling back to transcript truncation.")
|
||||
@@ -534,21 +588,12 @@ async def simple_summary(
|
||||
tools=[],
|
||||
)
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, summarizer_llm_config)
|
||||
summary = await _run_summarizer_request(request_data, input_messages_obj)
|
||||
except Exception as fallback_error_b:
|
||||
logger.error(f"Transcript truncation fallback also failed: {fallback_error_b}. Propagating error.")
|
||||
logger.info(f"Full fallback summarization payload: {request_data}")
|
||||
raise llm_client.handle_llm_error(fallback_error_b)
|
||||
|
||||
response = await llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config)
|
||||
if response.choices[0].message.content is None:
|
||||
logger.warning("No content returned from summarizer")
|
||||
# TODO raise an error error instead?
|
||||
# return "[Summary failed to generate]"
|
||||
raise Exception("Summary failed to generate")
|
||||
else:
|
||||
summary = response.choices[0].message.content.strip()
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
|
||||
@@ -101,7 +101,12 @@ async def summarize_via_sliding_window(
|
||||
|
||||
# get index of first assistant message after the cutoff point ()
|
||||
assistant_message_index = next(
|
||||
(i for i in reversed(range(1, message_cutoff_index + 1)) if in_context_messages[i].role in valid_cutoff_roles), None
|
||||
(
|
||||
i
|
||||
for i in reversed(range(1, message_cutoff_index + 1))
|
||||
if i < len(in_context_messages) and in_context_messages[i].role in valid_cutoff_roles
|
||||
),
|
||||
None,
|
||||
)
|
||||
if assistant_message_index is None:
|
||||
logger.warning(f"No assistant message found for evicting up to index {message_cutoff_index}, incrementing eviction percentage")
|
||||
|
||||
@@ -143,9 +143,10 @@ class AsyncToolSandboxModal(AsyncToolSandboxBase):
|
||||
logger.warning(f"Could not load sandbox env vars for tool {self.tool_name}: {e}")
|
||||
|
||||
# Add agent-specific environment variables (these override sandbox-level)
|
||||
# Use the pre-decrypted value field which was populated in from_orm_async()
|
||||
if agent_state and agent_state.secrets:
|
||||
for secret in agent_state.secrets:
|
||||
env_vars[secret.key] = secret.value_enc.get_plaintext() if secret.value_enc else None
|
||||
env_vars[secret.key] = secret.value or ""
|
||||
|
||||
# Add any additional env vars passed at runtime (highest priority)
|
||||
if additional_env_vars:
|
||||
|
||||
@@ -323,7 +323,9 @@ class Settings(BaseSettings):
|
||||
|
||||
# LLM request timeout settings (model + embedding model)
|
||||
llm_request_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM requests in seconds")
|
||||
llm_stream_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM streaming requests in seconds")
|
||||
llm_stream_timeout_seconds: float = Field(
|
||||
default=600.0, ge=10.0, le=1800.0, description="Timeout for LLM streaming requests in seconds"
|
||||
)
|
||||
|
||||
# For embeddings
|
||||
enable_pinecone: bool = False
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "letta"
|
||||
version = "0.16.0"
|
||||
version = "0.16.1"
|
||||
description = "Create LLM agents with long-term memory and custom tools"
|
||||
authors = [
|
||||
{name = "Letta Team", email = "contact@letta.com"},
|
||||
@@ -72,6 +72,7 @@ dependencies = [
|
||||
"google-genai>=1.52.0",
|
||||
"datadog>=0.49.1",
|
||||
"psutil>=5.9.0",
|
||||
"ddtrace>=4.0.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -27,6 +27,8 @@ from letta.schemas.message import Message as PydanticMessage, MessageCreate
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.services.summarizer.summarizer import simple_summary
|
||||
from letta.settings import model_settings
|
||||
|
||||
# Constants
|
||||
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig.default_config(provider="openai")
|
||||
@@ -240,6 +242,49 @@ async def test_summarize_empty_message_buffer(server: SyncServer, actor, llm_con
|
||||
assert "No assistant message found" in str(e) or "empty" in str(e).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
not model_settings.anthropic_api_key,
|
||||
reason="Missing LETTA_ANTHROPIC_API_KEY (or equivalent settings) for Anthropic integration test",
|
||||
)
|
||||
async def test_simple_summary_anthropic_uses_streaming_and_returns_summary(actor, monkeypatch):
|
||||
"""Regression test: Anthropic summarization must use streaming and return real text."""
|
||||
|
||||
# If the summarizer ever falls back to a non-streaming Anthropic call, make it fail fast.
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
|
||||
async def _nope_request_async(self, *args, **kwargs):
|
||||
raise AssertionError("Anthropic summarizer should not call request_async (must use streaming)")
|
||||
|
||||
monkeypatch.setattr(AnthropicClient, "request_async", _nope_request_async)
|
||||
|
||||
# Keep the prompt tiny so this is fast and cheap.
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(type="text", text="I'm planning a trip to Paris in April.")],
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
content=[
|
||||
TextContent(
|
||||
type="text",
|
||||
text="Great—your priorities are museums and cafes, and you want to stay under $200/day.",
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
anthropic_config = get_llm_config("claude-4-5-haiku.json")
|
||||
|
||||
summary = await simple_summary(messages=messages, llm_config=anthropic_config, actor=actor)
|
||||
|
||||
assert isinstance(summary, str)
|
||||
assert len(summary) > 10
|
||||
# Sanity-check that the model is summarizing the right conversation.
|
||||
assert any(token in summary.lower() for token in ["paris", "april", "museum", "cafe", "$200", "200"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
|
||||
@@ -943,12 +943,13 @@ async def test_mcp_server_token_encryption_on_create(server, default_user, encry
|
||||
assert created_server is not None
|
||||
assert created_server.server_name == "test-encrypted-server"
|
||||
|
||||
# Verify plaintext token is accessible (dual-write during migration)
|
||||
assert created_server.token == "sk-test-secret-token-12345"
|
||||
# Verify plaintext token field is NOT set (no dual-write)
|
||||
assert created_server.token is None
|
||||
|
||||
# Verify token_enc is a Secret object
|
||||
# Verify token_enc is a Secret object and decrypts correctly
|
||||
assert created_server.token_enc is not None
|
||||
assert isinstance(created_server.token_enc, Secret)
|
||||
assert created_server.token_enc.get_plaintext() == "sk-test-secret-token-12345"
|
||||
|
||||
# Read directly from database to verify encryption
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -958,9 +959,6 @@ async def test_mcp_server_token_encryption_on_create(server, default_user, encry
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify plaintext column has the value (dual-write)
|
||||
assert server_orm.token == "sk-test-secret-token-12345"
|
||||
|
||||
# Verify encrypted column is populated and different from plaintext
|
||||
assert server_orm.token_enc is not None
|
||||
assert server_orm.token_enc != "sk-test-secret-token-12345"
|
||||
@@ -994,8 +992,12 @@ async def test_mcp_server_token_decryption_on_read(server, default_user, encrypt
|
||||
# Read the server back
|
||||
retrieved_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
|
||||
|
||||
# Verify the token is decrypted correctly
|
||||
assert retrieved_server.token == "sk-test-decrypt-token-67890"
|
||||
# Verify plaintext token field is NOT set (no dual-write)
|
||||
assert retrieved_server.token is None
|
||||
|
||||
# Verify the token is decrypted correctly via token_enc
|
||||
assert retrieved_server.token_enc is not None
|
||||
assert retrieved_server.token_enc.get_plaintext() == "sk-test-decrypt-token-67890"
|
||||
|
||||
# Verify we can get the decrypted token through the secret getter
|
||||
token_secret = retrieved_server.get_token_secret()
|
||||
@@ -1028,8 +1030,11 @@ async def test_mcp_server_custom_headers_encryption(server, default_user, encryp
|
||||
created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user)
|
||||
|
||||
try:
|
||||
# Verify custom_headers are accessible
|
||||
assert created_server.custom_headers == custom_headers
|
||||
# Verify plaintext custom_headers field is NOT set (no dual-write)
|
||||
assert created_server.custom_headers is None
|
||||
|
||||
# Verify custom_headers are accessible via encrypted field
|
||||
assert created_server.get_custom_headers_dict() == custom_headers
|
||||
|
||||
# Verify custom_headers_enc is a Secret object (stores JSON string)
|
||||
assert created_server.custom_headers_enc is not None
|
||||
|
||||
@@ -8,7 +8,7 @@ AGENTS_CREATE_PARAMS = [
|
||||
# Verify model_settings is populated with config values
|
||||
# Note: The 'model' field itself is separate from model_settings
|
||||
"model_settings": {
|
||||
"max_output_tokens": 4096,
|
||||
"max_output_tokens": 16384,
|
||||
"parallel_tool_calls": False,
|
||||
"provider_type": "openai",
|
||||
"temperature": 0.7,
|
||||
@@ -27,7 +27,7 @@ AGENTS_UPDATE_PARAMS = [
|
||||
{
|
||||
# After updating just the name, model_settings should still be present
|
||||
"model_settings": {
|
||||
"max_output_tokens": 4096,
|
||||
"max_output_tokens": 16384,
|
||||
"parallel_tool_calls": False,
|
||||
"provider_type": "openai",
|
||||
"temperature": 0.7,
|
||||
|
||||
@@ -84,10 +84,8 @@ class TestMCPServerEncryption:
|
||||
decrypted_token = CryptoUtils.decrypt(db_server.token_enc)
|
||||
assert decrypted_token == token
|
||||
|
||||
# Legacy plaintext column should be None (or empty for dual-write)
|
||||
# During migration phase, might store both
|
||||
if db_server.token:
|
||||
assert db_server.token == token # Dual-write phase
|
||||
# Plaintext column should NOT be written to (encrypted-only)
|
||||
assert db_server.token is None
|
||||
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
@@ -176,9 +174,9 @@ class TestMCPServerEncryption:
|
||||
|
||||
assert test_server is not None
|
||||
assert test_server.server_name == server_name
|
||||
# Token should be decrypted when accessed via the secret method
|
||||
token_secret = test_server.get_token_secret()
|
||||
assert token_secret.get_plaintext() == plaintext_token
|
||||
# Token should be decrypted when accessed via the _enc column
|
||||
assert test_server.token_enc is not None
|
||||
assert test_server.token_enc.get_plaintext() == plaintext_token
|
||||
|
||||
# Clean up
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -220,15 +218,15 @@ class TestMCPServerEncryption:
|
||||
# Should work without encryption key - stores plaintext in _enc column
|
||||
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
||||
|
||||
# Check database - should store plaintext in _enc column
|
||||
# Check database - should store plaintext in _enc column (no encryption key)
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
|
||||
db_server = result.scalar_one()
|
||||
|
||||
# Token should be stored as plaintext in _enc column (not encrypted)
|
||||
assert db_server.token_enc == token # Plaintext stored directly
|
||||
# Legacy plaintext column should also be populated (dual-write)
|
||||
assert db_server.token == token
|
||||
# Plaintext column should NOT be written to (encrypted-only)
|
||||
assert db_server.token is None
|
||||
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
@@ -346,10 +344,13 @@ class TestMCPOAuthEncryption:
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
assert test_session is not None
|
||||
|
||||
# Tokens should be decrypted
|
||||
assert test_session.access_token == access_token
|
||||
assert test_session.refresh_token == refresh_token
|
||||
assert test_session.client_secret == client_secret
|
||||
# Tokens should be decrypted from _enc columns
|
||||
assert test_session.access_token_enc is not None
|
||||
assert test_session.access_token_enc.get_plaintext() == access_token
|
||||
assert test_session.refresh_token_enc is not None
|
||||
assert test_session.refresh_token_enc.get_plaintext() == refresh_token
|
||||
assert test_session.client_secret_enc is not None
|
||||
assert test_session.client_secret_enc.get_plaintext() == client_secret
|
||||
|
||||
# Clean up not needed - test database is reset
|
||||
|
||||
@@ -396,9 +397,11 @@ class TestMCPOAuthEncryption:
|
||||
|
||||
updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user)
|
||||
|
||||
# Verify update worked
|
||||
assert updated_session.access_token == new_access_token
|
||||
assert updated_session.refresh_token == new_refresh_token
|
||||
# Verify update worked - read from _enc columns
|
||||
assert updated_session.access_token_enc is not None
|
||||
assert updated_session.access_token_enc.get_plaintext() == new_access_token
|
||||
assert updated_session.refresh_token_enc is not None
|
||||
assert updated_session.refresh_token_enc.get_plaintext() == new_refresh_token
|
||||
|
||||
# Check database encryption
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -459,8 +462,9 @@ class TestMCPOAuthEncryption:
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
assert test_session is not None
|
||||
|
||||
# Should use encrypted value only (plaintext is ignored)
|
||||
assert test_session.access_token == new_encrypted_token
|
||||
# Should read from encrypted column only (plaintext is ignored)
|
||||
assert test_session.access_token_enc is not None
|
||||
assert test_session.access_token_enc.get_plaintext() == new_encrypted_token
|
||||
|
||||
# Clean up not needed - test database is reset
|
||||
|
||||
@@ -469,15 +473,13 @@ class TestMCPOAuthEncryption:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plaintext_only_record_fallback_with_error_logging(self, server, default_user, caplog):
|
||||
"""Test that records with only plaintext values fall back to plaintext with error logging.
|
||||
async def test_plaintext_only_record_returns_none(self, server, default_user):
|
||||
"""Test that records with only plaintext values return None for encrypted fields.
|
||||
|
||||
Note: In Phase 1 of migration, if a record only has plaintext value
|
||||
(no encrypted value), the system falls back to plaintext but logs an error
|
||||
to help identify unmigrated data.
|
||||
With encrypted-only migration complete, if a record only has plaintext value
|
||||
(no encrypted value), the system returns None for that field since we only
|
||||
read from _enc columns now.
|
||||
"""
|
||||
import logging
|
||||
|
||||
# Set encryption key directly on settings
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
||||
@@ -494,7 +496,7 @@ class TestMCPOAuthEncryption:
|
||||
server_url="https://test.com/mcp",
|
||||
server_name="Plaintext Only Test",
|
||||
# Only plaintext value, no encrypted
|
||||
access_token=plaintext_token, # Legacy plaintext - should fallback with error log
|
||||
access_token=plaintext_token, # Legacy plaintext - should be ignored
|
||||
access_token_enc=None, # No encrypted value
|
||||
client_id="test-client",
|
||||
user_id=default_user.id,
|
||||
@@ -505,17 +507,12 @@ class TestMCPOAuthEncryption:
|
||||
session.add(db_oauth)
|
||||
await session.commit()
|
||||
|
||||
# Retrieve through manager - should log error about plaintext fallback
|
||||
with caplog.at_level(logging.ERROR):
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
|
||||
# Retrieve through manager
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
assert test_session is not None
|
||||
|
||||
# Should fall back to plaintext value
|
||||
assert test_session.access_token == plaintext_token
|
||||
|
||||
# Should have logged an error about reading from plaintext column
|
||||
assert "MIGRATION_NEEDED" in caplog.text
|
||||
# Should return None since we only read from _enc columns now
|
||||
assert test_session.access_token_enc is None
|
||||
|
||||
# Clean up not needed - test database is reset
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ class TestSecret:
|
||||
# Should store encrypted value
|
||||
assert secret.encrypted_value is not None
|
||||
assert secret.encrypted_value != plaintext
|
||||
assert secret.was_encrypted is False
|
||||
|
||||
# Should decrypt to original value
|
||||
assert secret.get_plaintext() == plaintext
|
||||
@@ -52,7 +51,6 @@ class TestSecret:
|
||||
# Should store the plaintext value directly in encrypted_value
|
||||
assert secret.encrypted_value == plaintext
|
||||
assert secret.get_plaintext() == plaintext
|
||||
assert not secret.was_encrypted
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
@@ -61,7 +59,6 @@ class TestSecret:
|
||||
secret = Secret.from_plaintext(None)
|
||||
|
||||
assert secret.encrypted_value is None
|
||||
assert secret.was_encrypted is False
|
||||
assert secret.get_plaintext() is None
|
||||
assert secret.is_empty() is True
|
||||
|
||||
@@ -79,78 +76,10 @@ class TestSecret:
|
||||
secret = Secret.from_encrypted(encrypted)
|
||||
|
||||
assert secret.encrypted_value == encrypted
|
||||
assert secret.was_encrypted is True
|
||||
assert secret.get_plaintext() == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_with_encrypted_value(self):
|
||||
"""Test creating a Secret from database with encrypted value."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "database-secret"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
|
||||
|
||||
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=None)
|
||||
|
||||
assert secret.encrypted_value == encrypted
|
||||
assert secret.was_encrypted is True
|
||||
assert secret.get_plaintext() == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_with_plaintext_value_fallback(self, caplog):
|
||||
"""Test creating a Secret from database with only plaintext value falls back with error logging.
|
||||
|
||||
Note: In Phase 1 of migration, from_db() prefers encrypted but falls back to plaintext
|
||||
with error logging to help identify unmigrated data.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "legacy-plaintext"
|
||||
|
||||
# When only plaintext is provided, should fall back to plaintext with error logging
|
||||
with caplog.at_level(logging.ERROR):
|
||||
secret = Secret.from_db(encrypted_value=None, plaintext_value=plaintext)
|
||||
|
||||
# Should use the plaintext value (fallback)
|
||||
assert secret.get_plaintext() == plaintext
|
||||
|
||||
# Should have logged an error about reading from plaintext column
|
||||
assert "MIGRATION_NEEDED" in caplog.text
|
||||
assert "plaintext column" in caplog.text
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_dual_read(self):
|
||||
"""Test dual read functionality - prefer encrypted over plaintext."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "correct-value"
|
||||
old_plaintext = "old-legacy-value"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
|
||||
|
||||
# When both values exist, should prefer encrypted
|
||||
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=old_plaintext)
|
||||
|
||||
assert secret.get_plaintext() == plaintext # Should use encrypted value, not plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_get_encrypted(self):
|
||||
"""Test getting the encrypted value for database storage."""
|
||||
from letta.settings import settings
|
||||
|
||||
79
uv.lock
generated
79
uv.lock
generated
@@ -912,51 +912,43 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "ddtrace"
|
||||
version = "3.16.2"
|
||||
version = "4.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "bytecode" },
|
||||
{ name = "envier" },
|
||||
{ name = "legacy-cgi", marker = "python_full_version >= '3.13'" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/35/028fe174ec1a1da8977d4900297f4493a77e93dee1af700f473e692d010e/ddtrace-3.16.2.tar.gz", hash = "sha256:cfef021790635b6dda949e89298b7fed3b5e686c55b46afe9483cebcc0f10a86", size = 7408082, upload-time = "2025-10-21T19:29:32.004Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/71/54/e9f58f6e631f3c14b9300b7742bb76bcf3e8d73097ae70ddedfee81bf8f6/ddtrace-4.0.1.tar.gz", hash = "sha256:821d811de1d530ab61cdfb2d7f986d25a79c4e67d22a91190bf95a6c7abacdad", size = 7543615, upload-time = "2025-12-16T20:11:10.173Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/7f/55a8753b6ee574b34ee9c3ae48f6b35f4b04de3ef0e4044ad1adb6e7831b/ddtrace-3.16.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:8d644051c265be9865e68274ebbc4a17934a78fc447bf7d6611ecbe49fbe4b7a", size = 6334301, upload-time = "2025-10-21T19:26:27.261Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/a9/a58ae59088f00e237068c4522bb23b12d93e03a9e76cf73f2a357c963533/ddtrace-3.16.2-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:9a51665c563f2cc56ccc29f3404e68618bbfd70e60983de239ba8d2edc5ce7e8", size = 6679330, upload-time = "2025-10-21T19:26:28.867Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/48/df6dc3e2b7fff37ad813ddc5651a3e1f8240757b46f464d0e4b3ccf58a11/ddtrace-3.16.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:527d257f8b020d61f53686fa8791529d1a5b5c33719ed47b28110b8449f14257", size = 7402656, upload-time = "2025-10-21T19:26:30.865Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/d3/377f88a42b9df3bb12fe8c11dfc46548b996dec13abea23fbf0714994a32/ddtrace-3.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2eb5a196f74fa5fd8e62105a4faeb4b89e58ed2e00401b182888e407c8d23e94", size = 7668527, upload-time = "2025-10-21T19:26:32.663Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/01/8ee880b739afb614b2268c5e86cb039d509b3947aa6669ad861130fc62a0/ddtrace-3.16.2-cp311-cp311-manylinux_2_28_i686.whl", hash = "sha256:59bb645bd5f58465df651e3ae9809bdc2fed339e9c7eae53a41b941d40497508", size = 5521123, upload-time = "2025-10-21T19:26:35.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/60/cd88ae82999fe8e532259289f37f5afa08e9fae2204464372cc08859bf60/ddtrace-3.16.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b2e752daab0206c3abd118621851cc39673f9ed5c2180264ffa4fd147b6351be", size = 8415337, upload-time = "2025-10-21T19:26:37.117Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/7d/5ec01e65bf3a5c022439292e65c23ce80cd0c4726daba8bc7cfa12252665/ddtrace-3.16.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5f9cd3345d686369254e072f049974386f58ee583a24be99862c480d7347819e", size = 6609971, upload-time = "2025-10-21T19:26:39.462Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/1f/8df630beb55734d46f4340144c7350c0c820eb4fedc0d3c38f5a694b4eea/ddtrace-3.16.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1ed11dfc53359e2c0d8aa377ceddd1409520dafc7d41c3e5e3432897160c2b23", size = 8744007, upload-time = "2025-10-21T19:26:41.421Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/eb/70f518a5da3d0b1ca5585e81f3af1ce5119cee9cb4289134ce76d4614ad0/ddtrace-3.16.2-cp311-cp311-win32.whl", hash = "sha256:5c64499f3c2cd906be1f01f880b631157c0a8d4f0b074164efead09c0bd22c6d", size = 5043118, upload-time = "2025-10-21T19:26:43.527Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/29/2a4b0cd621151063912f8919f03cc1f3cbf494097861de73d9c000d8f1b7/ddtrace-3.16.2-cp311-cp311-win_amd64.whl", hash = "sha256:53110eca9026052c37751018e1a2d7ce6631fe54d456ebf92e575e6f75d14781", size = 5603897, upload-time = "2025-10-21T19:26:45.52Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/90/a2f44bdac307ed2f428ac16168d57d601586489e2e31cdb7b8360c5a1981/ddtrace-3.16.2-cp311-cp311-win_arm64.whl", hash = "sha256:0246ccbdf2b4f393410cb798219737fff106161941c4aa9048ce7d812192806c", size = 5326912, upload-time = "2025-10-21T19:26:47.654Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/18/763b25401be47ede7b6118f0470ad081c051e8ed60e8beeed0c2b0f4c7ef/ddtrace-3.16.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4c3f770eee6085155c52f24e9d14e413977a3ea2eb0f0338e17bf5e87159a11c", size = 6335314, upload-time = "2025-10-21T19:26:50.051Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/3a/4be0d1ad80384b888c0959e0198a6413b5bcfd17da189289f27b487bfa26/ddtrace-3.16.2-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:296c820c6612fbf863534f638c10d811b31b7985cc30c8f20679d4d70249464d", size = 6684784, upload-time = "2025-10-21T19:26:53.557Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/6e/5bc4ec8404a65832dd5f045064f7af41a02183dbd3f7f7c3e13276d75874/ddtrace-3.16.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:867101dc679c6b77ce77524e434ebe87b22bbece04f10226960fbd38811f4022", size = 7382532, upload-time = "2025-10-21T19:26:55.623Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/61/e2/b0b220b76fb90a91c8a322a1511ff6f28f0e7c9f5174bd2492dd28392bfb/ddtrace-3.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:35dce94534d5a13914ca943f4dd718f76c941f50f99dcfab729b0b9ec3587e0a", size = 7656342, upload-time = "2025-10-21T19:26:57.636Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/ad/f9dbfd6be8fb032d087b3362558947bb39a0329a30b84ead30fcf2c9668e/ddtrace-3.16.2-cp312-cp312-manylinux_2_28_i686.whl", hash = "sha256:975de343cf9c643a5d7b0006d36fe716cdf9957012faea66c1001c6195a964f8", size = 5504371, upload-time = "2025-10-21T19:27:00.054Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/e1/c7375cffa27f4558d2b3e5cba043a9770bce9373f0b9fd2c2441992e7dc9/ddtrace-3.16.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:053d9311c4db88f91b197209be945694afb360dc12153c9fc9eee676db1c2ea3", size = 8398772, upload-time = "2025-10-21T19:27:02.101Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/21/3a/3f6eccf9ddf65bc8dd6fef5c94d742df6cbe3a112dab53f60da4fa691420/ddtrace-3.16.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f3b94723d50ab235608c1e4d216c4ac22dfb9a21a253a488097648e456f1d82b", size = 6588990, upload-time = "2025-10-21T19:27:04.512Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/35/32bdb07845720a2b694495253fb341506deb6b745ed674b5281516824273/ddtrace-3.16.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b07c930bf83ed54996656faafbe065f96402d73231c7005b97f98cd0f002713a", size = 8731538, upload-time = "2025-10-21T19:27:07.007Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/ad/e6fe3e4316191fa54c1211e553e6ee20aa0f3c148096001c55e74b556c39/ddtrace-3.16.2-cp312-cp312-win32.whl", hash = "sha256:c778dfd7bf839ca94b815fcf1985b390060694806535bdd69641986b76915894", size = 5036303, upload-time = "2025-10-21T19:27:09.755Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/20/3fa392f338cfd22c1c8229e6b260c1ab67a01d6e4507cae3c4de7d720e5b/ddtrace-3.16.2-cp312-cp312-win_amd64.whl", hash = "sha256:3a81183b1681ddc04062dbe990770b50b1062bceb5b1780f2526daa6ecd3a909", size = 5594871, upload-time = "2025-10-21T19:27:11.857Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/65/3d08e2e6ac26e8f016e6ced86c070d200380e7d37c73de7f22986c162be5/ddtrace-3.16.2-cp312-cp312-win_arm64.whl", hash = "sha256:1b74d506e660244f7df29e2cda643f058e591da66ef8d2a71435c78eda3b47ce", size = 5313712, upload-time = "2025-10-21T19:27:14.084Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/2e/a7dde061252cd565f92d627b0c372db2747874bdbc17333f154b2be6bb18/ddtrace-3.16.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:74c7b6c2ef0043b902c6cd7eda3cad5a866d1251c83f14b8c49abb8e1e67a2b1", size = 6329879, upload-time = "2025-10-21T19:27:21.968Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/ba/d6e486dc27f9ba04be4ad90789dfbceb394ea5960219ed83aef5c7878634/ddtrace-3.16.2-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:adfe6373014b9f37a99d894986c6117b7fb0486d52b4990c2647f93e838c540a", size = 6679188, upload-time = "2025-10-21T19:27:24.339Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/06/d5/2090abf84fe9cfc941b3a903c638b8157d7a15d016348df6ce7cad733e1c/ddtrace-3.16.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:66e68846b9b617a1311d8b76c1c5d7bee552cb0af68dc2ccf9af91abc55a74af", size = 7377794, upload-time = "2025-10-21T19:27:26.986Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/88/07bccc2d9b22ec6114015468772a3451697fcc3c1a39bfb357bdbdfb43e2/ddtrace-3.16.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6822b1b9cd2f6ad3db4d76a05e22adad21878bcff6bd9ba7b7cd581c9136c00c", size = 7649664, upload-time = "2025-10-21T19:27:29.294Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/e7/433d8d2d5a9614a1adcadf54cf3c21afe151fa714208ff7fd40e3435496c/ddtrace-3.16.2-cp313-cp313-manylinux_2_28_i686.whl", hash = "sha256:9f4f5adedcfb1f42a02ac7c7370887001ce71826141885da49bb2da39a440125", size = 5498780, upload-time = "2025-10-21T19:27:31.703Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/a1/1339416eeeb39dab87f750038460f922e40681b3ca0089d6c6a5139c00ff/ddtrace-3.16.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:070cf3f7571a2b8640a6663f476c157b110b3a88b1c7f7941615b746bc4c6c99", size = 8394692, upload-time = "2025-10-21T19:27:34.52Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/75/24/a380640e605daebfb7791bdcc92bef8d7474ad913807276f9f21117aa323/ddtrace-3.16.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:873a06e21bd00a20d4453cea30bf44e877b657d0cf911aabbf6c8945cfbb31b6", size = 6584925, upload-time = "2025-10-21T19:27:37.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/80/6126563c16d9a28cb03f4318bde899e402662e14e97351569640ff833608/ddtrace-3.16.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:af28ecad6a6379bb5e18979baeb3defe5dea101bc0608e5032a18aa4c38340b2", size = 8727833, upload-time = "2025-10-21T19:27:39.584Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/9c/1b0aa8e0984e3d4f58e7ee9e6a5b938ae45645847f7ae52583fbaff1314e/ddtrace-3.16.2-cp313-cp313-win32.whl", hash = "sha256:a42fc81e7bd6a80c297ee891ac8169a0d1efcdd6b7b0ea5478f3bc3deb1aa8af", size = 5033513, upload-time = "2025-10-21T19:27:44.02Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/66/dfb088db59b580a6e1ff0a180029ff90dab684b6a60abfe4a5dad0f8e19b/ddtrace-3.16.2-cp313-cp313-win_amd64.whl", hash = "sha256:f61291e94d37ae1de5456e11e1f55a56eda6a622a0b4e6a1e59b481cc87e29ba", size = 5592118, upload-time = "2025-10-21T19:27:47.486Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/f4/164c40d8cd0392e86ea58194428fa1bd0fff7224ee01d5fa77abe4ae467c/ddtrace-3.16.2-cp313-cp313-win_arm64.whl", hash = "sha256:c3dedde96f9906556c20c76ef0e7f8a2fc24db394b5a095ce6c6dab21facf20f", size = 5311416, upload-time = "2025-10-21T19:27:50.107Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/22/de77c4fbbbb0f6ae6b6957ea5ccf7662566cac87c4191b12e3ced0bf5bd5/ddtrace-4.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:fb4ae8e3656433dff51defd973dba10096689648849c7198225b4ba3ebb51980", size = 6494051, upload-time = "2025-12-16T20:09:13.979Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/10/ef09907e0579efd2da60bc9f7155927b5ea9ed4dd582437489c975f88ee5/ddtrace-4.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:5febca426fb6192dfbdf89fcccca07a37cc826ee3ed56d5056559e838209f37a", size = 6897193, upload-time = "2025-12-16T20:09:15.929Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/21/e0/fd04c0c0db3ec14c491505e0147aef5bdaaf1715d7a8ebe9e1bebfe7ee78/ddtrace-4.0.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:57fe91e21889b67c920c17e02fd7ede9452a34295a17bae2e20a839adb557ca3", size = 7591528, upload-time = "2025-12-16T20:09:17.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/2a/60c78706838ebc1db85c8775c667b2155a9c56b3c18cb65450bf20c64e6b/ddtrace-4.0.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:93db5c87981af73ef5753ce47eb248e379bb7cfc230eb45ddd1b085e45708904", size = 7878277, upload-time = "2025-12-16T20:09:19.627Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/95/31b65e51af746d8ba0e7173075c2c0c862065f076208b517a95879575ecc/ddtrace-4.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8546e1bf69b663cb008556ed6ed78d379ee652732e95e5f6a274c586d9f7536a", size = 8545841, upload-time = "2025-12-16T20:09:21.555Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/f3/72445c91ade150be37475d9e097629576992594a81d6d7e1aef49c07d59a/ddtrace-4.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f626f21a2e46d3356b0c8bce9cc755617251ea768ebd7a4071496f968a99f4f5", size = 8940918, upload-time = "2025-12-16T20:09:25.01Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/6f/8427c4cb312983b7a9f60685abd1f2ff6f57ba75ce0c418e1f23264a972c/ddtrace-4.0.1-cp311-cp311-win32.whl", hash = "sha256:a67e3a37add54f833995aae193014f7dcbae7a88ba8cfdfa4f936cf7e5c1c146", size = 5119431, upload-time = "2025-12-16T20:09:27.096Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/86/5d02d7fb1129f435bba84882e7d01d3ee366ac06e2248bb8cc99c1d67995/ddtrace-4.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:67526ba098a4ee6706ddeaef60634994d7a9341e1a4199378678c23c0ac881b1", size = 5614752, upload-time = "2025-12-16T20:09:29.174Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/a5/8cf03d88ceeee017257a584c764725ce89b451c6c6be8d6cfce3efc7a629/ddtrace-4.0.1-cp311-cp311-win_arm64.whl", hash = "sha256:ccf8e42746c7991f7364b0d8ea4990b0a2747f7b9d307b918f72b92fae94e023", size = 5325937, upload-time = "2025-12-16T20:09:31.352Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/c1/e738c16be9532988896fab2088f0242280c2e5a11650c64b5255b9f9cc3d/ddtrace-4.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:a3c0d61fcab9f4d40f70744233fad3e07410789f35dfa95a16e3cdcaaa3cf229", size = 6496084, upload-time = "2025-12-16T20:09:33.513Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/30/67ee93be4a108f23bdd7d9e24690b8dac18a34bbbee72bdb22b995baf6be/ddtrace-4.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7acaf4ab1a43fa38232d478a51e3c74056eebe4f12d8380d65819e508ef3338b", size = 6901643, upload-time = "2025-12-16T20:09:35.704Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/11/bd/3d13286033f89aef4e615f2699c4f711af1d07bf437adbbdce6cbace626c/ddtrace-4.0.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:49aacfdc485896b6c1424d762708f5638aa70c5db154086f9f2c4bd38af381cf", size = 7574010, upload-time = "2025-12-16T20:09:37.997Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/96/0298a6caa2533e923239dbe0ef0de9073b92ce613e7bdb2ba4b6f648d01e/ddtrace-4.0.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66001bad0ee687e3a0a33cd1e3e085759105572351d19bb57fd88ab488023581", size = 7871849, upload-time = "2025-12-16T20:09:40.185Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/1a/46b8e36568c0103adf071564035fb8ece960e40145d172a09703a6a7e10c/ddtrace-4.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4bd0d06f177442977480364dc1b2d00ad4d074af6c66dc7ce084fa3eb19f54ca", size = 8535939, upload-time = "2025-12-16T20:09:42.509Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/b0/9600d30c9b4ad34eba79865d427f1ed9504786f0801607db9498457bedf4/ddtrace-4.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5c497e5778e8bbb188fd759b2662c9167f3937f6f34cc2b0827ad3507df02315", size = 8935929, upload-time = "2025-12-16T20:09:45.04Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/1d/5129926b0147f2d6acfa5d607ca8c403672c1a374ba5563585773673eb5b/ddtrace-4.0.1-cp312-cp312-win32.whl", hash = "sha256:bd9d58088d145cf685a0cfa774f0920baee256d05451df495b9fe3276ae2a38e", size = 5113021, upload-time = "2025-12-16T20:09:47.527Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/05/5a69c835e4ced4653cbc1731fe5f77380684fe2d001a84b9d2005451ca41/ddtrace-4.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:104b640b2826b92b001a8b36e59568fdaace2bab2fd815906bc134771d006942", size = 5605021, upload-time = "2025-12-16T20:09:49.867Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/10/7638c595b71f4947209423616a09a0f333400f30c82a8b9a2ad91e68eb2c/ddtrace-4.0.1-cp312-cp312-win_arm64.whl", hash = "sha256:ec9f626b9b5e0bcc9d0aaf7eb166bfc7400c9e2e870aa67b57977114b78782b3", size = 5314413, upload-time = "2025-12-16T20:09:52.858Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/09/84/55bc61e8cf57cc13783159e032dabb856baab85ea696f9b66566d968020b/ddtrace-4.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:141c49150dfdb0149d2c1e5bd046639b4f7afd9086b4670eb27556ace41d398e", size = 6490597, upload-time = "2025-12-16T20:09:55.444Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/56/1265401768c3088d8242d405955201a005e25ec79edefdf6a067b534d19c/ddtrace-4.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:807578ac1a09b81781d3ba18a31a361a6fe73a070d5cd776e37a749f22829e72", size = 6897761, upload-time = "2025-12-16T20:09:58.468Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/71/f3/6590cd15d6572621f2e56c7eb2048f60b5eeee0847f92834983f132a0465/ddtrace-4.0.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c6a946090a9e103c26f14c3772b0ffd3ea0a3b5fbb81496315e78fab59328949", size = 7564985, upload-time = "2025-12-16T20:10:01.07Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/60/0ba2b5cef4f96bfa2230e3c597873b65ed3124e369a46069f068adcfafe5/ddtrace-4.0.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2413d4649a02efb9f73f84338d3ae44908d3ecabfc6e2af40a3fc1c66e67bcba", size = 7859984, upload-time = "2025-12-16T20:10:03.44Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/ef/934407db95b6d3e657482b310af3b88a4423d18cb8314058319f5d1be2af/ddtrace-4.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b85abbde9559a94ded61574b45ad1cf6b528856000b115e76fe536b186a9cb4a", size = 8531203, upload-time = "2025-12-16T20:10:06.135Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/a6/d5551232bef69965d21c8157bd5f5dee3671f2a5a9707544b2c6e212885b/ddtrace-4.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5edb950dddcf9310c6d8990c7f38ca3376a1c455f135b0e98619e0f44a4c407", size = 8928902, upload-time = "2025-12-16T20:10:09.065Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/fd/c7a2df468dd536d98707119d9c3c580e69c1d7e99eefb969c697efa52deb/ddtrace-4.0.1-cp313-cp313-win32.whl", hash = "sha256:9ec00c8ddb8dc89ab89466a981b1ac813beaee6310a6cf62904ae625c3efc858", size = 5110305, upload-time = "2025-12-16T20:10:11.972Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/32/8e/bcc9b58fa4b0c7574e7576d0019b3819097b25743f768689c5d4f0257154/ddtrace-4.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:bede88c7777d8eebf9cc693264d0d78a86ffca9b3920e61cdc105a2bc67a7109", size = 5601559, upload-time = "2025-12-16T20:10:14.453Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0b/17/2d7e9e75f1ba0dd36f127886baae7addf789b01e1ab749215c84ad06c512/ddtrace-4.0.1-cp313-cp313-win_arm64.whl", hash = "sha256:d62b0d064607635f2f10b7f535f3780bf1585091aba20cb37d24ab214570f54a", size = 5311762, upload-time = "2025-12-16T20:10:16.83Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2324,18 +2316,9 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/73/91a506e17bb1bc6d20c2c04cf7b459dc58951bfbfe7f97f2c952646b4500/langsmith-0.4.18-py3-none-any.whl", hash = "sha256:ad63154f503678356aadf5b999f40393b4bbd332aee2d04cde3e431c61f2e1c2", size = 376444, upload-time = "2025-08-26T17:00:03.564Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "legacy-cgi"
|
||||
version = "2.6.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a6/ed/300cabc9693209d5a03e2ebc5eb5c4171b51607c08ed84a2b71c9015e0f3/legacy_cgi-2.6.3.tar.gz", hash = "sha256:4c119d6cb8e9d8b6ad7cc0ddad880552c62df4029622835d06dfd18f438a8154", size = 24401, upload-time = "2025-03-27T00:48:56.957Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/33/68c6c38193684537757e0d50a7ccb4f4656e5c2f7cd2be737a9d4a1bff71/legacy_cgi-2.6.3-py3-none-any.whl", hash = "sha256:6df2ea5ae14c71ef6f097f8b6372b44f6685283dc018535a75c924564183cdab", size = 19851, upload-time = "2025-03-27T00:48:55.366Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "letta"
|
||||
version = "0.16.0"
|
||||
version = "0.16.1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiomultiprocess" },
|
||||
@@ -2348,6 +2331,7 @@ dependencies = [
|
||||
{ name = "colorama" },
|
||||
{ name = "datadog" },
|
||||
{ name = "datamodel-code-generator", extra = ["http"] },
|
||||
{ name = "ddtrace" },
|
||||
{ name = "demjson3" },
|
||||
{ name = "docstring-parser" },
|
||||
{ name = "exa-py" },
|
||||
@@ -2498,6 +2482,7 @@ requires-dist = [
|
||||
{ name = "colorama", specifier = ">=0.4.6" },
|
||||
{ name = "datadog", specifier = ">=0.49.1" },
|
||||
{ name = "datamodel-code-generator", extras = ["http"], specifier = ">=0.25.0" },
|
||||
{ name = "ddtrace", specifier = ">=4.0.1" },
|
||||
{ name = "ddtrace", marker = "extra == 'profiling'", specifier = ">=2.18.2" },
|
||||
{ name = "demjson3", specifier = ">=3.0.6" },
|
||||
{ name = "docker", marker = "extra == 'desktop'", specifier = ">=7.1.0" },
|
||||
|
||||
Reference in New Issue
Block a user