chore: bump v0.16.4 (#3168)

This commit is contained in:
cthomas
2026-01-29 12:50:03 -08:00
committed by GitHub
143 changed files with 88406 additions and 803 deletions

View File

@@ -0,0 +1,43 @@
---
name: llm-provider-usage-statistics
description: Reference guide for token counting and prefix caching across LLM providers (OpenAI, Anthropic, Gemini). Use when debugging token counts or optimizing prefix caching.
---
# LLM Provider Usage Statistics
Reference documentation for how different LLM providers report token usage.
## Quick Reference: Token Counting Semantics
| Provider | `input_tokens` meaning | Cache tokens | Must add cache to get total? |
|----------|------------------------|--------------|------------------------------|
| OpenAI | TOTAL (includes cached) | `cached_tokens` is subset | No |
| Anthropic | NON-cached only | `cache_read_input_tokens` + `cache_creation_input_tokens` | **Yes** |
| Gemini | TOTAL (includes cached) | `cached_content_token_count` is subset | No |
**Critical difference:** Anthropic's `input_tokens` excludes cached tokens, so you must add them:
```
total_input = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
```
## Quick Reference: Prefix Caching
| Provider | Min tokens | How to enable | TTL |
|----------|-----------|---------------|-----|
| OpenAI | 1,024 | Automatic | ~5-10 min |
| Anthropic | 1,024 | Requires `cache_control` breakpoints | 5 min |
| Gemini 2.0+ | 1,024 | Automatic (implicit) | Variable |
## Quick Reference: Reasoning/Thinking Tokens
| Provider | Field name | Models |
|----------|-----------|--------|
| OpenAI | `reasoning_tokens` | o1, o3 models |
| Anthropic | N/A | (thinking is in content blocks, not usage) |
| Gemini | `thoughts_token_count` | Gemini 2.0 with thinking enabled |
## Provider Reference Files
- **OpenAI:** [references/openai.md](references/openai.md) - Chat Completions vs Responses API, reasoning models, cached_tokens
- **Anthropic:** [references/anthropic.md](references/anthropic.md) - cache_control setup, beta headers, cache token fields
- **Gemini:** [references/gemini.md](references/gemini.md) - implicit caching, thinking tokens, usage_metadata fields

View File

@@ -0,0 +1,83 @@
# Anthropic Usage Statistics
## Response Format
```
response.usage.input_tokens # NON-cached input tokens only
response.usage.output_tokens # Output tokens
response.usage.cache_read_input_tokens # Tokens read from cache
response.usage.cache_creation_input_tokens # Tokens written to cache
```
## Critical: Token Calculation
**Anthropic's `input_tokens` is NOT the total.** To get total input tokens:
```python
total_input = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
```
This is different from OpenAI/Gemini where `prompt_tokens` is already the total.
## Prefix Caching (Prompt Caching)
**Requirements:**
- Minimum 1,024 tokens for Claude 3.5 Haiku/Sonnet
- Minimum 2,048 tokens for Claude 3 Opus
- Requires explicit `cache_control` breakpoints in messages
- TTL: 5 minutes
**How to enable:**
Add `cache_control` to message content:
```python
{
"role": "user",
"content": [
{
"type": "text",
"text": "...",
"cache_control": {"type": "ephemeral"}
}
]
}
```
**Beta header required:**
```python
betas = ["prompt-caching-2024-07-31"]
```
## Cache Behavior
- `cache_creation_input_tokens`: Tokens that were cached on this request (cache write)
- `cache_read_input_tokens`: Tokens that were read from existing cache (cache hit)
- On first request: expect `cache_creation_input_tokens > 0`
- On subsequent requests with same prefix: expect `cache_read_input_tokens > 0`
## Streaming
In streaming mode, usage is reported in two events:
1. **`message_start`**: Initial usage (may have cache info)
```python
event.message.usage.input_tokens
event.message.usage.output_tokens
event.message.usage.cache_read_input_tokens
event.message.usage.cache_creation_input_tokens
```
2. **`message_delta`**: Cumulative output tokens
```python
event.usage.output_tokens # This is CUMULATIVE, not incremental
```
**Important:** Per Anthropic docs, `message_delta` token counts are cumulative, so assign (don't accumulate).
## Letta Implementation
- **Client:** `letta/llm_api/anthropic_client.py`
- **Streaming interfaces:**
- `letta/interfaces/anthropic_streaming_interface.py`
- `letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py` (tracks cache tokens)
- **Extract method:** `AnthropicClient.extract_usage_statistics()`
- **Cache control:** `_add_cache_control_to_system_message()`, `_add_cache_control_to_messages()`

View File

@@ -0,0 +1,81 @@
# Gemini Usage Statistics
## Response Format
Gemini returns usage in `usage_metadata`:
```
response.usage_metadata.prompt_token_count # Total input tokens
response.usage_metadata.candidates_token_count # Output tokens
response.usage_metadata.total_token_count # Sum
response.usage_metadata.cached_content_token_count # Tokens from cache (optional)
response.usage_metadata.thoughts_token_count # Reasoning tokens (optional)
```
## Token Counting
- `prompt_token_count` is the TOTAL (includes cached)
- `cached_content_token_count` is a subset (when present)
- Similar to OpenAI's semantics
## Implicit Caching (Gemini 2.0+)
**Requirements:**
- Minimum 1,024 tokens
- Automatic (no opt-in required)
- Available on Gemini 2.0 Flash and later models
**Behavior:**
- Caching is probabilistic and server-side
- `cached_content_token_count` may or may not be present
- When present, indicates tokens that were served from cache
**Note:** Unlike Anthropic, Gemini doesn't have explicit cache_control. Caching is implicit and managed by Google's infrastructure.
## Reasoning/Thinking Tokens
For models with extended thinking (like Gemini 2.0 with thinking enabled):
- `thoughts_token_count` reports tokens used for reasoning
- These are similar to OpenAI's `reasoning_tokens`
**Enabling thinking:**
```python
generation_config = {
"thinking_config": {
"thinking_budget": 1024 # Max thinking tokens
}
}
```
## Streaming
In streaming mode:
- `usage_metadata` is typically in the **final chunk**
- Same fields as non-streaming
- May not be present in intermediate chunks
**Important:** `stream_async()` returns an async generator (not awaitable):
```python
# Correct:
stream = client.stream_async(request_data, llm_config)
async for chunk in stream:
...
# Incorrect (will error):
stream = await client.stream_async(...) # TypeError!
```
## APIs
Gemini has two APIs:
- **Google AI (google_ai):** Uses `google.genai` SDK
- **Vertex AI (google_vertex):** Uses same SDK with different auth
Both share the same response format.
## Letta Implementation
- **Client:** `letta/llm_api/google_vertex_client.py` (handles both google_ai and google_vertex)
- **Streaming interface:** `letta/interfaces/gemini_streaming_interface.py`
- **Extract method:** `GoogleVertexClient.extract_usage_statistics()`
- Response is a `GenerateContentResponse` object with `.usage_metadata` attribute

View File

@@ -0,0 +1,61 @@
# OpenAI Usage Statistics
## APIs and Response Formats
OpenAI has two APIs with different response structures:
### Chat Completions API
```
response.usage.prompt_tokens # Total input tokens (includes cached)
response.usage.completion_tokens # Output tokens
response.usage.total_tokens # Sum
response.usage.prompt_tokens_details.cached_tokens # Subset that was cached
response.usage.completion_tokens_details.reasoning_tokens # For o1/o3 models
```
### Responses API (newer)
```
response.usage.input_tokens # Total input tokens
response.usage.output_tokens # Output tokens
response.usage.total_tokens # Sum
response.usage.input_tokens_details.cached_tokens # Subset that was cached
response.usage.output_tokens_details.reasoning_tokens # For reasoning models
```
## Prefix Caching
**Requirements:**
- Minimum 1,024 tokens in the prefix
- Automatic (no opt-in required)
- Cached in 128-token increments
- TTL: approximately 5-10 minutes of inactivity
**Supported models:** GPT-4o, GPT-4o-mini, o1, o1-mini, o3-mini
**Cache behavior:**
- `cached_tokens` will be a multiple of 128
- Cache hit means those tokens were not re-processed
- Cost: cached tokens are cheaper than non-cached
## Reasoning Models (o1, o3)
For reasoning models, additional tokens are used for "thinking":
- `reasoning_tokens` in `completion_tokens_details`
- These are output tokens used for internal reasoning
- Not visible in the response content
## Streaming
In streaming mode, usage is reported in the **final chunk** when `stream_options.include_usage=True`:
```python
request_data["stream_options"] = {"include_usage": True}
```
The final chunk will have `chunk.usage` with the same structure as non-streaming.
## Letta Implementation
- **Client:** `letta/llm_api/openai_client.py`
- **Streaming interface:** `letta/interfaces/openai_streaming_interface.py`
- **Extract method:** `OpenAIClient.extract_usage_statistics()`
- Uses OpenAI SDK's pydantic models (`ChatCompletion`) for type-safe parsing

View File

@@ -0,0 +1,36 @@
"""nullable embedding for archives and passages
Revision ID: 297e8217e952
Revises: 308a180244fc
Create Date: 2026-01-20 14:11:21.137232
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "297e8217e952"
down_revision: Union[str, None] = "308a180244fc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("archival_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.alter_column("archives", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.alter_column("source_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("source_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("archives", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("archival_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
# ### end Alembic commands ###

View File

@@ -0,0 +1,31 @@
"""last_synced column for providers
Revision ID: 308a180244fc
Revises: 82feb220a9b8
Create Date: 2026-01-05 18:54:15.996786
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "308a180244fc"
down_revision: Union[str, None] = "82feb220a9b8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("providers", sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("providers", "last_synced")
# ### end Alembic commands ###

View File

@@ -0,0 +1,32 @@
"""Add v2 protocol fields to provider_traces
Revision ID: 9275f62ad282
Revises: 297e8217e952
Create Date: 2026-01-22
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "9275f62ad282"
down_revision: Union[str, None] = "297e8217e952"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("provider_traces", sa.Column("org_id", sa.String(), nullable=True))
op.add_column("provider_traces", sa.Column("user_id", sa.String(), nullable=True))
op.add_column("provider_traces", sa.Column("compaction_settings", sa.JSON(), nullable=True))
op.add_column("provider_traces", sa.Column("llm_config", sa.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column("provider_traces", "llm_config")
op.drop_column("provider_traces", "compaction_settings")
op.drop_column("provider_traces", "user_id")
op.drop_column("provider_traces", "org_id")

View File

@@ -0,0 +1,59 @@
"""create provider_trace_metadata table
Revision ID: a1b2c3d4e5f8
Revises: 9275f62ad282
Create Date: 2026-01-28
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.settings import settings
revision: str = "a1b2c3d4e5f8"
down_revision: Union[str, None] = "9275f62ad282"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
if not settings.letta_pg_uri_no_default:
return
op.create_table(
"provider_trace_metadata",
sa.Column("id", sa.String(), nullable=False),
sa.Column("step_id", sa.String(), nullable=True),
sa.Column("agent_id", sa.String(), nullable=True),
sa.Column("agent_tags", sa.JSON(), nullable=True),
sa.Column("call_type", sa.String(), nullable=True),
sa.Column("run_id", sa.String(), nullable=True),
sa.Column("source", sa.String(), nullable=True),
sa.Column("org_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("created_at", "id"),
)
op.create_index("ix_provider_trace_metadata_step_id", "provider_trace_metadata", ["step_id"], unique=False)
op.create_index("ix_provider_trace_metadata_id", "provider_trace_metadata", ["id"], unique=True)
def downgrade() -> None:
if not settings.letta_pg_uri_no_default:
return
op.drop_index("ix_provider_trace_metadata_id", table_name="provider_trace_metadata")
op.drop_index("ix_provider_trace_metadata_step_id", table_name="provider_trace_metadata")
op.drop_table("provider_trace_metadata")

View File

@@ -14,7 +14,7 @@ services:
- ./.persist/pgdata-test:/var/lib/postgresql/data
- ./init.sql:/docker-entrypoint-initdb.d/init.sql
ports:
- "5432:5432"
- '5432:5432'
letta_server:
image: letta/letta:latest
hostname: letta
@@ -25,8 +25,8 @@ services:
depends_on:
- letta_db
ports:
- "8083:8083"
- "8283:8283"
- '8083:8083'
- '8283:8283'
environment:
- LETTA_PG_DB=${LETTA_PG_DB:-letta}
- LETTA_PG_USER=${LETTA_PG_USER:-letta}

48587
fern/openapi.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@ try:
__version__ = version("letta")
except PackageNotFoundError:
# Fallback for development installations
__version__ = "0.16.3"
__version__ = "0.16.4"
if os.environ.get("LETTA_VERSION"):
__version__ = os.environ["LETTA_VERSION"]

View File

@@ -27,12 +27,16 @@ class LettaLLMAdapter(ABC):
agent_id: str | None = None,
agent_tags: list[str] | None = None,
run_id: str | None = None,
org_id: str | None = None,
user_id: str | None = None,
) -> None:
self.llm_client: LLMClientBase = llm_client
self.llm_config: LLMConfig = llm_config
self.agent_id: str | None = agent_id
self.agent_tags: list[str] | None = agent_tags
self.run_id: str | None = run_id
self.org_id: str | None = org_id
self.user_id: str | None = user_id
self.message_id: str | None = None
self.request_data: dict | None = None
self.response_data: dict | None = None

View File

@@ -127,6 +127,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
),
),
label="create_provider_trace",

View File

@@ -33,8 +33,10 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
agent_id: str | None = None,
agent_tags: list[str] | None = None,
run_id: str | None = None,
org_id: str | None = None,
user_id: str | None = None,
) -> None:
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id)
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
async def invoke_llm(
@@ -60,7 +62,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
self.request_data = request_data
# Instantiate streaming interface
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock, ProviderType.minimax]:
self.interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=self.llm_config.put_inner_thoughts_in_kwargs,
@@ -68,7 +70,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
run_id=self.run_id,
step_id=step_id,
)
elif self.llm_config.model_endpoint_type == ProviderType.openai:
elif self.llm_config.model_endpoint_type in [ProviderType.openai, ProviderType.openrouter]:
# For non-v1 agents, always use Chat Completions streaming interface
self.interface = OpenAIStreamingInterface(
use_assistant_message=use_assistant_message,
@@ -114,64 +116,9 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
# Extract reasoning content from the interface
self.reasoning_content = self.interface.get_reasoning_content()
# Extract usage statistics
# Some providers don't provide usage in streaming, use fallback if needed
if hasattr(self.interface, "input_tokens") and hasattr(self.interface, "output_tokens"):
# Handle cases where tokens might not be set (e.g., LMStudio)
input_tokens = self.interface.input_tokens
output_tokens = self.interface.output_tokens
# Fallback to estimated values if not provided
if not input_tokens and hasattr(self.interface, "fallback_input_tokens"):
input_tokens = self.interface.fallback_input_tokens
if not output_tokens and hasattr(self.interface, "fallback_output_tokens"):
output_tokens = self.interface.fallback_output_tokens
# Extract cache token data (OpenAI/Gemini use cached_tokens, Anthropic uses cache_read_tokens)
# None means provider didn't report, 0 means provider reported 0
cached_input_tokens = None
if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens is not None:
cached_input_tokens = self.interface.cached_tokens
elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens is not None:
cached_input_tokens = self.interface.cache_read_tokens
# Extract cache write tokens (Anthropic only)
cache_write_tokens = None
if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens is not None:
cache_write_tokens = self.interface.cache_creation_tokens
# Extract reasoning tokens (OpenAI o1/o3 models use reasoning_tokens, Gemini uses thinking_tokens)
reasoning_tokens = None
if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens is not None:
reasoning_tokens = self.interface.reasoning_tokens
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
reasoning_tokens = self.interface.thinking_tokens
# Calculate actual total input tokens
#
# ANTHROPIC: input_tokens is NON-cached only, must add cache tokens
# Total = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
#
# OPENAI/GEMINI: input_tokens is already TOTAL
# cached_tokens is a subset, NOT additive
is_anthropic = hasattr(self.interface, "cache_read_tokens") or hasattr(self.interface, "cache_creation_tokens")
if is_anthropic:
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
else:
actual_input_tokens = input_tokens or 0
self.usage = LettaUsageStatistics(
step_count=1,
completion_tokens=output_tokens or 0,
prompt_tokens=actual_input_tokens,
total_tokens=actual_input_tokens + (output_tokens or 0),
cached_input_tokens=cached_input_tokens,
cache_write_tokens=cache_write_tokens,
reasoning_tokens=reasoning_tokens,
)
else:
# Default usage statistics if not available
self.usage = LettaUsageStatistics(step_count=1, completion_tokens=0, prompt_tokens=0, total_tokens=0)
# Extract usage statistics from the streaming interface
self.usage = self.interface.get_usage_statistics()
self.usage.step_count = 1
# Store any additional data from the interface
self.message_id = self.interface.letta_message_id
@@ -236,6 +183,9 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
),
),
label="create_provider_trace",

View File

@@ -46,6 +46,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type="agent_step",
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
)
try:
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)

View File

@@ -14,8 +14,8 @@ from letta.schemas.enums import ProviderType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import LettaMessageContentUnion
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.streaming_response import get_cancellation_event_for_run
from letta.settings import settings
from letta.utils import safe_create_task
@@ -70,8 +70,11 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
# Store request data
self.request_data = request_data
# Get cancellation event for this run to enable graceful cancellation (before branching)
cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None
# Instantiate streaming interface
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock, ProviderType.minimax]:
# NOTE: different
self.interface = SimpleAnthropicStreamingInterface(
requires_approval_tools=requires_approval_tools,
@@ -81,6 +84,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
elif self.llm_config.model_endpoint_type in [
ProviderType.openai,
ProviderType.deepseek,
ProviderType.openrouter,
ProviderType.zai,
ProviderType.chatgpt_oauth,
]:
@@ -102,6 +106,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
else:
self.interface = SimpleOpenAIStreamingInterface(
@@ -112,12 +117,14 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
model=self.llm_config.model,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
self.interface = SimpleGeminiStreamingInterface(
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
else:
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
@@ -157,68 +164,10 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
# Extract all content parts
self.content: List[LettaMessageContentUnion] = self.interface.get_content()
# Extract usage statistics
# Some providers don't provide usage in streaming, use fallback if needed
if hasattr(self.interface, "input_tokens") and hasattr(self.interface, "output_tokens"):
# Handle cases where tokens might not be set (e.g., LMStudio)
input_tokens = self.interface.input_tokens
output_tokens = self.interface.output_tokens
# Fallback to estimated values if not provided
if not input_tokens and hasattr(self.interface, "fallback_input_tokens"):
input_tokens = self.interface.fallback_input_tokens
if not output_tokens and hasattr(self.interface, "fallback_output_tokens"):
output_tokens = self.interface.fallback_output_tokens
# Extract cache token data (OpenAI/Gemini use cached_tokens)
# None means provider didn't report, 0 means provider reported 0
cached_input_tokens = None
if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens is not None:
cached_input_tokens = self.interface.cached_tokens
# Anthropic uses cache_read_tokens for cache hits
elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens is not None:
cached_input_tokens = self.interface.cache_read_tokens
# Extract cache write tokens (Anthropic only)
# None means provider didn't report, 0 means provider reported 0
cache_write_tokens = None
if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens is not None:
cache_write_tokens = self.interface.cache_creation_tokens
# Extract reasoning tokens (OpenAI o1/o3 models use reasoning_tokens, Gemini uses thinking_tokens)
# None means provider didn't report, 0 means provider reported 0
reasoning_tokens = None
if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens is not None:
reasoning_tokens = self.interface.reasoning_tokens
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
reasoning_tokens = self.interface.thinking_tokens
# Calculate actual total input tokens for context window limit checks (summarization trigger).
#
# ANTHROPIC: input_tokens is NON-cached only, must add cache tokens
# Total = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
#
# OPENAI/GEMINI: input_tokens (prompt_tokens/prompt_token_count) is already TOTAL
# cached_tokens is a subset, NOT additive
# Total = input_tokens (don't add cached_tokens or it double-counts!)
is_anthropic = hasattr(self.interface, "cache_read_tokens") or hasattr(self.interface, "cache_creation_tokens")
if is_anthropic:
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
else:
actual_input_tokens = input_tokens or 0
self.usage = LettaUsageStatistics(
step_count=1,
completion_tokens=output_tokens or 0,
prompt_tokens=actual_input_tokens,
total_tokens=actual_input_tokens + (output_tokens or 0),
cached_input_tokens=cached_input_tokens,
cache_write_tokens=cache_write_tokens,
reasoning_tokens=reasoning_tokens,
)
else:
# Default usage statistics if not available
self.usage = LettaUsageStatistics(step_count=1, completion_tokens=0, prompt_tokens=0, total_tokens=0)
# Extract usage statistics from the interface
# Each interface implements get_usage_statistics() with provider-specific logic
self.usage = self.interface.get_usage_statistics()
self.usage.step_count = 1
# Store any additional data from the interface
self.message_id = self.interface.letta_message_id
@@ -283,6 +232,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
),
),
label="create_provider_trace",

View File

@@ -97,6 +97,25 @@ async def _prepare_in_context_messages_async(
return current_in_context_messages, new_in_context_messages
@trace_method
def validate_persisted_tool_call_ids(tool_return_message: Message, approval_response_message: ApprovalCreate) -> bool:
persisted_tool_returns = tool_return_message.tool_returns
if not persisted_tool_returns:
return False
persisted_tool_call_ids = [tool_return.tool_call_id for tool_return in persisted_tool_returns]
approval_responses = approval_response_message.approvals
if not approval_responses:
return False
approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses]
request_response_diff = set(persisted_tool_call_ids).symmetric_difference(set(approval_response_tool_call_ids))
if request_response_diff:
return False
return True
@trace_method
def validate_approval_tool_call_ids(approval_request_message: Message, approval_response_message: ApprovalCreate):
approval_requests = approval_request_message.tool_calls
@@ -227,6 +246,36 @@ async def _prepare_in_context_messages_no_persist_async(
if input_messages[0].type == "approval":
# User is trying to send an approval response
if current_in_context_messages and current_in_context_messages[-1].role != "approval":
# No pending approval request - check if this is an idempotent retry
# Check last few messages for a tool return matching the approval's tool_call_ids
# (approved tool return should be recent, but server-side tool calls may come after it)
approval_already_processed = False
recent_messages = current_in_context_messages[-10:] # Only check last 10 messages
for msg in reversed(recent_messages):
if msg.role == "tool" and validate_persisted_tool_call_ids(msg, input_messages[0]):
logger.info(
f"Idempotency check: Found matching tool return in recent history. "
f"tool_returns={msg.tool_returns}, approval_response.approvals={input_messages[0].approvals}"
)
approval_already_processed = True
break
if approval_already_processed:
# Approval already handled, just process follow-up messages if any or manually inject keep-alive message
keep_alive_messages = input_messages[1:] or [
MessageCreate(
role="user",
content=[
TextContent(
text="<system-alert>Automated keep-alive ping. Ignore this message and continue from where you stopped.</system-alert>"
)
],
)
]
new_in_context_messages = await create_input_messages(
input_messages=keep_alive_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
)
return current_in_context_messages, new_in_context_messages
logger.warn(
f"Cannot process approval response: No tool call is currently awaiting approval. Last message: {current_in_context_messages[-1]}"
)
@@ -235,7 +284,7 @@ async def _prepare_in_context_messages_no_persist_async(
"Please send a regular message to interact with the agent."
)
validate_approval_tool_call_ids(current_in_context_messages[-1], input_messages[0])
new_in_context_messages = create_approval_response_message_from_input(
new_in_context_messages = await create_approval_response_message_from_input(
agent_state=agent_state, input_message=input_messages[0], run_id=run_id
)
if len(input_messages) > 1:

View File

@@ -218,6 +218,7 @@ class LettaAgent(BaseAgent):
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
run_id: str | None = None,
):
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id,
@@ -330,6 +331,7 @@ class LettaAgent(BaseAgent):
tool_rules_solver,
agent_step_span,
step_metrics,
run_id=run_id,
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -418,6 +420,9 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE
@@ -549,6 +554,7 @@ class LettaAgent(BaseAgent):
llm_config=agent_state.llm_config,
total_tokens=usage.total_tokens,
force=False,
run_id=run_id,
)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
@@ -677,6 +683,7 @@ class LettaAgent(BaseAgent):
tool_rules_solver,
agent_step_span,
step_metrics,
run_id=run_id,
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -766,6 +773,9 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE
@@ -882,6 +892,7 @@ class LettaAgent(BaseAgent):
llm_config=agent_state.llm_config,
total_tokens=usage.total_tokens,
force=False,
run_id=run_id,
)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
@@ -908,6 +919,7 @@ class LettaAgent(BaseAgent):
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
run_id: str | None = None,
) -> AsyncGenerator[str, None]:
"""
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
@@ -1027,6 +1039,8 @@ class LettaAgent(BaseAgent):
agent_state,
llm_client,
tool_rules_solver,
run_id=run_id,
step_id=step_id,
)
step_progression = StepProgression.STREAM_RECEIVED
@@ -1234,6 +1248,9 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE
@@ -1378,6 +1395,7 @@ class LettaAgent(BaseAgent):
llm_config=agent_state.llm_config,
total_tokens=usage.total_tokens,
force=False,
run_id=run_id,
)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
@@ -1441,6 +1459,7 @@ class LettaAgent(BaseAgent):
tool_rules_solver: ToolRulesSolver,
agent_step_span: "Span",
step_metrics: StepMetrics,
run_id: str | None = None,
) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
@@ -1461,6 +1480,7 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
step_id=step_metrics.id,
call_type="agent_step",
)
response = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
@@ -1488,6 +1508,7 @@ class LettaAgent(BaseAgent):
new_letta_messages=new_in_context_messages,
llm_config=agent_state.llm_config,
force=True,
run_id=run_id,
)
new_in_context_messages = []
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
@@ -1503,6 +1524,8 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
run_id: str | None = None,
step_id: str | None = None,
) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
@@ -1530,6 +1553,7 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
step_id=step_id,
call_type="agent_step",
)
@@ -1555,6 +1579,7 @@ class LettaAgent(BaseAgent):
new_letta_messages=new_in_context_messages,
llm_config=agent_state.llm_config,
force=True,
run_id=run_id,
)
new_in_context_messages: list[Message] = []
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
@@ -1568,10 +1593,17 @@ class LettaAgent(BaseAgent):
new_letta_messages: list[Message],
llm_config: LLMConfig,
force: bool,
run_id: str | None = None,
step_id: str | None = None,
) -> list[Message]:
if isinstance(e, ContextWindowExceededError):
return await self._rebuild_context_window(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, llm_config=llm_config, force=force
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
llm_config=llm_config,
force=force,
run_id=run_id,
step_id=step_id,
)
else:
raise llm_client.handle_llm_error(e)
@@ -1584,6 +1616,8 @@ class LettaAgent(BaseAgent):
llm_config: LLMConfig,
total_tokens: int | None = None,
force: bool = False,
run_id: str | None = None,
step_id: str | None = None,
) -> list[Message]:
# If total tokens is reached, we truncate down
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
@@ -1597,6 +1631,8 @@ class LettaAgent(BaseAgent):
new_letta_messages=new_letta_messages,
force=True,
clear=True,
run_id=run_id,
step_id=step_id,
)
else:
# NOTE (Sarah): Seems like this is doing nothing?
@@ -1606,6 +1642,8 @@ class LettaAgent(BaseAgent):
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
run_id=run_id,
step_id=step_id,
)
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_id,

View File

@@ -156,7 +156,11 @@ class LettaAgentV2(BaseAgentV2):
run_id=None,
messages=in_context_messages + input_messages_to_persist,
llm_adapter=LettaLLMRequestAdapter(
llm_client=self.llm_client, llm_config=self.agent_state.llm_config, agent_tags=self.agent_state.tags
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
agent_tags=self.agent_state.tags,
org_id=self.actor.organization_id,
user_id=self.actor.id,
),
dry_run=True,
enforce_run_id_set=False,
@@ -213,6 +217,8 @@ class LettaAgentV2(BaseAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
),
run_id=run_id,
use_assistant_message=use_assistant_message,
@@ -236,6 +242,7 @@ class LettaAgentV2(BaseAgentV2):
new_letta_messages=self.response_messages,
total_tokens=self.usage.total_tokens,
force=False,
run_id=run_id,
)
if self.stop_reason is None:
@@ -297,6 +304,8 @@ class LettaAgentV2(BaseAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
else:
llm_adapter = LettaLLMRequestAdapter(
@@ -305,6 +314,8 @@ class LettaAgentV2(BaseAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
try:
@@ -343,6 +354,7 @@ class LettaAgentV2(BaseAgentV2):
new_letta_messages=self.response_messages,
total_tokens=self.usage.total_tokens,
force=False,
run_id=run_id,
)
except:
@@ -488,6 +500,8 @@ class LettaAgentV2(BaseAgentV2):
in_context_messages=messages,
new_letta_messages=self.response_messages,
force=True,
run_id=run_id,
step_id=step_id,
)
else:
raise e
@@ -1246,6 +1260,8 @@ class LettaAgentV2(BaseAgentV2):
new_letta_messages: list[Message],
total_tokens: int | None = None,
force: bool = False,
run_id: str | None = None,
step_id: str | None = None,
) -> list[Message]:
self.logger.warning("Running deprecated v2 summarizer. This should be removed in the future.")
# always skip summarization if last message is an approval request message
@@ -1268,6 +1284,8 @@ class LettaAgentV2(BaseAgentV2):
new_letta_messages=new_letta_messages,
force=True,
clear=True,
run_id=run_id,
step_id=step_id,
)
else:
# NOTE (Sarah): Seems like this is doing nothing?
@@ -1277,6 +1295,8 @@ class LettaAgentV2(BaseAgentV2):
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
run_id=run_id,
step_id=step_id,
)
except Exception as e:
self.logger.error(f"Failed to summarize conversation history: {e}")

View File

@@ -41,6 +41,7 @@ from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import (
create_approval_request_message_from_llm_response,
create_letta_messages_from_llm_response,
@@ -72,6 +73,16 @@ class LettaAgentV3(LettaAgentV2):
* Support Gemini / OpenAI client
"""
def __init__(
self,
agent_state: AgentState,
actor: User,
conversation_id: str | None = None,
):
super().__init__(agent_state, actor)
# Set conversation_id after parent init (which calls _initialize_state)
self.conversation_id = conversation_id
def _initialize_state(self):
super()._initialize_state()
self._require_tool_call = False
@@ -168,7 +179,13 @@ class LettaAgentV3(LettaAgentV2):
input_messages_to_persist=input_messages_to_persist,
# TODO need to support non-streaming adapter too
llm_adapter=SimpleLLMRequestAdapter(
llm_client=self.llm_client, llm_config=self.agent_state.llm_config, agent_id=self.agent_state.id, run_id=run_id
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
),
run_id=run_id,
# use_assistant_message=use_assistant_message,
@@ -310,14 +327,20 @@ class LettaAgentV3(LettaAgentV2):
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
else:
llm_adapter = SimpleLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
try:
@@ -390,7 +413,9 @@ class LettaAgentV3(LettaAgentV2):
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
except Exception as e:
self.logger.warning(f"Error during agent stream: {e}", exc_info=True)
# Use repr() if str() is empty (happens with Exception() with no args)
error_detail = str(e) or repr(e)
self.logger.warning(f"Error during agent stream: {error_detail}", exc_info=True)
# Set stop_reason if not already set
if self.stop_reason is None:
@@ -411,7 +436,7 @@ class LettaAgentV3(LettaAgentV2):
run_id=run_id,
error_type="internal_error",
message="An error occurred during agent execution.",
detail=str(e),
detail=error_detail,
)
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
@@ -486,10 +511,11 @@ class LettaAgentV3(LettaAgentV2):
new_messages: The new messages to persist
in_context_messages: The current in-context messages
"""
# make sure all the new messages have the correct run_id and step_id
# make sure all the new messages have the correct run_id, step_id, and conversation_id
for message in new_messages:
message.step_id = step_id
message.run_id = run_id
message.conversation_id = self.conversation_id
# persist the new message objects - ONLY place where messages are persisted
persisted_messages = await self.message_manager.create_many_messages_async(
@@ -653,7 +679,15 @@ class LettaAgentV3(LettaAgentV2):
return
step_id = approval_request.step_id
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
if step_id is None:
# Old approval messages may not have step_id set - generate a new one
self.logger.warning(f"Approval request message {approval_request.id} has no step_id, generating new step_id")
step_id = generate_step_id()
step_progression, logged_step, step_metrics, agent_step_span = await self._step_checkpoint_start(
step_id=step_id, run_id=run_id
)
else:
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
else:
# Check for job cancellation at the start of each step
if run_id and await self._check_run_cancellation(run_id):
@@ -760,7 +794,10 @@ class LettaAgentV3(LettaAgentV2):
# TODO: might want to delay this checkpoint in case of corrupated state
try:
summary_message, messages, _ = await self.compact(
messages, trigger_threshold=self.agent_state.llm_config.context_window
messages,
trigger_threshold=self.agent_state.llm_config.context_window,
run_id=run_id,
step_id=step_id,
)
self.logger.info("Summarization succeeded, continuing to retry LLM request")
continue
@@ -776,7 +813,10 @@ class LettaAgentV3(LettaAgentV2):
# update the messages
await self._checkpoint_messages(
run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages
run_id=run_id,
step_id=step_id,
new_messages=[summary_message],
in_context_messages=messages,
)
else:
@@ -879,20 +919,30 @@ class LettaAgentV3(LettaAgentV2):
self.logger.info(
f"Context window exceeded (current: {self.context_token_estimate}, threshold: {self.agent_state.llm_config.context_window}), trying to compact messages"
)
summary_message, messages, _ = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window)
summary_message, messages, _ = await self.compact(
messages,
trigger_threshold=self.agent_state.llm_config.context_window,
run_id=run_id,
step_id=step_id,
)
# TODO: persist + return the summary message
# TODO: convert this to a SummaryMessage
self.response_messages.append(summary_message)
for message in Message.to_letta_messages(summary_message):
yield message
await self._checkpoint_messages(
run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages
run_id=run_id,
step_id=step_id,
new_messages=[summary_message],
in_context_messages=messages,
)
except Exception as e:
# NOTE: message persistence does not happen in the case of an exception (rollback to previous state)
self.logger.warning(f"Error during step processing: {e}")
self.job_update_metadata = {"error": str(e)}
# Use repr() if str() is empty (happens with Exception() with no args)
error_detail = str(e) or repr(e)
self.logger.warning(f"Error during step processing: {error_detail}")
self.job_update_metadata = {"error": error_detail}
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not self.stop_reason:
@@ -1445,7 +1495,12 @@ class LettaAgentV3(LettaAgentV2):
@trace_method
async def compact(
self, messages, trigger_threshold: Optional[int] = None, compaction_settings: Optional["CompactionSettings"] = None
self,
messages,
trigger_threshold: Optional[int] = None,
compaction_settings: Optional["CompactionSettings"] = None,
run_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> tuple[Message, list[Message], str]:
"""Compact the current in-context messages for this agent.
@@ -1472,7 +1527,7 @@ class LettaAgentV3(LettaAgentV2):
summarizer_config = CompactionSettings(model=handle)
# Build the LLMConfig used for summarization
summarizer_llm_config = self._build_summarizer_llm_config(
summarizer_llm_config = await self._build_summarizer_llm_config(
agent_llm_config=self.agent_state.llm_config,
summarizer_config=summarizer_config,
)
@@ -1484,6 +1539,10 @@ class LettaAgentV3(LettaAgentV2):
llm_config=summarizer_llm_config,
summarizer_config=summarizer_config,
in_context_messages=messages,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
step_id=step_id,
)
elif summarizer_config.mode == "sliding_window":
try:
@@ -1492,6 +1551,10 @@ class LettaAgentV3(LettaAgentV2):
llm_config=summarizer_llm_config,
summarizer_config=summarizer_config,
in_context_messages=messages,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
step_id=step_id,
)
except Exception as e:
self.logger.error(f"Sliding window summarization failed with exception: {str(e)}. Falling back to all mode.")
@@ -1500,6 +1563,10 @@ class LettaAgentV3(LettaAgentV2):
llm_config=summarizer_llm_config,
summarizer_config=summarizer_config,
in_context_messages=messages,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
step_id=step_id,
)
summarization_mode_used = "all"
else:
@@ -1533,6 +1600,10 @@ class LettaAgentV3(LettaAgentV2):
llm_config=self.agent_state.llm_config,
summarizer_config=summarizer_config,
in_context_messages=compacted_messages,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
step_id=step_id,
)
summarization_mode_used = "all"
@@ -1584,8 +1655,8 @@ class LettaAgentV3(LettaAgentV2):
return summary_message_obj, final_messages, summary
@staticmethod
def _build_summarizer_llm_config(
async def _build_summarizer_llm_config(
self,
agent_llm_config: LLMConfig,
summarizer_config: CompactionSettings,
) -> LLMConfig:
@@ -1611,12 +1682,41 @@ class LettaAgentV3(LettaAgentV2):
model_name = summarizer_config.model
# Start from the agent's config and override model + provider_name + handle
# Note: model_endpoint_type is NOT overridden - the parsed provider_name
# is a custom label (e.g. "claude-pro-max"), not the endpoint type (e.g. "anthropic")
base = agent_llm_config.model_copy()
base.provider_name = provider_name
base.model = model_name
base.handle = summarizer_config.model
# Check if the summarizer's provider matches the agent's provider
# If they match, we can safely use the agent's config as a base
# If they don't match, we need to load the default config for the new provider
from letta.schemas.enums import ProviderType
provider_matches = False
try:
# Check if provider_name is a valid ProviderType that matches agent's endpoint type
provider_type = ProviderType(provider_name)
provider_matches = provider_type.value == agent_llm_config.model_endpoint_type
except ValueError:
# provider_name is a custom label - check if it matches agent's provider_name
provider_matches = provider_name == agent_llm_config.provider_name
if provider_matches:
# Same provider - use agent's config as base and override model/handle
base = agent_llm_config.model_copy()
base.model = model_name
base.handle = summarizer_config.model
else:
# Different provider - load default config for this handle
from letta.services.provider_manager import ProviderManager
provider_manager = ProviderManager()
try:
base = await provider_manager.get_llm_config_from_handle(
handle=summarizer_config.model,
actor=self.actor,
)
except Exception as e:
self.logger.warning(
f"Failed to load LLM config for summarizer handle '{summarizer_config.model}': {e}. "
f"Falling back to agent's LLM config."
)
return agent_llm_config
# If explicit model_settings are provided for the summarizer, apply
# them just like server.create_agent_async does for agents.

View File

@@ -25,7 +25,7 @@ PROVIDER_ORDER = {
"xai": 12,
"lmstudio": 13,
"zai": 14,
"openrouter": 15, # Note: OpenRouter uses OpenRouterProvider, not a ProviderType enum
"openrouter": 15,
}
ADMIN_PREFIX = "/v1/admin"

View File

@@ -7,8 +7,52 @@ from typing import Dict, Optional, Tuple
from letta.errors import LettaToolCreateError
from letta.types import JsonDict
_ALLOWED_TYPING_NAMES = {name: obj for name, obj in vars(typing).items() if not name.startswith("_")}
_ALLOWED_BUILTIN_TYPES = {name: obj for name, obj in vars(builtins).items() if isinstance(obj, type)}
_ALLOWED_TYPE_NAMES = {**_ALLOWED_TYPING_NAMES, **_ALLOWED_BUILTIN_TYPES, "typing": typing}
def resolve_type(annotation: str):
def _resolve_annotation_node(node: ast.AST):
if isinstance(node, ast.Name):
if node.id == "None":
return type(None)
if node.id in _ALLOWED_TYPE_NAMES:
return _ALLOWED_TYPE_NAMES[node.id]
raise ValueError(f"Unsupported annotation name: {node.id}")
if isinstance(node, ast.Attribute):
if isinstance(node.value, ast.Name) and node.value.id == "typing" and node.attr in _ALLOWED_TYPING_NAMES:
return _ALLOWED_TYPING_NAMES[node.attr]
raise ValueError("Unsupported annotation attribute")
if isinstance(node, ast.Subscript):
origin = _resolve_annotation_node(node.value)
args = _resolve_subscript_slice(node.slice)
return origin[args]
if isinstance(node, ast.Tuple):
return tuple(_resolve_annotation_node(elt) for elt in node.elts)
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
left = _resolve_annotation_node(node.left)
right = _resolve_annotation_node(node.right)
return left | right
if isinstance(node, ast.Constant) and node.value is None:
return type(None)
raise ValueError("Unsupported annotation expression")
def _resolve_subscript_slice(slice_node: ast.AST):
if isinstance(slice_node, ast.Index):
slice_node = slice_node.value
if isinstance(slice_node, ast.Tuple):
return tuple(_resolve_annotation_node(elt) for elt in slice_node.elts)
return _resolve_annotation_node(slice_node)
def resolve_type(annotation: str, *, allow_unsafe_eval: bool = False, extra_globals: Optional[Dict[str, object]] = None):
"""
Resolve a type annotation string into a Python type.
Previously, primitive support for int, float, str, dict, list, set, tuple, bool.
@@ -23,15 +67,23 @@ def resolve_type(annotation: str):
ValueError: If the annotation is unsupported or invalid.
"""
python_types = {**vars(typing), **vars(builtins)}
if extra_globals:
python_types.update(extra_globals)
if annotation in python_types:
return python_types[annotation]
try:
# Allow use of typing and builtins in a safe eval context
return eval(annotation, python_types)
parsed = ast.parse(annotation, mode="eval")
return _resolve_annotation_node(parsed.body)
except Exception:
raise ValueError(f"Unsupported annotation: {annotation}")
if allow_unsafe_eval:
try:
return eval(annotation, python_types)
except Exception as exc:
raise ValueError(f"Unsupported annotation: {annotation}") from exc
raise ValueError(f"Unsupported annotation: {annotation}")
# TODO :: THIS MUST BE EDITED TO HANDLE THINGS
@@ -62,14 +114,34 @@ def get_function_annotations_from_source(source_code: str, function_name: str) -
# NOW json_loads -> ast.literal_eval -> typing.get_origin
def coerce_dict_args_by_annotations(function_args: JsonDict, annotations: Dict[str, str]) -> dict:
def coerce_dict_args_by_annotations(
function_args: JsonDict,
annotations: Dict[str, object],
*,
allow_unsafe_eval: bool = False,
extra_globals: Optional[Dict[str, object]] = None,
) -> dict:
coerced_args = dict(function_args) # Shallow copy
for arg_name, value in coerced_args.items():
if arg_name in annotations:
annotation_str = annotations[arg_name]
try:
arg_type = resolve_type(annotation_str)
annotation_value = annotations[arg_name]
if isinstance(annotation_value, str):
arg_type = resolve_type(
annotation_value,
allow_unsafe_eval=allow_unsafe_eval,
extra_globals=extra_globals,
)
elif isinstance(annotation_value, typing.ForwardRef):
arg_type = resolve_type(
annotation_value.__forward_arg__,
allow_unsafe_eval=allow_unsafe_eval,
extra_globals=extra_globals,
)
else:
arg_type = annotation_value
# Always parse strings using literal_eval or json if possible
if isinstance(value, str):

View File

@@ -383,12 +383,12 @@ def memory_replace(agent_state: "AgentState", label: str, old_str: str, new_str:
# snippet = "\n".join(new_value.split("\n")[start_line : end_line + 1])
# Prepare the success message
success_msg = f"The core memory block with label `{label}` has been edited. "
# success_msg += self._make_output(
# snippet, f"a snippet of {path}", start_line + 1
# )
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the memory block again if necessary."
success_msg = (
f"The core memory block with label `{label}` has been successfully edited. "
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
f"Review the changes and make sure they are as expected (correct indentation, "
f"no duplicate lines, etc). Edit the memory block again if necessary."
)
# return None
return success_msg
@@ -454,14 +454,12 @@ def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_li
agent_state.memory.update_block_value(label=label, value=new_value)
# Prepare the success message
success_msg = f"The core memory block with label `{label}` has been edited. "
# success_msg += self._make_output(
# snippet,
# "a snippet of the edited file",
# max(1, insert_line - SNIPPET_LINES + 1),
# )
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the memory block again if necessary."
success_msg = (
f"The core memory block with label `{label}` has been successfully edited. "
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
f"Review the changes and make sure they are as expected (correct indentation, "
f"no duplicate lines, etc). Edit the memory block again if necessary."
)
return success_msg
@@ -532,12 +530,12 @@ def memory_rethink(agent_state: "AgentState", label: str, new_memory: str) -> No
agent_state.memory.update_block_value(label=label, value=new_memory)
# Prepare the success message
success_msg = f"The core memory block with label `{label}` has been edited. "
# success_msg += self._make_output(
# snippet, f"a snippet of {path}", start_line + 1
# )
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the memory block again if necessary."
success_msg = (
f"The core memory block with label `{label}` has been successfully edited. "
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
f"Review the changes and make sure they are as expected (correct indentation, "
f"no duplicate lines, etc). Edit the memory block again if necessary."
)
# return None
return success_msg

View File

@@ -166,3 +166,61 @@ async def _convert_message_create_to_message(
batch_item_id=message_create.batch_item_id,
run_id=run_id,
)
async def _resolve_url_to_base64(url: str) -> tuple[str, str]:
"""Resolve URL to base64 data and media type."""
if url.startswith("file://"):
parsed = urlparse(url)
file_path = unquote(parsed.path)
image_bytes = await asyncio.to_thread(lambda: open(file_path, "rb").read())
media_type, _ = mimetypes.guess_type(file_path)
media_type = media_type or "image/jpeg"
else:
image_bytes, media_type = await _fetch_image_from_url(url)
media_type = media_type or mimetypes.guess_type(url)[0] or "image/png"
image_data = base64.standard_b64encode(image_bytes).decode("utf-8")
return image_data, media_type
async def resolve_tool_return_images(func_response: str | list) -> str | list:
"""Resolve URL and LettaImage sources to base64 for tool returns."""
if isinstance(func_response, str):
return func_response
resolved = []
for part in func_response:
if isinstance(part, ImageContent):
if part.source.type == ImageSourceType.url:
image_data, media_type = await _resolve_url_to_base64(part.source.url)
part.source = Base64Image(media_type=media_type, data=image_data)
elif part.source.type == ImageSourceType.letta and not part.source.data:
pass
resolved.append(part)
elif isinstance(part, TextContent):
resolved.append(part)
elif isinstance(part, dict):
if part.get("type") == "image" and part.get("source", {}).get("type") == "url":
url = part["source"].get("url")
if url:
image_data, media_type = await _resolve_url_to_base64(url)
resolved.append(
ImageContent(
source=Base64Image(
media_type=media_type,
data=image_data,
detail=part.get("source", {}).get("detail"),
)
)
)
else:
resolved.append(part)
elif part.get("type") == "text":
resolved.append(TextContent(text=part.get("text", "")))
else:
resolved.append(part)
else:
resolved.append(part)
return resolved

View File

@@ -7,6 +7,7 @@ from datetime import datetime, timezone
from typing import Any, Callable, List, Optional, Tuple
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.errors import LettaInvalidArgumentError
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, TagMatchMode
@@ -321,6 +322,7 @@ class TurbopufferClient:
actor: "PydanticUser",
tags: Optional[List[str]] = None,
created_at: Optional[datetime] = None,
embeddings: Optional[List[List[float]]] = None,
) -> List[PydanticPassage]:
"""Insert passages into Turbopuffer.
@@ -332,6 +334,7 @@ class TurbopufferClient:
actor: User actor for embedding generation
tags: Optional list of tags to attach to all passages
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
embeddings: Optional pre-computed embeddings (must match 1:1 with text_chunks). If provided, skips embedding generation.
Returns:
List of PydanticPassage objects that were inserted
@@ -345,9 +348,30 @@ class TurbopufferClient:
logger.warning("All text chunks were empty, skipping insertion")
return []
# generate embeddings using the default config
filtered_texts = [text for _, text in filtered_chunks]
embeddings = await self._generate_embeddings(filtered_texts, actor)
# use provided embeddings only if dimensions match TPUF's expected dimension
use_provided_embeddings = False
if embeddings is not None:
if len(embeddings) != len(text_chunks):
raise LettaInvalidArgumentError(
f"embeddings length ({len(embeddings)}) must match text_chunks length ({len(text_chunks)})",
argument_name="embeddings",
)
# check if first non-empty embedding has correct dimensions
filtered_indices = [i for i, _ in filtered_chunks]
sample_embedding = embeddings[filtered_indices[0]] if filtered_indices else None
if sample_embedding is not None and len(sample_embedding) == self.default_embedding_config.embedding_dim:
use_provided_embeddings = True
filtered_embeddings = [embeddings[i] for i, _ in filtered_chunks]
else:
logger.debug(
f"Embedding dimension mismatch (got {len(sample_embedding) if sample_embedding else 'None'}, "
f"expected {self.default_embedding_config.embedding_dim}), regenerating embeddings"
)
if not use_provided_embeddings:
filtered_embeddings = await self._generate_embeddings(filtered_texts, actor)
namespace_name = await self._get_archive_namespace_name(archive_id)
@@ -379,7 +403,7 @@ class TurbopufferClient:
tags_arrays = [] # Store tags as arrays
passages = []
for (original_idx, text), embedding in zip(filtered_chunks, embeddings):
for (original_idx, text), embedding in zip(filtered_chunks, filtered_embeddings):
passage_id = passage_ids[original_idx]
# append to columns

View File

@@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
from letta.server.rest_api.streaming_response import RunCancelledException
from letta.server.rest_api.utils import decrement_message_uuid
logger = get_logger(__name__)
@@ -145,6 +146,26 @@ class SimpleAnthropicStreamingInterface:
return tool_calls[0]
return None
def get_usage_statistics(self) -> "LettaUsageStatistics":
"""Extract usage statistics from accumulated streaming data.
Returns:
LettaUsageStatistics with token counts from the stream.
"""
from letta.schemas.usage import LettaUsageStatistics
# Anthropic: input_tokens is NON-cached only, must add cache tokens for total
actual_input_tokens = (self.input_tokens or 0) + (self.cache_read_tokens or 0) + (self.cache_creation_tokens or 0)
return LettaUsageStatistics(
prompt_tokens=actual_input_tokens,
completion_tokens=self.output_tokens or 0,
total_tokens=actual_input_tokens + (self.output_tokens or 0),
cached_input_tokens=self.cache_read_tokens if self.cache_read_tokens else None,
cache_write_tokens=self.cache_creation_tokens if self.cache_creation_tokens else None,
reasoning_tokens=None, # Anthropic doesn't report reasoning tokens separately
)
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
def _process_group(
group: list[ReasoningMessage | HiddenReasoningMessage | AssistantMessage],
@@ -228,10 +249,10 @@ class SimpleAnthropicStreamingInterface:
prev_message_type = new_message_type
# print(f"Yielding message: {message}")
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:

View File

@@ -41,6 +41,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
from letta.server.rest_api.streaming_response import RunCancelledException
logger = get_logger(__name__)
@@ -127,6 +128,25 @@ class AnthropicStreamingInterface:
arguments = str(json.dumps(tool_input, indent=2))
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
def get_usage_statistics(self) -> "LettaUsageStatistics":
"""Extract usage statistics from accumulated streaming data.
Returns:
LettaUsageStatistics with token counts from the stream.
"""
from letta.schemas.usage import LettaUsageStatistics
# Anthropic: input_tokens is NON-cached only in streaming
# This interface doesn't track cache tokens, so we just use the raw values
return LettaUsageStatistics(
prompt_tokens=self.input_tokens or 0,
completion_tokens=self.output_tokens or 0,
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
cached_input_tokens=None, # This interface doesn't track cache tokens
cache_write_tokens=None,
reasoning_tokens=None,
)
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
"""
Check if inner thoughts are complete in the current tool call arguments
@@ -218,10 +238,10 @@ class AnthropicStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:
@@ -636,6 +656,25 @@ class SimpleAnthropicStreamingInterface:
arguments = str(json.dumps(tool_input, indent=2))
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
def get_usage_statistics(self) -> "LettaUsageStatistics":
"""Extract usage statistics from accumulated streaming data.
Returns:
LettaUsageStatistics with token counts from the stream.
"""
from letta.schemas.usage import LettaUsageStatistics
# Anthropic: input_tokens is NON-cached only in streaming
# This interface doesn't track cache tokens, so we just use the raw values
return LettaUsageStatistics(
prompt_tokens=self.input_tokens or 0,
completion_tokens=self.output_tokens or 0,
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
cached_input_tokens=None, # This interface doesn't track cache tokens
cache_write_tokens=None,
reasoning_tokens=None,
)
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
def _process_group(
group: list[ReasoningMessage | HiddenReasoningMessage | AssistantMessage],
@@ -726,10 +765,10 @@ class SimpleAnthropicStreamingInterface:
prev_message_type = new_message_type
# print(f"Yielding message: {message}")
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:

View File

@@ -26,6 +26,7 @@ from letta.schemas.letta_message_content import (
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.streaming_response import RunCancelledException
from letta.server.rest_api.utils import decrement_message_uuid
from letta.utils import get_tool_call_id
@@ -43,9 +44,11 @@ class SimpleGeminiStreamingInterface:
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# self.messages = messages
# self.tools = tools
@@ -89,6 +92,9 @@ class SimpleGeminiStreamingInterface:
# Raw usage from provider (for transparent logging in provider trace)
self.raw_usage: dict | None = None
# Track cancellation status
self.stream_was_cancelled: bool = False
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
"""This is (unusually) in chunked format, instead of merged"""
for content in self.content_parts:
@@ -116,6 +122,27 @@ class SimpleGeminiStreamingInterface:
"""Return all finalized tool calls collected during this message (parallel supported)."""
return list(self.collected_tool_calls)
def get_usage_statistics(self) -> "LettaUsageStatistics":
"""Extract usage statistics from accumulated streaming data.
Returns:
LettaUsageStatistics with token counts from the stream.
Note:
Gemini uses `thinking_tokens` instead of `reasoning_tokens` (OpenAI o1/o3).
"""
from letta.schemas.usage import LettaUsageStatistics
return LettaUsageStatistics(
prompt_tokens=self.input_tokens or 0,
completion_tokens=self.output_tokens or 0,
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
# Gemini: input_tokens is already total, cached_tokens is a subset (not additive)
cached_input_tokens=self.cached_tokens,
cache_write_tokens=None, # Gemini doesn't report cache write tokens
reasoning_tokens=self.thinking_tokens, # Gemini uses thinking_tokens
)
async def process(
self,
stream: AsyncIterator[GenerateContentResponse],
@@ -137,10 +164,10 @@ class SimpleGeminiStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:
@@ -164,7 +191,11 @@ class SimpleGeminiStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
logger.info("GeminiStreamingInterface: Stream processing complete.")
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(f"GeminiStreamingInterface: Stream processing complete. stream was cancelled: {self.stream_was_cancelled}")
async def _process_event(
self,

View File

@@ -54,6 +54,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.server.rest_api.streaming_response import RunCancelledException
from letta.server.rest_api.utils import decrement_message_uuid
from letta.services.context_window_calculator.token_counter import create_token_counter
from letta.streaming_utils import (
@@ -82,6 +83,7 @@ class OpenAIStreamingInterface:
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.use_assistant_message = use_assistant_message
@@ -93,6 +95,7 @@ class OpenAIStreamingInterface:
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
@@ -191,6 +194,28 @@ class OpenAIStreamingInterface:
function=FunctionCall(arguments=self._get_current_function_arguments(), name=function_name),
)
def get_usage_statistics(self) -> "LettaUsageStatistics":
"""Extract usage statistics from accumulated streaming data.
Returns:
LettaUsageStatistics with token counts from the stream.
"""
from letta.schemas.usage import LettaUsageStatistics
# Use actual tokens if available, otherwise fall back to estimated
input_tokens = self.input_tokens if self.input_tokens else self.fallback_input_tokens
output_tokens = self.output_tokens if self.output_tokens else self.fallback_output_tokens
return LettaUsageStatistics(
prompt_tokens=input_tokens or 0,
completion_tokens=output_tokens or 0,
total_tokens=(input_tokens or 0) + (output_tokens or 0),
# OpenAI: input_tokens is already total, cached_tokens is a subset (not additive)
cached_input_tokens=None, # This interface doesn't track cache tokens
cache_write_tokens=None,
reasoning_tokens=None, # This interface doesn't track reasoning tokens
)
async def process(
self,
stream: AsyncStream[ChatCompletionChunk],
@@ -226,14 +251,15 @@ class OpenAIStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -267,6 +293,10 @@ class OpenAIStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"OpenAIStreamingInterface: Stream processing complete. "
f"Received {self.total_events_received} events, "
@@ -561,9 +591,11 @@ class SimpleOpenAIStreamingInterface:
model: str = None,
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -662,6 +694,28 @@ class SimpleOpenAIStreamingInterface:
raise ValueError("No tool calls available")
return calls[0]
def get_usage_statistics(self) -> "LettaUsageStatistics":
"""Extract usage statistics from accumulated streaming data.
Returns:
LettaUsageStatistics with token counts from the stream.
"""
from letta.schemas.usage import LettaUsageStatistics
# Use actual tokens if available, otherwise fall back to estimated
input_tokens = self.input_tokens if self.input_tokens else self.fallback_input_tokens
output_tokens = self.output_tokens if self.output_tokens else self.fallback_output_tokens
return LettaUsageStatistics(
prompt_tokens=input_tokens or 0,
completion_tokens=output_tokens or 0,
total_tokens=(input_tokens or 0) + (output_tokens or 0),
# OpenAI: input_tokens is already total, cached_tokens is a subset (not additive)
cached_input_tokens=self.cached_tokens,
cache_write_tokens=None, # OpenAI doesn't have cache write tokens
reasoning_tokens=self.reasoning_tokens,
)
async def process(
self,
stream: AsyncStream[ChatCompletionChunk],
@@ -715,14 +769,15 @@ class SimpleOpenAIStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -764,6 +819,10 @@ class SimpleOpenAIStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"SimpleOpenAIStreamingInterface: Stream processing complete. "
f"Received {self.total_events_received} events, "
@@ -932,6 +991,7 @@ class SimpleOpenAIResponsesStreamingInterface:
model: str = None,
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.is_openai_proxy = is_openai_proxy
self.messages = messages
@@ -946,6 +1006,7 @@ class SimpleOpenAIResponsesStreamingInterface:
self.message_id = None
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -1063,6 +1124,24 @@ class SimpleOpenAIResponsesStreamingInterface:
raise ValueError("No tool calls available")
return calls[0]
def get_usage_statistics(self) -> "LettaUsageStatistics":
"""Extract usage statistics from accumulated streaming data.
Returns:
LettaUsageStatistics with token counts from the stream.
"""
from letta.schemas.usage import LettaUsageStatistics
return LettaUsageStatistics(
prompt_tokens=self.input_tokens or 0,
completion_tokens=self.output_tokens or 0,
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
# OpenAI Responses API: input_tokens is already total
cached_input_tokens=self.cached_tokens,
cache_write_tokens=None, # OpenAI doesn't have cache write tokens
reasoning_tokens=self.reasoning_tokens,
)
async def process(
self,
stream: AsyncStream[ResponseStreamEvent],
@@ -1102,14 +1181,15 @@ class SimpleOpenAIResponsesStreamingInterface:
)
# Continue to next event rather than killing the stream
continue
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -1136,6 +1216,10 @@ class SimpleOpenAIResponsesStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"ResponsesAPI Stream processing complete. "
f"Received {self.total_events_received} events, "

View File

@@ -48,6 +48,7 @@ from letta.schemas.openai.chat_completion_response import (
UsageStatistics,
)
from letta.schemas.response_format import JsonSchemaResponseFormat
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
@@ -777,6 +778,18 @@ class AnthropicClient(LLMClientBase):
if not block.get("text", "").strip():
block["text"] = "."
# Strip trailing whitespace from final assistant message
# Anthropic API rejects messages where "final assistant content cannot end with trailing whitespace"
if is_final_assistant:
if isinstance(content, str):
msg["content"] = content.rstrip()
elif isinstance(content, list) and len(content) > 0:
# Find and strip trailing whitespace from the last text block
for block in reversed(content):
if isinstance(block, dict) and block.get("type") == "text":
block["text"] = block.get("text", "").rstrip()
break
try:
count_params = {
"model": model or "claude-3-7-sonnet-20250219",
@@ -976,6 +989,35 @@ class AnthropicClient(LLMClientBase):
return super().handle_llm_error(e)
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
"""Extract usage statistics from Anthropic response and return as LettaUsageStatistics."""
if not response_data:
return LettaUsageStatistics()
response = AnthropicMessage(**response_data)
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
# Extract cache data if available (None means not reported, 0 means reported as 0)
cache_read_tokens = None
cache_creation_tokens = None
if hasattr(response.usage, "cache_read_input_tokens"):
cache_read_tokens = response.usage.cache_read_input_tokens
if hasattr(response.usage, "cache_creation_input_tokens"):
cache_creation_tokens = response.usage.cache_creation_input_tokens
# Per Anthropic docs: "Total input tokens in a request is the summation of
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
actual_input_tokens = prompt_tokens + (cache_read_tokens or 0) + (cache_creation_tokens or 0)
return LettaUsageStatistics(
prompt_tokens=actual_input_tokens,
completion_tokens=completion_tokens,
total_tokens=actual_input_tokens + completion_tokens,
cached_input_tokens=cache_read_tokens,
cache_write_tokens=cache_creation_tokens,
)
# TODO: Input messages doesn't get used here
# TODO: Clean up this interface
@trace_method
@@ -1020,10 +1062,13 @@ class AnthropicClient(LLMClientBase):
}
"""
response = AnthropicMessage(**response_data)
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
finish_reason = remap_finish_reason(str(response.stop_reason))
# Extract usage via centralized method
from letta.schemas.enums import ProviderType
usage_stats = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.anthropic)
content = None
reasoning_content = None
reasoning_content_signature = None
@@ -1088,35 +1133,12 @@ class AnthropicClient(LLMClientBase):
),
)
# Build prompt tokens details with cache data if available
prompt_tokens_details = None
cache_read_tokens = 0
cache_creation_tokens = 0
if hasattr(response.usage, "cache_read_input_tokens") or hasattr(response.usage, "cache_creation_input_tokens"):
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
cache_read_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
cache_creation_tokens = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
# Per Anthropic docs: "Total input tokens in a request is the summation of
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
actual_input_tokens = prompt_tokens + cache_read_tokens + cache_creation_tokens
chat_completion_response = ChatCompletionResponse(
id=response.id,
choices=[choice],
created=get_utc_time_int(),
model=response.model,
usage=UsageStatistics(
prompt_tokens=actual_input_tokens,
completion_tokens=completion_tokens,
total_tokens=actual_input_tokens + completion_tokens,
prompt_tokens_details=prompt_tokens_details,
),
usage=usage_stats,
)
if llm_config.put_inner_thoughts_in_kwargs:
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(

View File

@@ -54,6 +54,7 @@ from letta.schemas.openai.chat_completion_response import (
UsageStatistics,
)
from letta.schemas.providers.chatgpt_oauth import ChatGPTOAuthCredentials, ChatGPTOAuthProvider
from letta.schemas.usage import LettaUsageStatistics
logger = get_logger(__name__)
@@ -511,6 +512,25 @@ class ChatGPTOAuthClient(LLMClientBase):
# Response should already be in ChatCompletion format after transformation
return ChatCompletionResponse(**response_data)
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
"""Extract usage statistics from ChatGPT OAuth response and return as LettaUsageStatistics."""
if not response_data:
return LettaUsageStatistics()
usage = response_data.get("usage")
if not usage:
return LettaUsageStatistics()
prompt_tokens = usage.get("prompt_tokens") or 0
completion_tokens = usage.get("completion_tokens") or 0
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
return LettaUsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
@trace_method
async def stream_async(
self,

View File

@@ -39,6 +39,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings, settings
from letta.utils import get_tool_call_id
@@ -415,6 +416,34 @@ class GoogleVertexClient(LLMClientBase):
return request_data
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
"""Extract usage statistics from Gemini response and return as LettaUsageStatistics."""
if not response_data:
return LettaUsageStatistics()
response = GenerateContentResponse(**response_data)
if not response.usage_metadata:
return LettaUsageStatistics()
cached_tokens = None
if (
hasattr(response.usage_metadata, "cached_content_token_count")
and response.usage_metadata.cached_content_token_count is not None
):
cached_tokens = response.usage_metadata.cached_content_token_count
reasoning_tokens = None
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None:
reasoning_tokens = response.usage_metadata.thoughts_token_count
return LettaUsageStatistics(
prompt_tokens=response.usage_metadata.prompt_token_count or 0,
completion_tokens=response.usage_metadata.candidates_token_count or 0,
total_tokens=response.usage_metadata.total_token_count or 0,
cached_input_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
@trace_method
async def convert_response_to_chat_completion(
self,
@@ -642,36 +671,10 @@ class GoogleVertexClient(LLMClientBase):
# "totalTokenCount": 36
# }
if response.usage_metadata:
# Extract cache token data if available (Gemini uses cached_content_token_count)
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
prompt_tokens_details = None
if (
hasattr(response.usage_metadata, "cached_content_token_count")
and response.usage_metadata.cached_content_token_count is not None
):
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
# Extract usage via centralized method
from letta.schemas.enums import ProviderType
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cached_tokens=response.usage_metadata.cached_content_token_count,
)
# Extract thinking/reasoning token data if available (Gemini uses thoughts_token_count)
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
completion_tokens_details = None
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None:
from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=response.usage_metadata.thoughts_token_count,
)
usage = UsageStatistics(
prompt_tokens=response.usage_metadata.prompt_token_count,
completion_tokens=response.usage_metadata.candidates_token_count,
total_tokens=response.usage_metadata.total_token_count,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
)
usage = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.google_ai)
else:
# Count it ourselves using the Gemini token counting API
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"

View File

@@ -43,6 +43,14 @@ class GroqClient(OpenAIClient):
data["logprobs"] = False
data["n"] = 1
# for openai.BadRequestError: Error code: 400 - {'error': {'message': "'messages.2' : for 'role:assistant' the following must be satisfied[('messages.2' : property 'reasoning_content' is unsupported)]", 'type': 'invalid_request_error'}}
if "messages" in data:
for message in data["messages"]:
if "reasoning_content" in message:
del message["reasoning_content"]
if "reasoning_content_signature" in message:
del message["reasoning_content_signature"]
return data
@trace_method

View File

@@ -167,8 +167,8 @@ def create(
printd("unsetting function_call because functions is None")
function_call = None
# openai
if llm_config.model_endpoint_type == "openai":
# openai and openrouter (OpenAI-compatible)
if llm_config.model_endpoint_type in ["openai", "openrouter"]:
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])

View File

@@ -93,6 +93,21 @@ class LLMClient:
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.minimax:
from letta.llm_api.minimax_client import MiniMaxClient
return MiniMaxClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.openrouter:
# OpenRouter uses OpenAI-compatible API, so we can use the OpenAI client directly
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.deepseek:
from letta.llm_api.deepseek_client import DeepseekClient

View File

@@ -15,6 +15,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.usage import LettaUsageStatistics
from letta.services.telemetry_manager import TelemetryManager
from letta.settings import settings
@@ -43,6 +44,10 @@ class LLMClientBase:
self._telemetry_run_id: Optional[str] = None
self._telemetry_step_id: Optional[str] = None
self._telemetry_call_type: Optional[str] = None
self._telemetry_org_id: Optional[str] = None
self._telemetry_user_id: Optional[str] = None
self._telemetry_compaction_settings: Optional[Dict] = None
self._telemetry_llm_config: Optional[Dict] = None
def set_telemetry_context(
self,
@@ -52,6 +57,10 @@ class LLMClientBase:
run_id: Optional[str] = None,
step_id: Optional[str] = None,
call_type: Optional[str] = None,
org_id: Optional[str] = None,
user_id: Optional[str] = None,
compaction_settings: Optional[Dict] = None,
llm_config: Optional[Dict] = None,
) -> None:
"""Set telemetry context for provider trace logging."""
self._telemetry_manager = telemetry_manager
@@ -60,6 +69,14 @@ class LLMClientBase:
self._telemetry_run_id = run_id
self._telemetry_step_id = step_id
self._telemetry_call_type = call_type
self._telemetry_org_id = org_id
self._telemetry_user_id = user_id
self._telemetry_compaction_settings = compaction_settings
self._telemetry_llm_config = llm_config
def extract_usage_statistics(self, response_data: Optional[dict], llm_config: LLMConfig) -> LettaUsageStatistics:
"""Provider-specific usage parsing hook (override in subclasses). Returns LettaUsageStatistics."""
return LettaUsageStatistics()
async def request_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""Wrapper around request_async that logs telemetry for all requests including errors.
@@ -96,6 +113,10 @@ class LLMClientBase:
agent_tags=self._telemetry_agent_tags,
run_id=self._telemetry_run_id,
call_type=self._telemetry_call_type,
org_id=self._telemetry_org_id,
user_id=self._telemetry_user_id,
compaction_settings=self._telemetry_compaction_settings,
llm_config=self._telemetry_llm_config,
),
)
except Exception as e:
@@ -137,6 +158,10 @@ class LLMClientBase:
agent_tags=self._telemetry_agent_tags,
run_id=self._telemetry_run_id,
call_type=self._telemetry_call_type,
org_id=self._telemetry_org_id,
user_id=self._telemetry_user_id,
compaction_settings=self._telemetry_compaction_settings,
llm_config=self._telemetry_llm_config,
),
)
except Exception as e:

View File

@@ -0,0 +1,175 @@
from typing import List, Optional, Union
import anthropic
from anthropic import AsyncStream
from anthropic.types.beta import BetaMessage, BetaRawMessageStreamEvent
from letta.llm_api.anthropic_client import AnthropicClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.settings import model_settings
logger = get_logger(__name__)
class MiniMaxClient(AnthropicClient):
"""
MiniMax LLM client using Anthropic-compatible API.
Uses the beta messages API to ensure compatibility with Anthropic streaming interfaces.
Temperature must be in range (0.0, 1.0].
Some Anthropic params are ignored: top_k, stop_sequences, service_tier, etc.
Documentation: https://platform.minimax.io/docs/api-reference/text-anthropic-api
Note: We override client creation to always use llm_config.model_endpoint as base_url
(required for BYOK where provider_name is user's custom name, not "minimax").
We also override request methods to avoid passing Anthropic-specific beta headers.
"""
@trace_method
def _get_anthropic_client(
self, llm_config: LLMConfig, async_client: bool = False
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
"""Create Anthropic client configured for MiniMax API."""
api_key, _, _ = self.get_byok_overrides(llm_config)
if not api_key:
api_key = model_settings.minimax_api_key
# Always use model_endpoint for base_url (works for both base and BYOK providers)
base_url = llm_config.model_endpoint
if async_client:
return anthropic.AsyncAnthropic(api_key=api_key, base_url=base_url, max_retries=model_settings.anthropic_max_retries)
return anthropic.Anthropic(api_key=api_key, base_url=base_url, max_retries=model_settings.anthropic_max_retries)
@trace_method
async def _get_anthropic_client_async(
self, llm_config: LLMConfig, async_client: bool = False
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
"""Create Anthropic client configured for MiniMax API (async version)."""
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
if not api_key:
api_key = model_settings.minimax_api_key
# Always use model_endpoint for base_url (works for both base and BYOK providers)
base_url = llm_config.model_endpoint
if async_client:
return anthropic.AsyncAnthropic(api_key=api_key, base_url=base_url, max_retries=model_settings.anthropic_max_retries)
return anthropic.Anthropic(api_key=api_key, base_url=base_url, max_retries=model_settings.anthropic_max_retries)
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Synchronous request to MiniMax API.
Uses beta messages API for compatibility with Anthropic streaming interfaces.
"""
client = self._get_anthropic_client(llm_config, async_client=False)
response: BetaMessage = client.beta.messages.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Asynchronous request to MiniMax API.
Uses beta messages API for compatibility with Anthropic streaming interfaces.
"""
client = await self._get_anthropic_client_async(llm_config, async_client=True)
try:
response: BetaMessage = await client.beta.messages.create(**request_data)
return response.model_dump()
except ValueError as e:
# Handle streaming fallback if needed (similar to Anthropic client)
if "streaming is required" in str(e).lower():
logger.warning(
"[MiniMax] Non-streaming request rejected. Falling back to streaming mode. Error: %s",
str(e),
)
return await self._request_via_streaming(request_data, llm_config, betas=[])
raise
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
"""
Asynchronous streaming request to MiniMax API.
Uses beta messages API for compatibility with Anthropic streaming interfaces.
"""
client = await self._get_anthropic_client_async(llm_config, async_client=True)
request_data["stream"] = True
try:
return await client.beta.messages.create(**request_data)
except Exception as e:
logger.error(f"Error streaming MiniMax request: {e}")
raise e
@trace_method
def build_request_data(
self,
agent_type: AgentType,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
"""
Build request data for MiniMax API.
Inherits most logic from AnthropicClient, with MiniMax-specific adjustments:
- Temperature must be in range (0.0, 1.0]
"""
data = super().build_request_data(
agent_type,
messages,
llm_config,
tools,
force_tool_call,
requires_subsequent_tool_call,
tool_return_truncation_chars,
)
# MiniMax temperature range is (0.0, 1.0], recommended value: 1
if data.get("temperature") is not None:
temp = data["temperature"]
if temp <= 0:
data["temperature"] = 0.01 # Minimum valid value (exclusive of 0)
logger.warning(f"[MiniMax] Temperature {temp} is invalid. Clamped to 0.01.")
elif temp > 1.0:
data["temperature"] = 1.0 # Maximum valid value
logger.warning(f"[MiniMax] Temperature {temp} is invalid. Clamped to 1.0.")
# MiniMax ignores these Anthropic-specific parameters, but we can remove them
# to avoid potential issues (they won't cause errors, just ignored)
# Note: We don't remove them since MiniMax silently ignores them
return data
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
"""
All MiniMax M2.x models support native interleaved thinking.
Unlike Anthropic where only certain models (Claude 3.7+) support extended thinking,
all MiniMax models natively support thinking blocks without beta headers.
"""
return True
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
"""MiniMax models support all tool choice modes."""
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
"""MiniMax doesn't currently advertise structured output support."""
return False

View File

@@ -60,6 +60,7 @@ from letta.schemas.openai.chat_completion_response import (
)
from letta.schemas.openai.responses_request import ResponsesRequest
from letta.schemas.response_format import JsonSchemaResponseFormat
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings
logger = get_logger(__name__)
@@ -169,6 +170,7 @@ def supports_content_none(llm_config: LLMConfig) -> bool:
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key, _, _ = self.get_byok_overrides(llm_config)
has_byok_key = api_key is not None # Track if we got a BYOK key
# Default to global OpenAI key when no BYOK override
if not api_key:
@@ -181,9 +183,11 @@ class OpenAIClient(LLMClientBase):
llm_config.provider_name == "openrouter"
)
if is_openrouter:
or_key = model_settings.openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
if or_key:
kwargs["api_key"] = or_key
# Only use prod OpenRouter key if no BYOK key was provided
if not has_byok_key:
or_key = model_settings.openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
if or_key:
kwargs["api_key"] = or_key
# Attach optional headers if provided
headers = {}
if model_settings.openrouter_referer:
@@ -207,6 +211,7 @@ class OpenAIClient(LLMClientBase):
async def _prepare_client_kwargs_async(self, llm_config: LLMConfig) -> dict:
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
has_byok_key = api_key is not None # Track if we got a BYOK key
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
@@ -216,9 +221,11 @@ class OpenAIClient(LLMClientBase):
llm_config.provider_name == "openrouter"
)
if is_openrouter:
or_key = model_settings.openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
if or_key:
kwargs["api_key"] = or_key
# Only use prod OpenRouter key if no BYOK key was provided
if not has_byok_key:
or_key = model_settings.openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
if or_key:
kwargs["api_key"] = or_key
headers = {}
if model_settings.openrouter_referer:
headers["HTTP-Referer"] = model_settings.openrouter_referer
@@ -591,6 +598,66 @@ class OpenAIClient(LLMClientBase):
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return is_openai_reasoning_model(llm_config.model)
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
"""Extract usage statistics from OpenAI response and return as LettaUsageStatistics."""
if not response_data:
return LettaUsageStatistics()
# Handle Responses API format (used by reasoning models like o1/o3)
if response_data.get("object") == "response":
usage = response_data.get("usage", {}) or {}
prompt_tokens = usage.get("input_tokens") or 0
completion_tokens = usage.get("output_tokens") or 0
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
input_details = usage.get("input_tokens_details", {}) or {}
cached_tokens = input_details.get("cached_tokens")
output_details = usage.get("output_tokens_details", {}) or {}
reasoning_tokens = output_details.get("reasoning_tokens")
return LettaUsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cached_input_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
# Handle standard Chat Completions API format using pydantic models
from openai.types.chat import ChatCompletion
try:
completion = ChatCompletion.model_validate(response_data)
except Exception:
return LettaUsageStatistics()
if not completion.usage:
return LettaUsageStatistics()
usage = completion.usage
prompt_tokens = usage.prompt_tokens or 0
completion_tokens = usage.completion_tokens or 0
total_tokens = usage.total_tokens or (prompt_tokens + completion_tokens)
# Extract cached tokens from prompt_tokens_details
cached_tokens = None
if usage.prompt_tokens_details:
cached_tokens = usage.prompt_tokens_details.cached_tokens
# Extract reasoning tokens from completion_tokens_details
reasoning_tokens = None
if usage.completion_tokens_details:
reasoning_tokens = usage.completion_tokens_details.reasoning_tokens
return LettaUsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cached_input_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
@trace_method
async def convert_response_to_chat_completion(
self,
@@ -607,30 +674,10 @@ class OpenAIClient(LLMClientBase):
# See example payload in tests/integration_test_send_message_v2.py
model = response_data.get("model")
# Extract usage
usage = response_data.get("usage", {}) or {}
prompt_tokens = usage.get("input_tokens") or 0
completion_tokens = usage.get("output_tokens") or 0
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
# Extract usage via centralized method
from letta.schemas.enums import ProviderType
# Extract detailed token breakdowns (Responses API uses input_tokens_details/output_tokens_details)
prompt_tokens_details = None
input_details = usage.get("input_tokens_details", {}) or {}
if input_details.get("cached_tokens"):
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cached_tokens=input_details.get("cached_tokens") or 0,
)
completion_tokens_details = None
output_details = usage.get("output_tokens_details", {}) or {}
if output_details.get("reasoning_tokens"):
from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=output_details.get("reasoning_tokens") or 0,
)
usage_stats = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.openai)
# Extract assistant message text from the outputs list
outputs = response_data.get("output") or []
@@ -698,13 +745,7 @@ class OpenAIClient(LLMClientBase):
choices=[choice],
created=int(response_data.get("created_at") or 0),
model=model or (llm_config.model if hasattr(llm_config, "model") else None),
usage=UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
),
usage=usage_stats,
)
return chat_completion_response

View File

@@ -0,0 +1 @@
"""Model specification utilities for Letta."""

View File

@@ -0,0 +1,120 @@
"""
Utility functions for working with litellm model specifications.
This module provides access to model specifications from the litellm model_prices_and_context_window.json file.
The data is synced from: https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
"""
import json
import os
from typing import Optional
import aiofiles
from async_lru import alru_cache
from letta.log import get_logger
logger = get_logger(__name__)
# Path to the litellm model specs JSON file
MODEL_SPECS_PATH = os.path.join(os.path.dirname(__file__), "model_prices_and_context_window.json")
@alru_cache(maxsize=1)
async def load_model_specs() -> dict:
"""Load the litellm model specifications from the JSON file.
Returns:
dict: The model specifications data
Raises:
FileNotFoundError: If the model specs file is not found
json.JSONDecodeError: If the file is not valid JSON
"""
if not os.path.exists(MODEL_SPECS_PATH):
logger.warning(f"Model specs file not found at {MODEL_SPECS_PATH}")
return {}
try:
async with aiofiles.open(MODEL_SPECS_PATH, "r") as f:
content = await f.read()
return json.loads(content)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse model specs JSON: {e}")
return {}
async def get_model_spec(model_name: str) -> Optional[dict]:
"""Get the specification for a specific model.
Args:
model_name: The name of the model (e.g., "gpt-4o", "gpt-4o-mini")
Returns:
Optional[dict]: The model specification if found, None otherwise
"""
specs = await load_model_specs()
return specs.get(model_name)
async def get_max_input_tokens(model_name: str) -> Optional[int]:
"""Get the max input tokens for a model.
Args:
model_name: The name of the model
Returns:
Optional[int]: The max input tokens if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
return spec.get("max_input_tokens")
async def get_max_output_tokens(model_name: str) -> Optional[int]:
"""Get the max output tokens for a model.
Args:
model_name: The name of the model
Returns:
Optional[int]: The max output tokens if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
# Try max_output_tokens first, fall back to max_tokens
return spec.get("max_output_tokens") or spec.get("max_tokens")
async def get_context_window(model_name: str) -> Optional[int]:
"""Get the context window size for a model.
For most models, this is the max_input_tokens.
Args:
model_name: The name of the model
Returns:
Optional[int]: The context window size if found, None otherwise
"""
return await get_max_input_tokens(model_name)
async def get_litellm_provider(model_name: str) -> Optional[str]:
"""Get the litellm provider for a model.
Args:
model_name: The name of the model
Returns:
Optional[str]: The provider name if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
return spec.get("litellm_provider")

File diff suppressed because it is too large Load Diff

View File

@@ -31,6 +31,7 @@ from letta.orm.prompt import Prompt
from letta.orm.provider import Provider
from letta.orm.provider_model import ProviderModel
from letta.orm.provider_trace import ProviderTrace
from letta.orm.provider_trace_metadata import ProviderTraceMetadata
from letta.orm.run import Run
from letta.orm.run_metrics import RunMetrics
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable

View File

@@ -46,8 +46,8 @@ class Archive(SqlalchemyBase, OrganizationMixin):
default=VectorDBProvider.NATIVE,
doc="The vector database provider used for this archive's passages",
)
embedding_config: Mapped[dict] = mapped_column(
EmbeddingConfigColumn, nullable=False, doc="Embedding configuration for passages in this archive"
embedding_config: Mapped[Optional[dict]] = mapped_column(
EmbeddingConfigColumn, nullable=True, doc="Embedding configuration for passages in this archive"
)
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive")
_vector_db_namespace: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Private field for vector database namespace")

View File

@@ -25,18 +25,18 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
text: Mapped[str] = mapped_column(doc="Passage text content")
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
embedding_config: Mapped[Optional[dict]] = mapped_column(EmbeddingConfigColumn, nullable=True, doc="Embedding configuration")
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
# dual storage: json column for fast retrieval, junction table for efficient queries
tags: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="Tags associated with this passage")
# Vector embedding field based on database type
# Vector embedding field based on database type - nullable for text-only search
if settings.database_engine is DatabaseChoice.POSTGRES:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM), nullable=True)
else:
embedding = Column(CommonVector)
embedding = Column(CommonVector, nullable=True)
@declared_attr
def organization(cls) -> Mapped["Organization"]:

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from sqlalchemy import ForeignKey, String, Text, UniqueConstraint
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
@@ -41,6 +42,11 @@ class Provider(SqlalchemyBase, OrganizationMixin):
api_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted API key or secret key for the provider.")
access_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted access key for the provider.")
# sync tracking
last_synced: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True, doc="Last time models were synced for this provider."
)
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
models: Mapped[list["ProviderModel"]] = relationship("ProviderModel", back_populates="provider", cascade="all, delete-orphan")

View File

@@ -32,5 +32,15 @@ class ProviderTrace(SqlalchemyBase, OrganizationMixin):
String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)"
)
# v2 protocol fields
org_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the organization")
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
compaction_settings: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, doc="Compaction/summarization settings (summarization calls only)"
)
llm_config: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, doc="LLM configuration used for this call (non-summarization calls only)"
)
# Relationships
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")

View File

@@ -0,0 +1,45 @@
import uuid
from datetime import datetime
from typing import Optional
from sqlalchemy import JSON, DateTime, Index, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.provider_trace import ProviderTraceMetadata as PydanticProviderTraceMetadata
class ProviderTraceMetadata(SqlalchemyBase, OrganizationMixin):
"""Metadata-only provider trace storage (no request/response JSON)."""
__tablename__ = "provider_trace_metadata"
__pydantic_model__ = PydanticProviderTraceMetadata
__table_args__ = (
Index("ix_provider_trace_metadata_step_id", "step_id"),
UniqueConstraint("id", name="uq_provider_trace_metadata_id"),
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), primary_key=True, server_default=func.now(), doc="Timestamp when the trace was created"
)
id: Mapped[str] = mapped_column(
String, primary_key=True, doc="Unique provider trace identifier", default=lambda: f"provider_trace-{uuid.uuid4()}"
)
step_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the step that this trace is associated with")
# Telemetry context fields
agent_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the agent that generated this trace")
agent_tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True, doc="Tags associated with the agent for filtering")
call_type: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Type of call (agent_step, summarization, etc.)")
run_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the run this trace is associated with")
source: Mapped[Optional[str]] = mapped_column(
String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)"
)
# v2 protocol fields
org_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the organization")
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
# Relationships
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")

View File

@@ -42,7 +42,9 @@ def handle_db_timeout(func):
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
except QueryCanceledError as e:
logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
logger.error(
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
)
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
return wrapper
@@ -56,7 +58,9 @@ def handle_db_timeout(func):
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
except QueryCanceledError as e:
logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
logger.error(
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
)
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
return async_wrapper
@@ -207,6 +211,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
"""
Constructs the query for listing records.
"""
# Security check: if the model has organization_id column, actor should be provided
if actor is None and hasattr(cls, "organization_id"):
logger.warning(f"SECURITY: Listing org-scoped model {cls.__name__} without actor. This bypasses organization filtering.")
query = select(cls)
if join_model and join_conditions:
@@ -446,6 +454,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
):
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
# Security check: if the model has organization_id column, actor should be provided
# to ensure proper org-scoping. Log a warning if actor is None.
if actor is None and hasattr(cls, "organization_id"):
logger.warning(
f"SECURITY: Reading org-scoped model {cls.__name__} without actor. "
f"IDs: {identifiers}. This bypasses organization filtering."
)
# Start the query
query = select(cls)
# Collect query conditions for better error reporting
@@ -681,6 +697,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
**kwargs,
):
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
# Security check: if the model has organization_id column, actor should be provided
if actor is None and hasattr(cls, "organization_id"):
logger.warning(
f"SECURITY: Calculating size for org-scoped model {cls.__name__} without actor. This bypasses organization filtering."
)
query = select(func.count(1)).select_from(cls)
if actor:

View File

@@ -17,7 +17,7 @@ class ArchiveBase(OrmMetadataBase):
vector_db_provider: VectorDBProvider = Field(
default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages"
)
embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for passages in this archive")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="Embedding configuration for passages in this archive")
metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata")

View File

@@ -63,11 +63,14 @@ class ProviderType(str, Enum):
hugging_face = "hugging-face"
letta = "letta"
lmstudio_openai = "lmstudio_openai"
minimax = "minimax"
mistral = "mistral"
ollama = "ollama"
openai = "openai"
together = "together"
vllm = "vllm"
sglang = "sglang"
openrouter = "openrouter"
xai = "xai"
zai = "zai"

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
from letta.validators import AgentId, BlockId
class ManagerType(str, Enum):
@@ -93,43 +94,43 @@ class RoundRobinManagerUpdate(ManagerConfig):
class SupervisorManager(ManagerConfig):
manager_type: Literal[ManagerType.supervisor] = Field(ManagerType.supervisor, description="")
manager_agent_id: str = Field(..., description="")
manager_agent_id: AgentId = Field(..., description="")
class SupervisorManagerUpdate(ManagerConfig):
manager_type: Literal[ManagerType.supervisor] = Field(ManagerType.supervisor, description="")
manager_agent_id: Optional[str] = Field(..., description="")
manager_agent_id: Optional[AgentId] = Field(..., description="")
class DynamicManager(ManagerConfig):
manager_type: Literal[ManagerType.dynamic] = Field(ManagerType.dynamic, description="")
manager_agent_id: str = Field(..., description="")
manager_agent_id: AgentId = Field(..., description="")
termination_token: Optional[str] = Field("DONE!", description="")
max_turns: Optional[int] = Field(None, description="")
class DynamicManagerUpdate(ManagerConfig):
manager_type: Literal[ManagerType.dynamic] = Field(ManagerType.dynamic, description="")
manager_agent_id: Optional[str] = Field(None, description="")
manager_agent_id: Optional[AgentId] = Field(None, description="")
termination_token: Optional[str] = Field(None, description="")
max_turns: Optional[int] = Field(None, description="")
class SleeptimeManager(ManagerConfig):
manager_type: Literal[ManagerType.sleeptime] = Field(ManagerType.sleeptime, description="")
manager_agent_id: str = Field(..., description="")
manager_agent_id: AgentId = Field(..., description="")
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
class SleeptimeManagerUpdate(ManagerConfig):
manager_type: Literal[ManagerType.sleeptime] = Field(ManagerType.sleeptime, description="")
manager_agent_id: Optional[str] = Field(None, description="")
manager_agent_id: Optional[AgentId] = Field(None, description="")
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
class VoiceSleeptimeManager(ManagerConfig):
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
manager_agent_id: str = Field(..., description="")
manager_agent_id: AgentId = Field(..., description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
@@ -142,7 +143,7 @@ class VoiceSleeptimeManager(ManagerConfig):
class VoiceSleeptimeManagerUpdate(ManagerConfig):
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
manager_agent_id: Optional[str] = Field(None, description="")
manager_agent_id: Optional[AgentId] = Field(None, description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
@@ -170,11 +171,11 @@ ManagerConfigUpdateUnion = Annotated[
class GroupCreate(BaseModel):
agent_ids: List[str] = Field(..., description="")
agent_ids: List[AgentId] = Field(..., description="")
description: str = Field(..., description="")
manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="")
project_id: Optional[str] = Field(None, description="The associated project id.")
shared_block_ids: List[str] = Field([], description="", deprecated=True)
shared_block_ids: List[BlockId] = Field([], description="", deprecated=True)
hidden: Optional[bool] = Field(
None,
description="If set to True, the group will be hidden.",
@@ -190,8 +191,8 @@ class InternalTemplateGroupCreate(GroupCreate):
class GroupUpdate(BaseModel):
agent_ids: Optional[List[str]] = Field(None, description="")
agent_ids: Optional[List[AgentId]] = Field(None, description="")
description: Optional[str] = Field(None, description="")
manager_config: Optional[ManagerConfigUpdateUnion] = Field(None, description="")
project_id: Optional[str] = Field(None, description="The associated project id.")
shared_block_ids: Optional[List[str]] = Field(None, description="", deprecated=True)
shared_block_ids: Optional[List[BlockId]] = Field(None, description="", deprecated=True)

View File

@@ -5,6 +5,7 @@ from pydantic import Field
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
from letta.validators import AgentId, BlockId
class IdentityType(str, Enum):
@@ -57,8 +58,8 @@ class IdentityCreate(LettaBase):
name: str = Field(..., description="The name of the identity.")
identity_type: IdentityType = Field(..., description="The type of the identity.")
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
agent_ids: Optional[List[AgentId]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
block_ids: Optional[List[BlockId]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
@@ -67,8 +68,8 @@ class IdentityUpsert(LettaBase):
name: str = Field(..., description="The name of the identity.")
identity_type: IdentityType = Field(..., description="The type of the identity.")
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
agent_ids: Optional[List[AgentId]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
block_ids: Optional[List[BlockId]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
@@ -76,8 +77,8 @@ class IdentityUpdate(LettaBase):
identifier_key: Optional[str] = Field(None, description="External, user-generated identifier key of the identity.")
name: Optional[str] = Field(None, description="The name of the identity.")
identity_type: Optional[IdentityType] = Field(None, description="The type of the identity.")
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
agent_ids: Optional[List[AgentId]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
block_ids: Optional[List[BlockId]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")

View File

@@ -7,8 +7,10 @@ from pydantic import BaseModel, Field, field_serializer, field_validator
from letta.schemas.letta_message_content import (
LettaAssistantMessageContentUnion,
LettaToolReturnContentUnion,
LettaUserMessageContentUnion,
get_letta_assistant_message_content_union_str_json_schema,
get_letta_tool_return_content_union_str_json_schema,
get_letta_user_message_content_union_str_json_schema,
)
@@ -35,7 +37,11 @@ class ApprovalReturn(MessageReturn):
class ToolReturn(MessageReturn):
type: Literal[MessageReturnType.tool] = Field(default=MessageReturnType.tool, description="The message type to be created.")
tool_return: str
tool_return: Union[str, List[LettaToolReturnContentUnion]] = Field(
...,
description="The tool return value - either a string or list of content parts (text/image)",
json_schema_extra=get_letta_tool_return_content_union_str_json_schema(),
)
status: Literal["success", "error"]
tool_call_id: str
stdout: Optional[List[str]] = None
@@ -563,6 +569,10 @@ class SystemMessageListResult(UpdateSystemMessage):
default=None,
description="The unique identifier of the agent that owns the message.",
)
conversation_id: str | None = Field(
default=None,
description="The unique identifier of the conversation that the message belongs to.",
)
created_at: datetime = Field(..., description="The time the message was created in ISO format.")
@@ -581,6 +591,10 @@ class UserMessageListResult(UpdateUserMessage):
default=None,
description="The unique identifier of the agent that owns the message.",
)
conversation_id: str | None = Field(
default=None,
description="The unique identifier of the conversation that the message belongs to.",
)
created_at: datetime = Field(..., description="The time the message was created in ISO format.")
@@ -599,6 +613,10 @@ class ReasoningMessageListResult(UpdateReasoningMessage):
default=None,
description="The unique identifier of the agent that owns the message.",
)
conversation_id: str | None = Field(
default=None,
description="The unique identifier of the conversation that the message belongs to.",
)
created_at: datetime = Field(..., description="The time the message was created in ISO format.")
@@ -617,6 +635,10 @@ class AssistantMessageListResult(UpdateAssistantMessage):
default=None,
description="The unique identifier of the agent that owns the message.",
)
conversation_id: str | None = Field(
default=None,
description="The unique identifier of the conversation that the message belongs to.",
)
created_at: datetime = Field(..., description="The time the message was created in ISO format.")

View File

@@ -138,6 +138,48 @@ def get_letta_user_message_content_union_str_json_schema():
}
# -------------------------------
# Tool Return Content Types
# -------------------------------
LettaToolReturnContentUnion = Annotated[
Union[TextContent, ImageContent],
Field(discriminator="type"),
]
def create_letta_tool_return_content_union_schema():
return {
"oneOf": [
{"$ref": "#/components/schemas/TextContent"},
{"$ref": "#/components/schemas/ImageContent"},
],
"discriminator": {
"propertyName": "type",
"mapping": {
"text": "#/components/schemas/TextContent",
"image": "#/components/schemas/ImageContent",
},
},
}
def get_letta_tool_return_content_union_str_json_schema():
"""Schema that accepts either string or list of content parts for tool returns."""
return {
"anyOf": [
{
"type": "array",
"items": {
"$ref": "#/components/schemas/LettaToolReturnContentUnion",
},
},
{"type": "string"},
],
}
# -------------------------------
# Assistant Content Types
# -------------------------------

View File

@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MES
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import LettaMessageContentUnion
from letta.schemas.message import MessageCreate, MessageCreateUnion, MessageRole
from letta.validators import AgentId
class ClientToolSchema(BaseModel):
@@ -125,12 +126,33 @@ class LettaStreamingRequest(LettaRequest):
)
class ConversationMessageRequest(LettaRequest):
"""Request for sending messages to a conversation. Streams by default."""
streaming: bool = Field(
default=True,
description="If True (default), returns a streaming response (Server-Sent Events). If False, returns a complete JSON response.",
)
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed, rather than streaming per step (only used when streaming=true).",
)
include_pings: bool = Field(
default=True,
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts (only used when streaming=true).",
)
background: bool = Field(
default=False,
description="Whether to process the request in the background (only used when streaming=true).",
)
class LettaAsyncRequest(LettaRequest):
callback_url: Optional[str] = Field(None, description="Optional callback URL to POST to when the job completes")
class LettaBatchRequest(LettaRequest):
agent_id: str = Field(..., description="The ID of the agent to send this batch request for")
agent_id: AgentId = Field(..., description="The ID of the agent to send this batch request for")
class CreateBatch(BaseModel):

View File

@@ -43,12 +43,14 @@ class LLMConfig(BaseModel):
"koboldcpp",
"vllm",
"hugging-face",
"minimax",
"mistral",
"together", # completions endpoint
"bedrock",
"deepseek",
"xai",
"zai",
"openrouter",
"chatgpt_oauth",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
@@ -320,9 +322,10 @@ class LLMConfig(BaseModel):
GoogleAIModelSettings,
GoogleVertexModelSettings,
GroqModelSettings,
Model,
ModelSettings,
OpenAIModelSettings,
OpenAIReasoning,
OpenRouterModelSettings,
TogetherModelSettings,
XAIModelSettings,
ZAIModelSettings,
@@ -395,15 +398,30 @@ class LLMConfig(BaseModel):
max_output_tokens=self.max_tokens or 4096,
temperature=self.temperature,
)
elif self.model_endpoint_type == "openrouter":
return OpenRouterModelSettings(
max_output_tokens=self.max_tokens or 4096,
temperature=self.temperature,
)
elif self.model_endpoint_type == "chatgpt_oauth":
return ChatGPTOAuthModelSettings(
max_output_tokens=self.max_tokens or 4096,
temperature=self.temperature,
reasoning=ChatGPTOAuthReasoning(reasoning_effort=self.reasoning_effort or "medium"),
)
elif self.model_endpoint_type == "minimax":
# MiniMax uses Anthropic-compatible API
thinking_type = "enabled" if self.enable_reasoner else "disabled"
return AnthropicModelSettings(
max_output_tokens=self.max_tokens or 4096,
temperature=self.temperature,
thinking=AnthropicThinking(type=thinking_type, budget_tokens=self.max_reasoning_tokens or 1024),
verbosity=self.verbosity,
strict=self.strict,
)
else:
# If we don't know the model type, use the default Model schema
return Model(max_output_tokens=self.max_tokens or 4096)
# If we don't know the model type, use the base ModelSettings schema
return ModelSettings(max_output_tokens=self.max_tokens or 4096)
@classmethod
def is_openai_reasoning_model(cls, config: "LLMConfig") -> bool:

View File

@@ -50,6 +50,7 @@ from letta.schemas.letta_message_content import (
ImageContent,
ImageSourceType,
LettaMessageContentUnion,
LettaToolReturnContentUnion,
OmittedReasoningContent,
ReasoningContent,
RedactedReasoningContent,
@@ -71,6 +72,34 @@ def truncate_tool_return(content: Optional[str], limit: Optional[int]) -> Option
return content[:limit] + f"... [truncated {len(content) - limit} chars]"
def _get_text_from_part(part: Union[TextContent, ImageContent, dict]) -> Optional[str]:
"""Extract text from a content part, returning None for images."""
if isinstance(part, TextContent):
return part.text
elif isinstance(part, dict) and part.get("type") == "text":
return part.get("text", "")
return None
def tool_return_to_text(func_response: Optional[Union[str, List]]) -> Optional[str]:
"""Convert tool return content to text, replacing images with placeholders."""
if func_response is None:
return None
if isinstance(func_response, str):
return func_response
text_parts = [text for part in func_response if (text := _get_text_from_part(part))]
image_count = sum(
1 for part in func_response if isinstance(part, ImageContent) or (isinstance(part, dict) and part.get("type") == "image")
)
result = "\n".join(text_parts)
if image_count > 0:
placeholder = "[Image omitted]" if image_count == 1 else f"[{image_count} images omitted]"
result = (result + " " + placeholder) if result else placeholder
return result if result else None
def add_inner_thoughts_to_tool_call(
tool_call: OpenAIToolCall,
inner_thoughts: str,
@@ -366,6 +395,7 @@ class Message(BaseMessage):
message_type=lm.message_type,
content=lm.content,
agent_id=message.agent_id,
conversation_id=message.conversation_id,
created_at=message.created_at,
)
)
@@ -376,6 +406,7 @@ class Message(BaseMessage):
message_type=lm.message_type,
content=lm.content,
agent_id=message.agent_id,
conversation_id=message.conversation_id,
created_at=message.created_at,
)
)
@@ -386,6 +417,7 @@ class Message(BaseMessage):
message_type=lm.message_type,
reasoning=lm.reasoning,
agent_id=message.agent_id,
conversation_id=message.conversation_id,
created_at=message.created_at,
)
)
@@ -396,6 +428,7 @@ class Message(BaseMessage):
message_type=lm.message_type,
content=lm.content,
agent_id=message.agent_id,
conversation_id=message.conversation_id,
created_at=message.created_at,
)
)
@@ -786,8 +819,14 @@ class Message(BaseMessage):
for tool_return in self.tool_returns:
parsed_data = self._parse_tool_response(tool_return.func_response)
# Preserve multi-modal content (ToolReturn supports Union[str, List])
if isinstance(tool_return.func_response, list):
tool_return_value = tool_return.func_response
else:
tool_return_value = parsed_data["message"]
tool_return_obj = LettaToolReturn(
tool_return=parsed_data["message"],
tool_return=tool_return_value,
status=parsed_data["status"],
tool_call_id=tool_return.tool_call_id,
stdout=tool_return.stdout,
@@ -801,11 +840,18 @@ class Message(BaseMessage):
first_tool_return = all_tool_returns[0]
# Convert deprecated string-only field to text (preserve images in tool_returns list)
deprecated_tool_return_text = (
tool_return_to_text(first_tool_return.tool_return)
if isinstance(first_tool_return.tool_return, list)
else first_tool_return.tool_return
)
return ToolReturnMessage(
id=self.id,
date=self.created_at,
# deprecated top-level fields populated from first tool return
tool_return=first_tool_return.tool_return,
tool_return=deprecated_tool_return_text,
status=first_tool_return.status,
tool_call_id=first_tool_return.tool_call_id,
stdout=first_tool_return.stdout,
@@ -840,11 +886,11 @@ class Message(BaseMessage):
"""Check if message has exactly one text content item."""
return self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent)
def _parse_tool_response(self, response_text: str) -> dict:
def _parse_tool_response(self, response_text: Union[str, List]) -> dict:
"""Parse tool response JSON and extract message and status.
Args:
response_text: Raw JSON response text
response_text: Raw JSON response text OR list of content parts (for multi-modal)
Returns:
Dictionary with 'message' and 'status' keys
@@ -852,6 +898,14 @@ class Message(BaseMessage):
Raises:
ValueError: If JSON parsing fails
"""
# Handle multi-modal content (list with text/images)
if isinstance(response_text, list):
text_representation = tool_return_to_text(response_text) or "[Multi-modal content]"
return {
"message": text_representation,
"status": "success",
}
try:
function_return = parse_json(response_text)
return {
@@ -1301,7 +1355,9 @@ class Message(BaseMessage):
tool_return = self.tool_returns[0]
if not tool_return.tool_call_id:
raise TypeError("OpenAI API requires tool_call_id to be set.")
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
# Convert to text first (replaces images with placeholders), then truncate
func_response_text = tool_return_to_text(tool_return.func_response)
func_response = truncate_tool_return(func_response_text, tool_return_truncation_chars)
openai_message = {
"content": func_response,
"role": self.role,
@@ -1356,8 +1412,9 @@ class Message(BaseMessage):
for tr in m.tool_returns:
if not tr.tool_call_id:
raise TypeError("ToolReturn came back without a tool_call_id.")
# Ensure explicit tool_returns are truncated for Chat Completions
func_response = truncate_tool_return(tr.func_response, tool_return_truncation_chars)
# Convert multi-modal to text (images → placeholders), then truncate
func_response_text = tool_return_to_text(tr.func_response)
func_response = truncate_tool_return(func_response_text, tool_return_truncation_chars)
result.append(
{
"content": func_response,
@@ -1418,7 +1475,10 @@ class Message(BaseMessage):
message_dicts.append(user_dict)
elif self.role == "assistant" or self.role == "approval":
assert self.tool_calls is not None or (self.content is not None and len(self.content) > 0)
# Validate that message has content OpenAI Responses API can process
if self.tool_calls is None and (self.content is None or len(self.content) == 0):
# Skip this message (similar to Anthropic handling at line 1308)
return message_dicts
# A few things may be in here, firstly reasoning content, secondly assistant messages, thirdly tool calls
# TODO check if OpenAI Responses is capable of R->A->T like Anthropic?
@@ -1456,17 +1516,17 @@ class Message(BaseMessage):
)
elif self.role == "tool":
# Handle tool returns - similar pattern to Anthropic
# Handle tool returns - supports images via content arrays
if self.tool_returns:
for tool_return in self.tool_returns:
if not tool_return.tool_call_id:
raise TypeError("OpenAI Responses API requires tool_call_id to be set.")
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
output = self._tool_return_to_responses_output(tool_return.func_response, tool_return_truncation_chars)
message_dicts.append(
{
"type": "function_call_output",
"call_id": tool_return.tool_call_id[:max_tool_id_length] if max_tool_id_length else tool_return.tool_call_id,
"output": func_response,
"output": output,
}
)
else:
@@ -1534,6 +1594,50 @@ class Message(BaseMessage):
return None
@staticmethod
def _image_dict_to_data_url(part: dict) -> Optional[str]:
"""Convert image dict to data URL."""
source = part.get("source", {})
if source.get("type") == "base64" and source.get("data"):
media_type = source.get("media_type", "image/png")
return f"data:{media_type};base64,{source['data']}"
elif source.get("type") == "url":
return source.get("url")
return None
@staticmethod
def _tool_return_to_responses_output(
func_response: Optional[Union[str, List]],
tool_return_truncation_chars: Optional[int] = None,
) -> Union[str, List[dict]]:
"""Convert tool return to OpenAI Responses API format."""
if func_response is None:
return ""
if isinstance(func_response, str):
return truncate_tool_return(func_response, tool_return_truncation_chars) or ""
output_parts: List[dict] = []
for part in func_response:
if isinstance(part, TextContent):
text = truncate_tool_return(part.text, tool_return_truncation_chars) or ""
output_parts.append({"type": "input_text", "text": text})
elif isinstance(part, ImageContent):
image_url = Message._image_source_to_data_url(part)
if image_url:
detail = getattr(part.source, "detail", None) or "auto"
output_parts.append({"type": "input_image", "image_url": image_url, "detail": detail})
elif isinstance(part, dict):
if part.get("type") == "text":
text = truncate_tool_return(part.get("text", ""), tool_return_truncation_chars) or ""
output_parts.append({"type": "input_text", "text": text})
elif part.get("type") == "image":
image_url = Message._image_dict_to_data_url(part)
if image_url:
detail = part.get("source", {}).get("detail", "auto")
output_parts.append({"type": "input_image", "image_url": image_url, "detail": detail})
return output_parts if output_parts else ""
@staticmethod
def to_openai_responses_dicts_from_list(
messages: List[Message],
@@ -1550,6 +1654,68 @@ class Message(BaseMessage):
)
return result
@staticmethod
def _get_base64_image_data(part: Union[ImageContent, dict]) -> Optional[tuple[str, str]]:
"""Extract base64 data and media type from ImageContent or dict."""
if isinstance(part, ImageContent):
source = part.source
if source.type == ImageSourceType.base64:
return source.data, source.media_type
elif source.type == ImageSourceType.letta and getattr(source, "data", None):
return source.data, getattr(source, "media_type", None) or "image/png"
elif isinstance(part, dict) and part.get("type") == "image":
source = part.get("source", {})
if source.get("type") == "base64" and source.get("data"):
return source["data"], source.get("media_type", "image/png")
return None
@staticmethod
def _tool_return_to_google_parts(
func_response: Optional[Union[str, List]],
tool_return_truncation_chars: Optional[int] = None,
) -> tuple[str, List[dict]]:
"""Extract text and image parts for Google API format."""
if isinstance(func_response, str):
return truncate_tool_return(func_response, tool_return_truncation_chars) or "", []
text_parts = []
image_parts = []
for part in func_response:
if text := _get_text_from_part(part):
text_parts.append(text)
elif image_data := Message._get_base64_image_data(part):
data, media_type = image_data
image_parts.append({"inlineData": {"data": data, "mimeType": media_type}})
text = truncate_tool_return("\n".join(text_parts), tool_return_truncation_chars) or ""
if image_parts:
suffix = f"[{len(image_parts)} image(s) attached]"
text = f"{text}\n{suffix}" if text else suffix
return text, image_parts
@staticmethod
def _tool_return_to_anthropic_content(
func_response: Optional[Union[str, List]],
tool_return_truncation_chars: Optional[int] = None,
) -> Union[str, List[dict]]:
"""Convert tool return to Anthropic tool_result content format."""
if func_response is None:
return ""
if isinstance(func_response, str):
return truncate_tool_return(func_response, tool_return_truncation_chars) or ""
content: List[dict] = []
for part in func_response:
if text := _get_text_from_part(part):
text = truncate_tool_return(text, tool_return_truncation_chars) or ""
content.append({"type": "text", "text": text})
elif image_data := Message._get_base64_image_data(part):
data, media_type = image_data
content.append({"type": "image", "source": {"type": "base64", "data": data, "media_type": media_type}})
return content if content else ""
def to_anthropic_dict(
self,
current_model: str,
@@ -1628,8 +1794,11 @@ class Message(BaseMessage):
}
elif self.role == "assistant" or self.role == "approval":
# assert self.tool_calls is not None or text_content is not None, vars(self)
assert self.tool_calls is not None or len(self.content) > 0
# Validate that message has content Anthropic API can process
if self.tool_calls is None and (self.content is None or len(self.content) == 0):
# Skip this message (consistent with OpenAI dict handling)
return None
anthropic_message = {
"role": "assistant",
}
@@ -1759,12 +1928,13 @@ class Message(BaseMessage):
f"Message ID: {self.id}, Tool: {self.name or 'unknown'}, "
f"Tool return index: {idx}/{len(self.tool_returns)}"
)
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
# Convert to Anthropic format (supports images)
tool_result_content = self._tool_return_to_anthropic_content(tool_return.func_response, tool_return_truncation_chars)
content.append(
{
"type": "tool_result",
"tool_use_id": resolved_tool_call_id,
"content": func_response,
"content": tool_result_content,
}
)
if content:
@@ -1884,7 +2054,16 @@ class Message(BaseMessage):
}
elif self.role == "assistant" or self.role == "approval":
assert self.tool_calls is not None or text_content is not None or len(self.content) > 1
# Validate that message has content Google API can process
if self.tool_calls is None and text_content is None and len(self.content) <= 1:
# Message has no tool calls, no extractable text, and not multi-part
logger.warning(
f"Assistant/approval message {self.id} has no content Google API can convert: "
f"tool_calls={self.tool_calls}, text_content={text_content}, content={self.content}"
)
# Return None to skip this message (similar to approval messages without tool_calls at line 1998)
return None
google_ai_message = {
"role": "model", # NOTE: different
}
@@ -2003,7 +2182,7 @@ class Message(BaseMessage):
elif self.role == "tool":
# NOTE: Significantly different tool calling format, more similar to function calling format
# Handle tool returns - similar pattern to Anthropic
# Handle tool returns - Google supports images as sibling inlineData parts
if self.tool_returns:
parts = []
for tool_return in self.tool_returns:
@@ -2013,26 +2192,24 @@ class Message(BaseMessage):
# Use the function name if available, otherwise use tool_call_id
function_name = self.name if self.name else tool_return.tool_call_id
# Truncate the tool return if needed
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
text_content, image_parts = Message._tool_return_to_google_parts(
tool_return.func_response, tool_return_truncation_chars
)
# NOTE: Google AI API wants the function response as JSON only, no string
try:
function_response = parse_json(func_response)
function_response = parse_json(text_content)
except:
function_response = {"function_response": func_response}
function_response = {"function_response": text_content}
parts.append(
{
"functionResponse": {
"name": function_name,
"response": {
"name": function_name, # NOTE: name twice... why?
"content": function_response,
},
"response": {"name": function_name, "content": function_response},
}
}
)
parts.extend(image_parts)
google_ai_message = {
"role": "function",
@@ -2325,7 +2502,9 @@ class ToolReturn(BaseModel):
status: Literal["success", "error"] = Field(..., description="The status of the tool call")
stdout: Optional[List[str]] = Field(default=None, description="Captured stdout (e.g. prints, logs) from the tool invocation")
stderr: Optional[List[str]] = Field(default=None, description="Captured stderr from the tool invocation")
func_response: Optional[str] = Field(None, description="The function response string")
func_response: Optional[Union[str, List[LettaToolReturnContentUnion]]] = Field(
None, description="The function response - either a string or list of content parts (text/image)"
)
class MessageSearchRequest(BaseModel):

View File

@@ -42,12 +42,14 @@ class Model(LLMConfig, ModelBase):
"koboldcpp",
"vllm",
"hugging-face",
"minimax",
"mistral",
"together",
"bedrock",
"deepseek",
"xai",
"zai",
"openrouter",
"chatgpt_oauth",
] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True)
context_window: int = Field(
@@ -138,6 +140,7 @@ class Model(LLMConfig, ModelBase):
ProviderType.deepseek: DeepseekModelSettings,
ProviderType.together: TogetherModelSettings,
ProviderType.bedrock: BedrockModelSettings,
ProviderType.openrouter: OpenRouterModelSettings,
}
settings_class = PROVIDER_SETTINGS_MAP.get(self.provider_type)
@@ -456,6 +459,23 @@ class BedrockModelSettings(ModelSettings):
}
class OpenRouterModelSettings(ModelSettings):
"""OpenRouter model configuration (OpenAI-compatible)."""
provider_type: Literal[ProviderType.openrouter] = Field(ProviderType.openrouter, description="The type of the provider.")
temperature: float = Field(0.7, description="The temperature of the model.")
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
def _to_legacy_config_params(self) -> dict:
return {
"temperature": self.temperature,
"max_tokens": self.max_output_tokens,
"response_format": self.response_format,
"parallel_tool_calls": self.parallel_tool_calls,
"strict": False, # OpenRouter does not support strict mode
}
class ChatGPTOAuthReasoning(BaseModel):
"""Reasoning configuration for ChatGPT OAuth models (GPT-5.x, o-series)."""
@@ -495,6 +515,7 @@ ModelSettingsUnion = Annotated[
DeepseekModelSettings,
TogetherModelSettings,
BedrockModelSettings,
OpenRouterModelSettings,
ChatGPTOAuthModelSettings,
],
Field(discriminator="provider_type"),

View File

@@ -44,7 +44,7 @@ class Passage(PassageBase):
embedding: Optional[List[float]] = Field(..., description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfig] = Field(..., description="The embedding configuration used by the passage.")
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.")
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the passage.")
@field_validator("embedding", mode="before")
@classmethod
@@ -83,6 +83,7 @@ class PassageCreate(PassageBase):
# optionally provide embeddings
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
created_at: Optional[datetime] = Field(None, description="Optional creation datetime for the passage.")
class PassageUpdate(PassageCreate):

View File

@@ -29,6 +29,9 @@ class ProviderTrace(BaseProviderTrace):
run_id (str): ID of the run this trace is associated with.
source (str): Source service that generated this trace (memgpt-server, lettuce-py).
organization_id (str): The unique identifier of the organization.
user_id (str): The unique identifier of the user who initiated the request.
compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).
llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).
created_at (datetime): The timestamp when the object was created.
"""
@@ -44,4 +47,30 @@ class ProviderTrace(BaseProviderTrace):
run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with")
source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)")
# v2 protocol fields
org_id: Optional[str] = Field(None, description="ID of the organization")
user_id: Optional[str] = Field(None, description="ID of the user who initiated the request")
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
class ProviderTraceMetadata(BaseProviderTrace):
"""Metadata-only representation of a provider trace (no request/response JSON)."""
id: str = BaseProviderTrace.generate_id_field()
step_id: Optional[str] = Field(None, description="ID of the step that this trace is associated with")
# Telemetry context fields
agent_id: Optional[str] = Field(None, description="ID of the agent that generated this trace")
agent_tags: Optional[list[str]] = Field(None, description="Tags associated with the agent for filtering")
call_type: Optional[str] = Field(None, description="Type of call (agent_step, summarization, etc.)")
run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with")
source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)")
# v2 protocol fields
org_id: Optional[str] = Field(None, description="ID of the organization")
user_id: Optional[str] = Field(None, description="ID of the user who initiated the request")
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

View File

@@ -12,10 +12,12 @@ from .google_vertex import GoogleVertexProvider
from .groq import GroqProvider
from .letta import LettaProvider
from .lmstudio import LMStudioOpenAIProvider
from .minimax import MiniMaxProvider
from .mistral import MistralProvider
from .ollama import OllamaProvider
from .openai import OpenAIProvider
from .openrouter import OpenRouterProvider
from .sglang import SGLangProvider
from .together import TogetherProvider
from .vllm import VLLMProvider
from .xai import XAIProvider
@@ -40,11 +42,13 @@ __all__ = [
"GroqProvider",
"LettaProvider",
"LMStudioOpenAIProvider",
"MiniMaxProvider",
"MistralProvider",
"OllamaProvider",
"OpenAIProvider",
"TogetherProvider",
"VLLMProvider", # Replaces ChatCompletions and Completions
"SGLangProvider",
"XAIProvider",
"ZAIProvider",
"OpenRouterProvider",

View File

@@ -32,6 +32,7 @@ class Provider(ProviderBase):
api_version: str | None = Field(None, description="API version used for requests to the provider.")
organization_id: str | None = Field(None, description="The organization id of the user")
updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
last_synced: datetime | None = Field(None, description="The last time models were synced for this provider.")
# Encrypted fields (stored as Secret objects, serialized to strings for DB)
# Secret class handles validation and serialization automatically via __get_pydantic_core_schema__
@@ -191,9 +192,12 @@ class Provider(ProviderBase):
GroqProvider,
LettaProvider,
LMStudioOpenAIProvider,
MiniMaxProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
OpenRouterProvider,
SGLangProvider,
TogetherProvider,
VLLMProvider,
XAIProvider,
@@ -224,6 +228,8 @@ class Provider(ProviderBase):
return OllamaProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm:
return VLLMProvider(**self.model_dump(exclude_none=True)) # Removed support for CompletionsProvider
case ProviderType.sglang:
return SGLangProvider(**self.model_dump(exclude_none=True))
case ProviderType.mistral:
return MistralProvider(**self.model_dump(exclude_none=True))
case ProviderType.deepseek:
@@ -240,6 +246,10 @@ class Provider(ProviderBase):
return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.bedrock:
return BedrockProvider(**self.model_dump(exclude_none=True))
case ProviderType.minimax:
return MiniMaxProvider(**self.model_dump(exclude_none=True))
case ProviderType.openrouter:
return OpenRouterProvider(**self.model_dump(exclude_none=True))
case _:
raise ValueError(f"Unknown provider type: {self.provider_type}")

View File

@@ -18,6 +18,7 @@ logger = get_logger(__name__)
class BedrockProvider(Provider):
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field("bedrock", description="Identifier for Bedrock endpoint (used for model_endpoint)")
access_key: str | None = Field(None, description="AWS access key ID for Bedrock")
api_key: str | None = Field(None, description="AWS secret access key for Bedrock")
region: str = Field(..., description="AWS region for Bedrock")
@@ -99,7 +100,7 @@ class BedrockProvider(Provider):
LLMConfig(
model=model_name,
model_endpoint_type=self.provider_type.value,
model_endpoint=None,
model_endpoint="bedrock",
context_window=self.get_model_context_window(inference_profile_id),
# Store the full inference profile ID in the handle for API calls
handle=self.get_handle(inference_profile_id),

View File

@@ -0,0 +1,105 @@
from typing import Literal
import anthropic
from pydantic import Field
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.log import get_logger
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.base import Provider
logger = get_logger(__name__)
# MiniMax model specifications from official documentation
# https://platform.minimax.io/docs/guides/models-intro
MODEL_LIST = [
{
"name": "MiniMax-M2.1",
"context_window": 200000,
"max_output": 128000,
"description": "Polyglot code mastery, precision code refactoring (~60 tps)",
},
{
"name": "MiniMax-M2.1-lightning",
"context_window": 200000,
"max_output": 128000,
"description": "Same performance as M2.1, significantly faster (~100 tps)",
},
{
"name": "MiniMax-M2",
"context_window": 200000,
"max_output": 128000,
"description": "Agentic capabilities, advanced reasoning",
},
]
class MiniMaxProvider(Provider):
"""
MiniMax provider using Anthropic-compatible API.
MiniMax models support native interleaved thinking without requiring beta headers.
The API uses the standard messages endpoint (not beta).
Documentation: https://platform.minimax.io/docs/api-reference/text-anthropic-api
"""
provider_type: Literal[ProviderType.minimax] = Field(ProviderType.minimax, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str | None = Field(None, description="API key for the MiniMax API.", deprecated=True)
base_url: str = Field("https://api.minimax.io/anthropic", description="Base URL for the MiniMax Anthropic-compatible API.")
async def check_api_key(self):
"""Check if the API key is valid by making a test request to the MiniMax API."""
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if not api_key:
raise ValueError("No API key provided")
try:
# Use async Anthropic client pointed at MiniMax's Anthropic-compatible endpoint
client = anthropic.AsyncAnthropic(api_key=api_key, base_url=self.base_url)
# Use count_tokens as a lightweight check - similar to Anthropic provider
await client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}])
except anthropic.AuthenticationError as e:
raise LLMAuthenticationError(message=f"Failed to authenticate with MiniMax: {e}", code=ErrorCode.UNAUTHENTICATED)
except Exception as e:
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
def get_default_max_output_tokens(self, model_name: str) -> int:
"""Get the default max output tokens for MiniMax models."""
# All MiniMax models support 128K output tokens
return 128000
def get_model_context_window_size(self, model_name: str) -> int | None:
"""Get the context window size for a MiniMax model."""
# All current MiniMax models have 200K context window
for model in MODEL_LIST:
if model["name"] == model_name:
return model["context_window"]
# Default fallback
return 200000
async def list_llm_models_async(self) -> list[LLMConfig]:
"""
Return available MiniMax models.
MiniMax doesn't have a models listing endpoint, so we use a hardcoded list.
"""
configs = []
for model in MODEL_LIST:
configs.append(
LLMConfig(
model=model["name"],
model_endpoint_type="minimax",
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["name"]),
max_tokens=model["max_output"],
# MiniMax models support native thinking, similar to Claude's extended thinking
put_inner_thoughts_in_kwargs=True,
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs

View File

@@ -42,22 +42,37 @@ class OpenAIProvider(Provider):
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
def get_default_max_output_tokens(self, model_name: str) -> int:
"""Get the default max output tokens for OpenAI models."""
if model_name.startswith("gpt-5"):
return 16384
elif model_name.startswith("o1") or model_name.startswith("o3"):
return 100000
return 16384 # default for openai
"""Get the default max output tokens for OpenAI models (sync fallback)."""
# Simple default for openai
return 16384
async def get_default_max_output_tokens_async(self, model_name: str) -> int:
"""Get the default max output tokens for OpenAI models.
Uses litellm model specifications with a simple fallback.
"""
from letta.model_specs.litellm_model_specs import get_max_output_tokens
# Try litellm specs
max_output = await get_max_output_tokens(model_name)
if max_output is not None:
return max_output
# Simple default for openai
return 16384
async def _get_models_async(self) -> list[dict]:
from letta.llm_api.openai import openai_get_model_list_async
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
# Similar to Nebius
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
# Provider-specific extra parameters for model listing
extra_params = None
if "openrouter.ai" in self.base_url:
# OpenRouter: filter for models with tool calling support
# See: https://openrouter.ai/docs/requests
extra_params = {"supported_parameters": "tools"}
elif "nebius.com" in self.base_url:
# Nebius: use verbose mode for better model info
extra_params = {"verbose": True}
# Decrypt API key before using
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
@@ -76,7 +91,7 @@ class OpenAIProvider(Provider):
async def list_llm_models_async(self) -> list[LLMConfig]:
data = await self._get_models_async()
return self._list_llm_models(data)
return await self._list_llm_models(data)
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
"""Return known OpenAI embedding models.
@@ -116,13 +131,13 @@ class OpenAIProvider(Provider):
),
]
def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
async def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
"""
This handles filtering out LLM Models by provider that meet Letta's requirements.
"""
configs = []
for model in data:
check = self._do_model_checks_for_name_and_context_size(model)
check = await self._do_model_checks_for_name_and_context_size_async(model)
if check is None:
continue
model_name, context_window_size = check
@@ -174,7 +189,7 @@ class OpenAIProvider(Provider):
model_endpoint=self.base_url,
context_window=context_window_size,
handle=handle,
max_tokens=self.get_default_max_output_tokens(model_name),
max_tokens=await self.get_default_max_output_tokens_async(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
@@ -188,12 +203,30 @@ class OpenAIProvider(Provider):
return configs
def _do_model_checks_for_name_and_context_size(self, model: dict, length_key: str = "context_length") -> tuple[str, int] | None:
"""Sync version - uses sync get_model_context_window_size (for subclasses with hardcoded values)."""
if "id" not in model:
logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
return None
model_name = model["id"]
context_window_size = model.get(length_key) or self.get_model_context_window_size(model_name)
context_window_size = self.get_model_context_window_size(model_name)
if not context_window_size:
logger.info("No context window size found for model: %s", model_name)
return None
return model_name, context_window_size
async def _do_model_checks_for_name_and_context_size_async(
self, model: dict, length_key: str = "context_length"
) -> tuple[str, int] | None:
"""Async version - uses async get_model_context_window_size_async (for litellm lookup)."""
if "id" not in model:
logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
return None
model_name = model["id"]
context_window_size = await self.get_model_context_window_size_async(model_name)
if not context_window_size:
logger.info("No context window size found for model: %s", model_name)
@@ -211,19 +244,30 @@ class OpenAIProvider(Provider):
return llm_config
def get_model_context_window_size(self, model_name: str) -> int | None:
if model_name in LLM_MAX_CONTEXT_WINDOW:
return LLM_MAX_CONTEXT_WINDOW[model_name]
else:
logger.debug(
"Model %s on %s for provider %s not found in LLM_MAX_CONTEXT_WINDOW. Using default of {LLM_MAX_CONTEXT_WINDOW['DEFAULT']}",
model_name,
self.base_url,
self.__class__.__name__,
)
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
"""Get the context window size for a model (sync fallback)."""
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
async def get_model_context_window_size_async(self, model_name: str) -> int | None:
"""Get the context window size for a model.
Uses litellm model specifications which covers all OpenAI models.
"""
from letta.model_specs.litellm_model_specs import get_context_window
context_window = await get_context_window(model_name)
if context_window is not None:
return context_window
# Simple fallback
logger.debug(
"Model %s not found in litellm specs. Using default of %s",
model_name,
LLM_MAX_CONTEXT_WINDOW["DEFAULT"],
)
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
def get_model_context_window(self, model_name: str) -> int | None:
return self.get_model_context_window_size(model_name)
async def get_model_context_window_async(self, model_name: str) -> int | None:
return self.get_model_context_window_size(model_name)
return await self.get_model_context_window_size_async(model_name)

View File

@@ -1,52 +1,106 @@
from typing import Literal
from openai import AsyncOpenAI, AuthenticationError
from pydantic import Field
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_CONTEXT_WINDOW
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.log import get_logger
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.openai import OpenAIProvider
logger = get_logger(__name__)
# ALLOWED_PREFIXES = {"gpt-4", "gpt-5", "o1", "o3", "o4"}
# DISALLOWED_KEYWORDS = {"transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro", "chat"}
# DEFAULT_EMBEDDING_BATCH_SIZE = 1024
# Default context window for models not in the API response
DEFAULT_CONTEXT_WINDOW = 128000
class OpenRouterProvider(OpenAIProvider):
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
"""
OpenRouter provider - https://openrouter.ai/
OpenRouter is an OpenAI-compatible API gateway that provides access to
multiple LLM providers (Anthropic, Meta, Mistral, etc.) through a unified API.
"""
provider_type: Literal[ProviderType.openrouter] = Field(ProviderType.openrouter, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str | None = Field(None, description="API key for the OpenRouter API.", deprecated=True)
base_url: str = Field("https://openrouter.ai/api/v1", description="Base URL for the OpenRouter API.")
def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
async def check_api_key(self):
"""Check if the API key is valid by making a test request to the OpenRouter API."""
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if not api_key:
raise ValueError("No API key provided")
try:
# Use async OpenAI client pointed at OpenRouter's endpoint
client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
# Just list models to verify API key works
await client.models.list()
except AuthenticationError as e:
raise LLMAuthenticationError(message=f"Failed to authenticate with OpenRouter: {e}", code=ErrorCode.UNAUTHENTICATED)
except Exception as e:
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
def get_model_context_window_size(self, model_name: str) -> int | None:
"""Get the context window size for an OpenRouter model.
OpenRouter models provide context_length in the API response,
so this is mainly a fallback.
"""
This handles filtering out LLM Models by provider that meet Letta's requirements.
return DEFAULT_CONTEXT_WINDOW
async def list_llm_models_async(self) -> list[LLMConfig]:
"""
Return available OpenRouter models that support tool calling.
OpenRouter provides a models endpoint that supports filtering by supported_parameters.
We filter for models that support 'tools' to ensure Letta compatibility.
"""
from letta.llm_api.openai import openai_get_model_list_async
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
# OpenRouter supports filtering models by supported parameters
# See: https://openrouter.ai/docs/requests
extra_params = {"supported_parameters": "tools"}
response = await openai_get_model_list_async(
self.base_url,
api_key=api_key,
extra_params=extra_params,
)
data = response.get("data", response)
configs = []
for model in data:
check = self._do_model_checks_for_name_and_context_size(model)
if check is None:
if "id" not in model:
logger.warning(f"OpenRouter model missing 'id' field: {model}")
continue
model_name, context_window_size = check
handle = self.get_handle(model_name)
model_name = model["id"]
config = LLMConfig(
model=model_name,
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=handle,
max_tokens=self.get_default_max_output_tokens(model_name),
provider_name=self.name,
provider_category=self.provider_category,
# OpenRouter returns context_length in the model listing
if "context_length" in model and model["context_length"]:
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
logger.debug(f"Model {model_name} missing context_length, using default: {context_window_size}")
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="openrouter",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
max_tokens=self.get_default_max_output_tokens(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
)
config = self._set_model_parameter_tuned_defaults(model_name, config)
configs.append(config)
return configs

View File

@@ -0,0 +1,80 @@
"""
SGLang provider for Letta.
SGLang is a high-performance inference engine that exposes OpenAI-compatible API endpoints.
"""
from typing import Literal
from pydantic import Field
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.base import Provider
class SGLangProvider(Provider):
provider_type: Literal[ProviderType.sglang] = Field(
ProviderType.sglang,
description="The type of the provider."
)
provider_category: ProviderCategory = Field(
ProviderCategory.base,
description="The category of the provider (base or byok)"
)
base_url: str = Field(
...,
description="Base URL for the SGLang API (e.g., http://localhost:30000)."
)
api_key: str | None = Field(
None,
description="API key for the SGLang API (optional for local instances)."
)
default_prompt_formatter: str | None = Field(
default=None,
description="Default prompt formatter (aka model wrapper)."
)
handle_base: str | None = Field(
None,
description="Custom handle base name for model handles."
)
async def list_llm_models_async(self) -> list[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
# Ensure base_url ends with /v1 (SGLang uses same convention as vLLM)
base_url = self.base_url.rstrip("/")
if not base_url.endswith("/v1"):
base_url = base_url + "/v1"
# Decrypt API key before using (may be None for local instances)
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await openai_get_model_list_async(base_url, api_key=api_key)
data = response.get("data", response)
configs = []
for model in data:
model_name = model["id"]
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="openai", # SGLang is OpenAI-compatible
model_endpoint=base_url,
model_wrapper=self.default_prompt_formatter,
context_window=model.get("max_model_len", 8192),
handle=self.get_handle(model_name, base_name=self.handle_base) if self.handle_base else self.get_handle(model_name),
max_tokens=self.get_default_max_output_tokens(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
# SGLang embedding support not common for training use cases
return []

View File

@@ -126,3 +126,53 @@ class LettaUsageStatistics(BaseModel):
reasoning_tokens: Optional[int] = Field(
None, description="The number of reasoning/thinking tokens generated. None if not reported by provider."
)
def to_usage(self, provider_type: Optional["ProviderType"] = None) -> "UsageStatistics":
"""Convert to UsageStatistics (OpenAI-compatible format).
Args:
provider_type: ProviderType enum indicating which provider format to use.
Used to determine which cache field to populate.
Returns:
UsageStatistics object with nested prompt/completion token details.
"""
from letta.schemas.enums import ProviderType
from letta.schemas.openai.chat_completion_response import (
UsageStatistics,
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
# Providers that use Anthropic-style cache fields (cache_read_tokens, cache_creation_tokens)
anthropic_style_providers = {ProviderType.anthropic, ProviderType.bedrock}
# Build prompt_tokens_details if we have cache data
prompt_tokens_details = None
if self.cached_input_tokens is not None or self.cache_write_tokens is not None:
if provider_type in anthropic_style_providers:
# Anthropic uses cache_read_tokens and cache_creation_tokens
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cache_read_tokens=self.cached_input_tokens,
cache_creation_tokens=self.cache_write_tokens,
)
else:
# OpenAI/Gemini use cached_tokens
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cached_tokens=self.cached_input_tokens,
)
# Build completion_tokens_details if we have reasoning tokens
completion_tokens_details = None
if self.reasoning_tokens is not None:
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=self.reasoning_tokens,
)
return UsageStatistics(
prompt_tokens=self.prompt_tokens,
completion_tokens=self.completion_tokens,
total_tokens=self.total_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
)

View File

@@ -38,6 +38,7 @@ from letta.errors import (
HandleNotFoundError,
LettaAgentNotFoundError,
LettaExpiredError,
LettaImageFetchError,
LettaInvalidArgumentError,
LettaInvalidMCPSchemaError,
LettaMCPConnectionError,
@@ -64,6 +65,7 @@ from letta.schemas.letta_message import create_letta_error_message_schema, creat
from letta.schemas.letta_message_content import (
create_letta_assistant_message_content_union_schema,
create_letta_message_content_union_schema,
create_letta_tool_return_content_union_schema,
create_letta_user_message_content_union_schema,
)
from letta.server.constants import REST_DEFAULT_PORT
@@ -105,6 +107,7 @@ def generate_openapi_schema(app: FastAPI):
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema()
letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema()
letta_docs["components"]["schemas"]["LettaToolReturnContentUnion"] = create_letta_tool_return_content_union_schema()
letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema()
letta_docs["components"]["schemas"]["LettaErrorMessage"] = create_letta_error_message_schema()
@@ -163,10 +166,18 @@ async def lifespan(app_: FastAPI):
except Exception as e:
logger.warning(f"[Worker {worker_id}] Failed to download NLTK data: {e}")
# logger.info(f"[Worker {worker_id}] Starting lifespan initialization")
# logger.info(f"[Worker {worker_id}] Initializing database connections")
# db_registry.initialize_async()
# logger.info(f"[Worker {worker_id}] Database connections initialized")
# Log effective database timeout settings for debugging
try:
from sqlalchemy import text
from letta.server.db import db_registry
async with db_registry.async_session() as session:
result = await session.execute(text("SHOW statement_timeout"))
statement_timeout = result.scalar()
logger.warning(f"[Worker {worker_id}] PostgreSQL statement_timeout: {statement_timeout}")
except Exception as e:
logger.warning(f"[Worker {worker_id}] Failed to query statement_timeout: {e}")
if should_use_pinecone():
if settings.upsert_pinecone_indices:
@@ -180,7 +191,7 @@ async def lifespan(app_: FastAPI):
logger.info(f"[Worker {worker_id}] Starting scheduler with leader election")
global server
await server.init_async()
await server.init_async(init_with_default_org_and_user=not settings.no_default_actor)
try:
await start_scheduler_with_leader_election(server)
logger.info(f"[Worker {worker_id}] Scheduler initialization completed")
@@ -475,6 +486,7 @@ def create_application() -> "FastAPI":
app.add_exception_handler(LettaToolNameConflictError, _error_handler_400)
app.add_exception_handler(AgentFileImportError, _error_handler_400)
app.add_exception_handler(EmbeddingConfigRequiredError, _error_handler_400)
app.add_exception_handler(LettaImageFetchError, _error_handler_400)
app.add_exception_handler(ValueError, _error_handler_400)
# 404 Not Found errors

View File

@@ -3,7 +3,10 @@ from typing import TYPE_CHECKING, Optional
from fastapi import Header
from pydantic import BaseModel
from letta.errors import LettaInvalidArgumentError
from letta.otel.tracing import tracer
from letta.schemas.enums import PrimitiveType
from letta.validators import PRIMITIVE_ID_PATTERNS
if TYPE_CHECKING:
from letta.server.server import SyncServer
@@ -42,6 +45,12 @@ def get_headers(
) -> HeaderParams:
"""Dependency injection function to extract common headers from requests."""
with tracer.start_as_current_span("dependency.get_headers"):
if actor_id is not None and PRIMITIVE_ID_PATTERNS[PrimitiveType.USER.value].match(actor_id) is None:
raise LettaInvalidArgumentError(
message=(f"Invalid user ID format: {actor_id}. Expected format: '{PrimitiveType.USER.value}-<uuid4>'"),
argument_name="user_id",
)
return HeaderParams(
actor_id=actor_id,
user_agent=user_agent,

View File

@@ -239,11 +239,11 @@ async def create_background_stream_processor(
if isinstance(chunk, tuple):
chunk = chunk[0]
# Track terminal events
# Track terminal events (check at line start to avoid false positives in message content)
if isinstance(chunk, str):
if "data: [DONE]" in chunk:
if "\ndata: [DONE]" in chunk or chunk.startswith("data: [DONE]"):
saw_done = True
if "event: error" in chunk:
if "\nevent: error" in chunk or chunk.startswith("event: error"):
saw_error = True
# Best-effort extraction of the error payload so we can persist it on the run.

View File

@@ -308,6 +308,7 @@ async def _import_agent(
strip_messages: bool = False,
env_vars: Optional[dict[str, Any]] = None,
override_embedding_handle: Optional[str] = None,
override_model_handle: Optional[str] = None,
) -> List[str]:
"""
Import an agent using the new AgentFileSchema format.
@@ -319,6 +320,11 @@ async def _import_agent(
else:
embedding_config_override = None
if override_model_handle:
llm_config_override = await server.get_llm_config_from_handle_async(actor=actor, handle=override_model_handle)
else:
llm_config_override = None
import_result = await server.agent_serialization_manager.import_file(
schema=agent_schema,
actor=actor,
@@ -327,6 +333,7 @@ async def _import_agent(
override_existing_tools=override_existing_tools,
env_vars=env_vars,
override_embedding_config=embedding_config_override,
override_llm_config=llm_config_override,
project_id=project_id,
)
@@ -362,6 +369,10 @@ async def import_agent(
None,
description="Embedding handle to override with.",
),
model: Optional[str] = Form(
None,
description="Model handle to override the agent's default model. This allows the imported agent to use a different model while keeping other defaults (e.g., context size) from the original configuration.",
),
# Deprecated fields (maintain backward compatibility)
append_copy_suffix: bool = Form(
True,
@@ -378,6 +389,11 @@ async def import_agent(
description="Override import with specific embedding handle. Use 'embedding' instead.",
deprecated=True,
),
override_model_handle: Optional[str] = Form(
None,
description="Model handle to override the agent's default model. Use 'model' instead.",
deprecated=True,
),
project_id: str | None = Form(
None, description="The project ID to associate the uploaded agent with. This is now passed via headers.", deprecated=True
),
@@ -408,6 +424,7 @@ async def import_agent(
# Handle backward compatibility: prefer new field names over deprecated ones
final_name = name or override_name
final_embedding_handle = embedding or override_embedding_handle or x_override_embedding_model
final_model_handle = model or override_model_handle
# Parse secrets (new) or env_vars_json (deprecated)
env_vars = None
@@ -440,6 +457,7 @@ async def import_agent(
strip_messages=strip_messages,
env_vars=env_vars,
override_embedding_handle=final_embedding_handle,
override_model_handle=final_model_handle,
)
else:
# This is a legacy AgentSchema
@@ -628,7 +646,9 @@ async def run_tool_for_agent(
# Get agent with all relationships
agent = await server.agent_manager.get_agent_by_id_async(
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
agent_id,
actor,
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
)
# Find the tool by name among attached tools
@@ -701,7 +721,7 @@ async def attach_source(
await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
if agent_state.enable_sleeptime:
source = await server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async")
return agent_state
@@ -728,7 +748,7 @@ async def attach_folder_to_agent(
await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
if agent_state.enable_sleeptime:
source = await server.source_manager.get_source_by_id(source_id=folder_id)
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async")
if is_1_0_sdk_version(headers):
@@ -759,7 +779,7 @@ async def detach_source(
if agent_state.enable_sleeptime:
try:
source = await server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
await server.block_manager.delete_block_async(block.id, actor)
except:
@@ -791,7 +811,7 @@ async def detach_folder_from_agent(
if agent_state.enable_sleeptime:
try:
source = await server.source_manager.get_source_by_id(source_id=folder_id)
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
await server.block_manager.delete_block_async(block.id, actor)
except:
@@ -1256,7 +1276,7 @@ async def detach_identity_from_agent(
return None
@router.get("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="list_passages", deprecated=True)
@router.get("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="list_passages")
async def list_passages(
agent_id: AgentId,
server: "SyncServer" = Depends(get_letta_server),
@@ -1285,7 +1305,7 @@ async def list_passages(
)
@router.post("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="create_passage", deprecated=True)
@router.post("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="create_passage")
async def create_passage(
agent_id: AgentId,
request: CreateArchivalMemory = Body(...),
@@ -1306,7 +1326,6 @@ async def create_passage(
"/{agent_id}/archival-memory/search",
response_model=ArchivalMemorySearchResponse,
operation_id="search_archival_memory",
deprecated=True,
)
async def search_archival_memory(
agent_id: AgentId,
@@ -1354,7 +1373,7 @@ async def search_archival_memory(
# TODO(ethan): query or path parameter for memory_id?
# @router.delete("/{agent_id}/archival")
@router.delete("/{agent_id}/archival-memory/{memory_id}", response_model=None, operation_id="delete_passage", deprecated=True)
@router.delete("/{agent_id}/archival-memory/{memory_id}", response_model=None, operation_id="delete_passage")
async def delete_passage(
memory_id: str,
agent_id: AgentId,
@@ -1520,7 +1539,9 @@ async def send_message(
MetricRegistry().user_message_counter.add(1, get_ctx_attributes())
# TODO: This is redundant, remove soon
agent = await server.agent_manager.get_agent_by_id_async(
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
agent_id,
actor,
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
)
# Handle model override if specified in the request
@@ -1799,7 +1820,9 @@ async def _process_message_background(
try:
agent = await server.agent_manager.get_agent_by_id_async(
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
agent_id,
actor,
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
)
# Handle model override if specified
@@ -1853,15 +1876,24 @@ async def _process_message_background(
runs_manager = RunManager()
from letta.schemas.enums import RunStatus
from letta.schemas.letta_stop_reason import StopReasonType
if result.stop_reason.stop_reason == "cancelled":
# Handle cases where stop_reason might be None (defensive)
if result.stop_reason and result.stop_reason.stop_reason == "cancelled":
run_status = RunStatus.cancelled
else:
stop_reason = result.stop_reason.stop_reason
elif result.stop_reason:
run_status = RunStatus.completed
stop_reason = result.stop_reason.stop_reason
else:
# Fallback: no stop_reason set (shouldn't happen but defensive)
logger.error(f"Run {run_id} completed without stop_reason in result, defaulting to end_turn")
run_status = RunStatus.completed
stop_reason = StopReasonType.end_turn
await runs_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=run_status, stop_reason=result.stop_reason.stop_reason),
update=RunUpdate(status=run_status, stop_reason=stop_reason),
actor=actor,
)
@@ -1869,20 +1901,22 @@ async def _process_message_background(
# Update run status to failed with specific error info
runs_manager = RunManager()
from letta.schemas.enums import RunStatus
from letta.schemas.letta_stop_reason import StopReasonType
await runs_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.failed, metadata={"error": str(e)}),
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error, metadata={"error": str(e)}),
actor=actor,
)
except Exception as e:
# Update run status to failed
runs_manager = RunManager()
from letta.schemas.enums import RunStatus
from letta.schemas.letta_stop_reason import StopReasonType
await runs_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.failed, metadata={"error": str(e)}),
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error, metadata={"error": str(e)}),
actor=actor,
)
finally:
@@ -1966,7 +2000,9 @@ async def send_message_async(
if use_lettuce:
agent_state = await server.agent_manager.get_agent_by_id_async(
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
agent_id,
actor,
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
)
# Allow V1 agents only if the message async flag is enabled
is_v1_message_async_enabled = (
@@ -2020,10 +2056,11 @@ async def send_message_async(
async def update_failed_run():
runs_manager = RunManager()
from letta.schemas.enums import RunStatus
from letta.schemas.letta_stop_reason import StopReasonType
await runs_manager.update_run_by_id_async(
run_id=run.id,
update=RunUpdate(status=RunStatus.failed, metadata={"error": error_str}),
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error, metadata={"error": error_str}),
actor=actor,
)

View File

@@ -48,6 +48,13 @@ class PassageCreateRequest(BaseModel):
text: str = Field(..., description="The text content of the passage")
metadata: Optional[Dict] = Field(default=None, description="Optional metadata for the passage")
tags: Optional[List[str]] = Field(default=None, description="Optional tags for categorizing the passage")
created_at: Optional[str] = Field(default=None, description="Optional creation datetime for the passage (ISO 8601 format)")
class PassageBatchCreateRequest(BaseModel):
"""Request model for creating multiple passages in an archive."""
passages: List[PassageCreateRequest] = Field(..., description="Passages to create in the archive")
@router.post("/", response_model=PydanticArchive, operation_id="create_archive")
@@ -65,16 +72,14 @@ async def create_archive(
if embedding_config is None:
embedding_handle = archive.embedding
if embedding_handle is None:
if settings.default_embedding_handle is None:
raise LettaInvalidArgumentError(
"Must specify either embedding or embedding_config in request", argument_name="default_embedding_handle"
)
else:
embedding_handle = settings.default_embedding_handle
embedding_config = await server.get_embedding_config_from_handle_async(
handle=embedding_handle,
actor=actor,
)
embedding_handle = settings.default_embedding_handle
# Only resolve embedding config if we have an embedding handle
if embedding_handle is not None:
embedding_config = await server.get_embedding_config_from_handle_async(
handle=embedding_handle,
actor=actor,
)
# Otherwise, embedding_config remains None (text search only)
return await server.archive_manager.create_archive_async(
name=archive.name,
@@ -227,6 +232,27 @@ async def create_passage_in_archive(
text=passage.text,
metadata=passage.metadata,
tags=passage.tags,
created_at=passage.created_at,
actor=actor,
)
@router.post("/{archive_id}/passages/batch", response_model=List[Passage], operation_id="create_passages_in_archive")
async def create_passages_in_archive(
archive_id: ArchiveId,
payload: PassageBatchCreateRequest = Body(...),
server: "SyncServer" = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Create multiple passages in an archive.
This adds passages to the archive and creates embeddings for vector storage.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.archive_manager.create_passages_in_archive_async(
archive_id=archive_id,
passages=[passage.model_dump() for passage in payload.passages],
actor=actor,
)

View File

@@ -5,16 +5,20 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from letta.agents.agent_loop import AgentLoop
from letta.agents.letta_agent_v3 import LettaAgentV3
from letta.constants import REDIS_RUN_ID_PREFIX
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import LettaExpiredError, LettaInvalidArgumentError, NoActiveRunsToCancelError
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.schemas.conversation import Conversation, CreateConversation, UpdateConversation
from letta.schemas.enums import RunStatus
from letta.schemas.job import LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_request import LettaStreamingRequest, RetrieveStreamRequest
from letta.schemas.letta_request import ConversationMessageRequest, LettaStreamingRequest, RetrieveStreamRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.run import Run as PydanticRun
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
from letta.server.rest_api.streaming_response import (
@@ -60,6 +64,7 @@ async def list_conversations(
agent_id: str = Query(..., description="The agent ID to list conversations for"),
limit: int = Query(50, description="Maximum number of conversations to return"),
after: Optional[str] = Query(None, description="Cursor for pagination (conversation ID)"),
summary_search: Optional[str] = Query(None, description="Search for text within conversation summaries"),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
@@ -70,6 +75,7 @@ async def list_conversations(
actor=actor,
limit=limit,
after=after,
summary_search=summary_search,
)
@@ -154,51 +160,112 @@ async def list_conversation_messages(
@router.post(
"/{conversation_id}/messages",
response_model=LettaStreamingResponse,
response_model=LettaResponse,
operation_id="send_conversation_message",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
"text/event-stream": {"description": "Server-Sent Events stream (default, when streaming=true)"},
"application/json": {"description": "JSON response (when streaming=false)"},
},
}
},
)
async def send_conversation_message(
conversation_id: ConversationId,
request: LettaStreamingRequest = Body(...),
request: ConversationMessageRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
) -> StreamingResponse | LettaResponse:
"""
Send a message to a conversation and get a streaming response.
Send a message to a conversation and get a response.
This endpoint sends a message to an existing conversation and streams
the agent's response back.
This endpoint sends a message to an existing conversation.
By default (streaming=true), returns a streaming response (Server-Sent Events).
Set streaming=false to get a complete JSON response.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Get the conversation to find the agent_id
if not request.messages or len(request.messages) == 0:
raise HTTPException(status_code=422, detail="Messages must not be empty")
conversation = await conversation_manager.get_conversation_by_id(
conversation_id=conversation_id,
actor=actor,
)
# Force streaming mode for this endpoint
request.streaming = True
# Streaming mode (default)
if request.streaming:
# Convert to LettaStreamingRequest for StreamingService compatibility
streaming_request = LettaStreamingRequest(
messages=request.messages,
streaming=True,
stream_tokens=request.stream_tokens,
include_pings=request.include_pings,
background=request.background,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
override_model=request.override_model,
client_tools=request.client_tools,
)
streaming_service = StreamingService(server)
run, result = await streaming_service.create_agent_stream(
agent_id=conversation.agent_id,
actor=actor,
request=streaming_request,
run_type="send_conversation_message",
conversation_id=conversation_id,
)
return result
# Use streaming service
streaming_service = StreamingService(server)
run, result = await streaming_service.create_agent_stream(
agent_id=conversation.agent_id,
actor=actor,
request=request,
run_type="send_conversation_message",
conversation_id=conversation_id,
# Non-streaming mode
agent = await server.agent_manager.get_agent_by_id_async(
conversation.agent_id,
actor,
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
)
return result
if request.override_model:
override_llm_config = await server.get_llm_config_from_handle_async(
actor=actor,
handle=request.override_model,
)
agent = agent.model_copy(update={"llm_config": override_llm_config})
# Create a run for execution tracking
run = None
if settings.track_agent_run:
runs_manager = RunManager()
run = await runs_manager.create_run(
pydantic_run=PydanticRun(
agent_id=conversation.agent_id,
background=False,
metadata={
"run_type": "send_conversation_message",
},
request_config=LettaRequestConfig.from_letta_request(request),
),
actor=actor,
)
# Set run_id in Redis for cancellation support
redis_client = await get_redis_client()
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{conversation.agent_id}", run.id if run else None)
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
return await agent_loop.step(
request.messages,
max_steps=request.max_steps,
run_id=run.id if run else None,
use_assistant_message=request.use_assistant_message,
include_return_message_types=request.include_return_message_types,
client_tools=request.client_tools,
conversation_id=conversation_id,
)
@router.post(
@@ -289,11 +356,14 @@ async def retrieve_conversation_stream(
)
if settings.enable_cancellation_aware_streaming:
from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run
stream = cancellation_aware_stream_wrapper(
stream_generator=stream,
run_manager=server.run_manager,
run_id=run.id,
actor=actor,
cancellation_event=get_cancellation_event_for_run(run.id),
)
if request and request.include_pings and settings.enable_keepalive:

View File

@@ -594,7 +594,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id:
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
source = await server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent in agents:
if agent.enable_sleeptime:

View File

@@ -231,7 +231,7 @@ async def list_messages_for_batch(
# Get messages directly using our efficient method
messages = await server.batch_manager.get_messages_for_letta_batch_async(
letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, ascending=(order == "asc"), before=before, after=after
letta_batch_job_id=batch_id, actor=actor, limit=limit, agent_id=agent_id, sort_descending=(order == "desc"), cursor=after
)
return LettaBatchMessages(messages=messages)

View File

@@ -4,18 +4,62 @@ from typing import List, Literal, Optional
from fastapi import APIRouter, Body, Depends
from pydantic import BaseModel, Field
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import TagMatchMode
from letta.schemas.passage import Passage
from letta.schemas.user import User as PydanticUser
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.server import SyncServer
router = APIRouter(prefix="/passages", tags=["passages"])
async def _get_embedding_config_for_search(
server: SyncServer,
actor: PydanticUser,
agent_id: Optional[str],
archive_id: Optional[str],
) -> Optional[EmbeddingConfig]:
"""Determine which embedding config to use for a passage search.
Args:
server: The SyncServer instance
actor: The user making the request
agent_id: Optional agent ID to get embedding config from
archive_id: Optional archive ID to get embedding config from
Returns:
The embedding config to use, or None if not found
Priority:
1. If agent_id is provided, use that agent's embedding config
2. If archive_id is provided, use that archive's embedding config
3. Otherwise, try to get embedding config from any existing agent
4. Fall back to server default if no agents exist
"""
if agent_id:
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
return agent_state.embedding_config
if archive_id:
archive = await server.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor)
return archive.embedding_config
# Search across all passages - try to get embedding config from any agent
agent_count = await server.agent_manager.size_async(actor=actor)
if agent_count > 0:
agents = await server.agent_manager.list_agents_async(actor=actor, limit=1)
if agents:
return agents[0].embedding_config
# Fall back to server default
return server.default_embedding_config
class PassageSearchRequest(BaseModel):
"""Request model for searching passages across archives."""
query: str = Field(..., description="Text query for semantic search")
query: Optional[str] = Field(None, description="Text query for semantic search")
agent_id: Optional[str] = Field(None, description="Filter passages by agent ID")
archive_id: Optional[str] = Field(None, description="Filter passages by archive ID")
tags: Optional[List[str]] = Field(None, description="Optional list of tags to filter search results")
@@ -56,29 +100,16 @@ async def search_passages(
# Convert tag_match_mode to enum
tag_mode = TagMatchMode.ANY if request.tag_match_mode == "any" else TagMatchMode.ALL
# Determine which embedding config to use
# Determine embedding config (only needed when query text is provided)
embed_query = bool(request.query)
embedding_config = None
if request.agent_id:
# Search by agent
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=request.agent_id, actor=actor)
embedding_config = agent_state.embedding_config
elif request.archive_id:
# Search by archive_id
archive = await server.archive_manager.get_archive_by_id_async(archive_id=request.archive_id, actor=actor)
embedding_config = archive.embedding_config
else:
# Search across all passages in the organization
# Get default embedding config from any agent or use server default
agent_count = await server.agent_manager.size_async(actor=actor)
if agent_count > 0:
# Get first agent to derive embedding config
agents = await server.agent_manager.list_agents_async(actor=actor, limit=1)
if agents:
embedding_config = agents[0].embedding_config
if not embedding_config:
# Fall back to server default
embedding_config = server.default_embedding_config
if embed_query:
embedding_config = await _get_embedding_config_for_search(
server=server,
actor=actor,
agent_id=request.agent_id,
archive_id=request.archive_id,
)
# Search passages
passages_with_metadata = await server.agent_manager.query_agent_passages_async(
@@ -88,7 +119,7 @@ async def search_passages(
query_text=request.query,
limit=request.limit,
embedding_config=embedding_config,
embed_query=True,
embed_query=embed_query,
tags=request.tags,
tag_match_mode=tag_mode,
start_date=request.start_date,

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Literal, Optional
from fastapi import APIRouter, Body, Depends, Query, status
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
from letta.schemas.enums import ProviderCategory, ProviderType
@@ -144,6 +144,27 @@ async def check_existing_provider(
)
@router.patch("/{provider_id}/refresh", response_model=Provider, operation_id="refresh_provider_models")
async def refresh_provider_models(
provider_id: ProviderId,
headers: HeaderParams = Depends(get_headers),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Refresh models for a BYOK provider by querying the provider's API.
Adds new models and removes ones no longer available.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
# Only allow refresh for BYOK providers
if provider.provider_category != ProviderCategory.byok:
raise HTTPException(status_code=400, detail="Refresh is only supported for BYOK providers")
await server.provider_manager._sync_default_models_for_provider(provider, actor)
return await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
@router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
async def delete_provider(
provider_id: ProviderId,

View File

@@ -393,11 +393,14 @@ async def retrieve_stream_for_run(
)
if settings.enable_cancellation_aware_streaming:
from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run
stream = cancellation_aware_stream_wrapper(
stream_generator=stream,
run_manager=server.run_manager,
run_id=run_id,
actor=actor,
cancellation_event=get_cancellation_event_for_run(run_id),
)
if request.include_pings and settings.enable_keepalive:

View File

@@ -485,7 +485,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id:
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
source = await server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent in agents:
if agent.enable_sleeptime:

View File

@@ -4,6 +4,7 @@ from fastapi import APIRouter, Body, Depends, Query
from letta.schemas.user import User, UserCreate, UserUpdate
from letta.server.rest_api.dependencies import get_letta_server
from letta.validators import UserIdQueryRequired
if TYPE_CHECKING:
from letta.schemas.user import User
@@ -52,7 +53,7 @@ async def update_user(
@router.delete("/", tags=["admin"], response_model=User, operation_id="delete_user")
async def delete_user(
user_id: str = Query(..., description="The user_id key to be deleted."),
user_id: UserIdQueryRequired,
server: "SyncServer" = Depends(get_letta_server),
):
# TODO make a soft deletion, instead of a hard deletion

View File

@@ -7,6 +7,7 @@ import json
import re
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import Dict, Optional
from uuid import uuid4
import anyio
@@ -26,6 +27,17 @@ from letta.utils import safe_create_task
logger = get_logger(__name__)
# Global registry of cancellation events per run_id
# Note: Events are small and we don't bother cleaning them up
_cancellation_events: Dict[str, asyncio.Event] = {}
def get_cancellation_event_for_run(run_id: str) -> asyncio.Event:
"""Get or create a cancellation event for a run."""
if run_id not in _cancellation_events:
_cancellation_events[run_id] = asyncio.Event()
return _cancellation_events[run_id]
class RunCancelledException(Exception):
"""Exception raised when a run is explicitly cancelled (not due to client timeout)"""
@@ -125,6 +137,7 @@ async def cancellation_aware_stream_wrapper(
run_id: str,
actor: User,
cancellation_check_interval: float = 0.5,
cancellation_event: Optional[asyncio.Event] = None,
) -> AsyncIterator[str | bytes]:
"""
Wraps a stream generator to provide real-time run cancellation checking.
@@ -156,11 +169,22 @@ async def cancellation_aware_stream_wrapper(
run = await run_manager.get_run_by_id(run_id=run_id, actor=actor)
if run.status == RunStatus.cancelled:
logger.info(f"Stream cancelled for run {run_id}, interrupting stream")
# Signal cancellation via shared event if available
if cancellation_event:
cancellation_event.set()
logger.info(f"Set cancellation event for run {run_id}")
# Send cancellation event to client
cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
yield f"data: {json.dumps(cancellation_event)}\n\n"
# Raise custom exception for explicit run cancellation
raise RunCancelledException(run_id, f"Run {run_id} was cancelled")
stop_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
yield f"data: {json.dumps(stop_event)}\n\n"
# Inject exception INTO the generator so its except blocks can catch it
try:
await stream_generator.athrow(RunCancelledException(run_id, f"Run {run_id} was cancelled"))
except (StopAsyncIteration, RunCancelledException):
# Generator closed gracefully or raised the exception back
break
except RunCancelledException:
# Re-raise cancellation immediately, don't catch it
raise
@@ -173,9 +197,10 @@ async def cancellation_aware_stream_wrapper(
yield chunk
except RunCancelledException:
# Re-raise RunCancelledException to distinguish from client timeout
# Don't re-raise - we already injected the exception into the generator
# The generator has handled it and set its stream_was_cancelled flag
logger.info(f"Stream for run {run_id} was explicitly cancelled and cleaned up")
raise
# Don't raise - let it exit gracefully
except asyncio.CancelledError:
# Re-raise CancelledError (likely client timeout) to ensure proper cleanup
logger.info(f"Stream for run {run_id} was cancelled (likely client timeout) and cleaned up")

View File

@@ -20,7 +20,7 @@ from letta.constants import (
)
from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.helpers.message_helper import convert_message_creates_to_messages, resolve_tool_return_images
from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
@@ -171,18 +171,26 @@ async def create_input_messages(
return messages
def create_approval_response_message_from_input(
async def create_approval_response_message_from_input(
agent_state: AgentState, input_message: ApprovalCreate, run_id: Optional[str] = None
) -> List[Message]:
def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn):
async def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn):
if isinstance(maybe_tool_return, LettaToolReturn):
packaged_function_response = package_function_response(
maybe_tool_return.status == "success", maybe_tool_return.tool_return, agent_state.timezone
)
tool_return_content = maybe_tool_return.tool_return
# Handle tool_return content - can be string or list of content parts (text/image)
if isinstance(tool_return_content, str):
# String content - wrap with package_function_response as before
func_response = package_function_response(maybe_tool_return.status == "success", tool_return_content, agent_state.timezone)
else:
# List of content parts (text/image) - resolve URL images to base64 first
resolved_content = await resolve_tool_return_images(tool_return_content)
func_response = resolved_content
return ToolReturn(
tool_call_id=maybe_tool_return.tool_call_id,
status=maybe_tool_return.status,
func_response=packaged_function_response,
func_response=func_response,
stdout=maybe_tool_return.stdout,
stderr=maybe_tool_return.stderr,
)
@@ -196,6 +204,11 @@ def create_approval_response_message_from_input(
getattr(input_message, "approval_request_id", None),
)
# Process all tool returns concurrently (for async image resolution)
import asyncio
converted_approvals = await asyncio.gather(*[maybe_convert_tool_return_message(approval) for approval in approvals_list])
return [
Message(
role=MessageRole.approval,
@@ -204,7 +217,7 @@ def create_approval_response_message_from_input(
approval_request_id=input_message.approval_request_id,
approve=input_message.approve,
denial_reason=input_message.reason,
approvals=[maybe_convert_tool_return_message(approval) for approval in approvals_list],
approvals=list(converted_approvals),
run_id=run_id,
group_id=input_message.group_id
if input_message.group_id

View File

@@ -19,7 +19,6 @@ from letta.config import LettaConfig
from letta.constants import LETTA_TOOL_EXECUTION_DIR
from letta.data_sources.connectors import DataConnector, load_data
from letta.errors import (
EmbeddingConfigRequiredError,
HandleNotFoundError,
LettaInvalidArgumentError,
LettaMCPConnectionError,
@@ -68,10 +67,12 @@ from letta.schemas.providers import (
GroqProvider,
LettaProvider,
LMStudioOpenAIProvider,
MiniMaxProvider,
OllamaProvider,
OpenAIProvider,
OpenRouterProvider,
Provider,
SGLangProvider,
TogetherProvider,
VLLMProvider,
XAIProvider,
@@ -283,15 +284,33 @@ class SyncServer(object):
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
# see: https://docs.vllm.ai/en/stable/features/tool_calling.html
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
# Auto-append /v1 to the base URL
vllm_url = (
model_settings.vllm_api_base if model_settings.vllm_api_base.endswith("/v1") else model_settings.vllm_api_base + "/v1"
)
self._enabled_providers.append(
VLLMProvider(
name="vllm",
base_url=model_settings.vllm_api_base,
base_url=vllm_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
handle_base=model_settings.vllm_handle_base,
)
)
if model_settings.sglang_api_base:
# Auto-append /v1 to the base URL
sglang_url = (
model_settings.sglang_api_base if model_settings.sglang_api_base.endswith("/v1") else model_settings.sglang_api_base + "/v1"
)
self._enabled_providers.append(
SGLangProvider(
name="sglang",
base_url=sglang_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
handle_base=model_settings.sglang_handle_base,
)
)
if model_settings.aws_access_key_id and model_settings.aws_secret_access_key and model_settings.aws_default_region:
self._enabled_providers.append(
BedrockProvider(
@@ -324,6 +343,13 @@ class SyncServer(object):
api_key_enc=Secret.from_plaintext(model_settings.xai_api_key),
)
)
if model_settings.minimax_api_key:
self._enabled_providers.append(
MiniMaxProvider(
name="minimax",
api_key_enc=Secret.from_plaintext(model_settings.minimax_api_key),
)
)
if model_settings.zai_api_key:
self._enabled_providers.append(
ZAIProvider(
@@ -443,6 +469,8 @@ class SyncServer(object):
embedding_models=embedding_models,
organization_id=None, # Global models
)
# Update last_synced timestamp
await self.provider_manager.update_provider_last_synced_async(persisted_provider.id)
logger.info(
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
)
@@ -628,9 +656,10 @@ class SyncServer(object):
actor=actor,
)
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]:
if main_agent.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_sleeptime_agent")
logger.warning(f"Skipping sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
return None
request = CreateAgent(
name=main_agent.name + "-sleeptime",
agent_type=AgentType.sleeptime_agent,
@@ -662,9 +691,10 @@ class SyncServer(object):
)
return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor)
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]:
if main_agent.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_voice_sleeptime_agent")
logger.warning(f"Skipping voice sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
return None
# TODO: Inject system
request = CreateAgent(
name=main_agent.name + "-sleeptime",
@@ -956,7 +986,7 @@ class SyncServer(object):
from letta.data_sources.connectors import DirectoryConnector
# TODO: move this into a thread
source = await self.source_manager.get_source_by_id(source_id=source_id)
source = await self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
connector = DirectoryConnector(input_files=[file_path])
num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
@@ -1041,9 +1071,10 @@ class SyncServer(object):
async def create_document_sleeptime_agent_async(
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
) -> AgentState:
) -> Optional[AgentState]:
if main_agent.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_document_sleeptime_agent")
logger.warning(f"Skipping document sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
return None
try:
block = await self.agent_manager.get_block_with_label_async(agent_id=main_agent.id, block_label=source.name, actor=actor)
except:
@@ -1151,10 +1182,18 @@ class SyncServer(object):
if provider_type and provider.provider_type != provider_type:
continue
# For bedrock, use schema default for base_url since DB may have NULL
# TODO: can maybe do this for all models but want to isolate change so we don't break any other providers
if provider.provider_type == ProviderType.bedrock:
typed_provider = provider.cast_to_subtype()
model_endpoint = typed_provider.base_url
else:
model_endpoint = provider.base_url
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url or model.model_endpoint_type,
model_endpoint=model_endpoint,
context_window=model.max_context_window or 16384,
handle=model.handle,
provider_name=provider.name,
@@ -1162,7 +1201,7 @@ class SyncServer(object):
)
llm_models.append(llm_config)
# Get BYOK provider models by hitting provider endpoints directly
# Get BYOK provider models - sync if not synced yet, then read from DB
if include_byok:
byok_providers = await self.provider_manager.list_providers_async(
actor=actor,
@@ -1173,9 +1212,39 @@ class SyncServer(object):
for provider in byok_providers:
try:
# Get typed provider to access schema defaults (e.g., base_url)
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_llm_models_async()
llm_models.extend(models)
# Sync models if not synced yet
if provider.last_synced is None:
models = await typed_provider.list_llm_models_async()
embedding_models = await typed_provider.list_embedding_models_async()
await self.provider_manager.sync_provider_models_async(
provider=provider,
llm_models=models,
embedding_models=embedding_models,
organization_id=provider.organization_id,
)
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
# Read from database
provider_llm_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="llm",
provider_id=provider.id,
enabled=True,
)
for model in provider_llm_models:
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=typed_provider.base_url,
context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW,
handle=model.handle,
provider_name=provider.name,
provider_category=ProviderCategory.byok,
)
llm_models.append(llm_config)
except Exception as e:
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
@@ -1217,7 +1286,7 @@ class SyncServer(object):
)
embedding_models.append(embedding_config)
# Get BYOK provider models by hitting provider endpoints directly
# Get BYOK provider models - sync if not synced yet, then read from DB
byok_providers = await self.provider_manager.list_providers_async(
actor=actor,
provider_category=[ProviderCategory.byok],
@@ -1225,9 +1294,38 @@ class SyncServer(object):
for provider in byok_providers:
try:
# Get typed provider to access schema defaults (e.g., base_url)
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_embedding_models_async()
embedding_models.extend(models)
# Sync models if not synced yet
if provider.last_synced is None:
llm_models = await typed_provider.list_llm_models_async()
emb_models = await typed_provider.list_embedding_models_async()
await self.provider_manager.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=emb_models,
organization_id=provider.organization_id,
)
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
# Read from database
provider_embedding_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="embedding",
provider_id=provider.id,
enabled=True,
)
for model in provider_embedding_models:
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=typed_provider.base_url,
embedding_dim=model.embedding_dim or 1536,
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,
)
embedding_models.append(embedding_config)
except Exception as e:
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")

View File

@@ -357,7 +357,11 @@ class AgentManager:
)
agent_create.llm_config = LLMConfig.apply_reasoning_setting_to_config(
agent_create.llm_config,
agent_create.reasoning if agent_create.reasoning is not None else default_reasoning,
agent_create.reasoning
if agent_create.reasoning is not None
else (
agent_create.llm_config.enable_reasoner if agent_create.llm_config.enable_reasoner is not None else default_reasoning
),
agent_create.agent_type,
)
else:
@@ -2042,10 +2046,12 @@ class AgentManager:
if other_agent_id != agent_id:
try:
other_agent = await AgentModel.read_async(db_session=session, identifier=other_agent_id, actor=actor)
if other_agent.agent_type == AgentType.sleeptime_agent and block not in other_agent.core_memory:
other_agent.core_memory.append(block)
# await other_agent.update_async(session, actor=actor, no_commit=True)
await other_agent.update_async(session, actor=actor)
if other_agent.agent_type == AgentType.sleeptime_agent:
# Check if block with same label already exists
existing_block = next((b for b in other_agent.core_memory if b.label == block.label), None)
if not existing_block:
other_agent.core_memory.append(block)
await other_agent.update_async(session, actor=actor)
except NoResultFound:
# Agent might not exist anymore, skip
continue
@@ -2321,15 +2327,6 @@ class AgentManager:
# Use Turbopuffer for vector search if archive is configured for TPUF
if archive.vector_db_provider == VectorDBProvider.TPUF:
from letta.helpers.tpuf_client import TurbopufferClient
from letta.llm_api.llm_client import LLMClient
# Generate embedding for query
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
query_embedding = embeddings[0]
# Query Turbopuffer - use hybrid search when text is available
tpuf_client = TurbopufferClient()
@@ -2488,13 +2485,15 @@ class AgentManager:
# Get results using existing passage query method
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
# Only use embedding-based search if embedding config is available
use_embedding_search = agent_state.embedding_config is not None
passages_with_metadata = await self.query_agent_passages_async(
actor=actor,
agent_id=agent_id,
query_text=query,
limit=limit,
embedding_config=agent_state.embedding_config,
embed_query=True,
embed_query=use_embedding_search,
tags=tags,
tag_match_mode=tag_mode,
start_date=start_date,
@@ -3053,10 +3052,19 @@ class AgentManager:
)
# Apply cursor-based pagination
if before:
query = query.where(BlockModel.id < before)
if after:
query = query.where(BlockModel.id > after)
# Note: cursor direction must account for sort order
# - ascending order: "after X" means id > X, "before X" means id < X
# - descending order: "after X" means id < X, "before X" means id > X
if ascending:
if before:
query = query.where(BlockModel.id < before)
if after:
query = query.where(BlockModel.id > after)
else:
if before:
query = query.where(BlockModel.id > before)
if after:
query = query.where(BlockModel.id < after)
# Apply sorting - use id instead of created_at for core memory blocks
if ascending:

View File

@@ -33,6 +33,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
from letta.schemas.file import FileMetadata
from letta.schemas.group import Group, GroupCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.mcp import MCPServer
from letta.schemas.message import Message
from letta.schemas.source import Source
@@ -472,6 +473,7 @@ class AgentSerializationManager:
dry_run: bool = False,
env_vars: Optional[Dict[str, Any]] = None,
override_embedding_config: Optional[EmbeddingConfig] = None,
override_llm_config: Optional[LLMConfig] = None,
project_id: Optional[str] = None,
) -> ImportResult:
"""
@@ -672,6 +674,11 @@ class AgentSerializationManager:
agent_schema.embedding_config = override_embedding_config
agent_schema.embedding = override_embedding_config.handle
# Override llm_config if provided (keeps other defaults like context size)
if override_llm_config:
agent_schema.llm_config = override_llm_config
agent_schema.model = override_llm_config.handle
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})

View File

@@ -4,7 +4,6 @@ from typing import Dict, List, Optional
from sqlalchemy import delete, or_, select
from letta.errors import EmbeddingConfigRequiredError
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
@@ -32,7 +31,7 @@ class ArchiveManager:
async def create_archive_async(
self,
name: str,
embedding_config: EmbeddingConfig,
embedding_config: Optional[EmbeddingConfig] = None,
description: Optional[str] = None,
actor: PydanticUser = None,
) -> PydanticArchive:
@@ -289,6 +288,7 @@ class ArchiveManager:
text: str,
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
created_at: Optional[str] = None,
actor: PydanticUser = None,
) -> PydanticPassage:
"""Create a passage in an archive.
@@ -298,6 +298,7 @@ class ArchiveManager:
text: The text content of the passage
metadata: Optional metadata for the passage
tags: Optional tags for categorizing the passage
created_at: Optional creation datetime in ISO 8601 format
actor: User performing the operation
Returns:
@@ -312,13 +313,20 @@ class ArchiveManager:
# Verify the archive exists and user has access
archive = await self.get_archive_by_id_async(archive_id=archive_id, actor=actor)
# Generate embeddings for the text
embedding_client = LLMClient.create(
provider_type=archive.embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([text], archive.embedding_config)
embedding = embeddings[0] if embeddings else None
# Generate embeddings for the text if embedding config is available
embedding = None
if archive.embedding_config is not None:
embedding_client = LLMClient.create(
provider_type=archive.embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([text], archive.embedding_config)
embedding = embeddings[0] if embeddings else None
# Parse created_at from ISO string if provided
parsed_created_at = None
if created_at:
parsed_created_at = datetime.fromisoformat(created_at)
# Create the passage object with embedding
passage = PydanticPassage(
@@ -329,6 +337,7 @@ class ArchiveManager:
tags=tags,
embedding_config=archive.embedding_config,
embedding=embedding,
created_at=parsed_created_at,
)
# Use PassageManager to create the passage
@@ -345,13 +354,14 @@ class ArchiveManager:
tpuf_client = TurbopufferClient()
# Insert to Turbopuffer with the same ID as SQL
# Insert to Turbopuffer with the same ID as SQL, reusing existing embedding
await tpuf_client.insert_archival_memories(
archive_id=archive.id,
text_chunks=[created_passage.text],
passage_ids=[created_passage.id],
organization_id=actor.organization_id,
actor=actor,
embeddings=[created_passage.embedding],
)
logger.info(f"Uploaded passage {created_passage.id} to Turbopuffer for archive {archive_id}")
except Exception as e:
@@ -362,6 +372,92 @@ class ArchiveManager:
logger.info(f"Created passage {created_passage.id} in archive {archive_id}")
return created_passage
@enforce_types
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
@trace_method
async def create_passages_in_archive_async(
self,
archive_id: str,
passages: List[Dict],
actor: PydanticUser = None,
) -> List[PydanticPassage]:
"""Create multiple passages in an archive.
Args:
archive_id: ID of the archive to add the passages to
passages: Passage create payloads
actor: User performing the operation
Returns:
The created passages
Raises:
NoResultFound: If archive not found
"""
if not passages:
return []
from letta.llm_api.llm_client import LLMClient
from letta.services.passage_manager import PassageManager
archive = await self.get_archive_by_id_async(archive_id=archive_id, actor=actor)
texts = [passage["text"] for passage in passages]
embedding_client = LLMClient.create(
provider_type=archive.embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(texts, archive.embedding_config)
if len(embeddings) != len(passages):
raise ValueError("Embedding response count does not match passages count")
# Build PydanticPassage objects for batch creation
pydantic_passages: List[PydanticPassage] = []
for passage_payload, embedding in zip(passages, embeddings):
# Parse created_at from ISO string if provided
created_at = passage_payload.get("created_at")
if created_at and isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
passage = PydanticPassage(
text=passage_payload["text"],
archive_id=archive_id,
organization_id=actor.organization_id,
metadata=passage_payload.get("metadata") or {},
tags=passage_payload.get("tags"),
embedding_config=archive.embedding_config,
embedding=embedding,
created_at=created_at,
)
pydantic_passages.append(passage)
# Use batch create for efficient single-transaction insert
passage_manager = PassageManager()
created_passages = await passage_manager.create_agent_passages_async(
pydantic_passages=pydantic_passages,
actor=actor,
)
if archive.vector_db_provider == VectorDBProvider.TPUF:
try:
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
await tpuf_client.insert_archival_memories(
archive_id=archive.id,
text_chunks=[passage.text for passage in created_passages],
passage_ids=[passage.id for passage in created_passages],
organization_id=actor.organization_id,
actor=actor,
)
logger.info(f"Uploaded {len(created_passages)} passages to Turbopuffer for archive {archive_id}")
except Exception as e:
logger.error(f"Failed to upload passages to Turbopuffer: {e}")
logger.info(f"Created {len(created_passages)} passages in archive {archive_id}")
return created_passages
@enforce_types
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
@raise_on_invalid_id(param_name="passage_id", expected_prefix=PrimitiveType.PASSAGE)
@@ -433,9 +529,7 @@ class ArchiveManager:
)
return archive
# Create a default archive for this agent
if agent_state.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="create_default_archive")
# Create a default archive for this agent (embedding_config is optional)
archive_name = f"{agent_state.name}'s Archive"
archive = await self.create_archive_async(
name=archive_name,

View File

@@ -508,7 +508,7 @@ class BlockManager:
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
async def get_block_by_id_async(self, block_id: str, actor: PydanticUser) -> Optional[PydanticBlock]:
"""Retrieve a block by its ID, including tags."""
async with db_registry.async_session() as session:
try:
@@ -523,7 +523,7 @@ class BlockManager:
@enforce_types
@trace_method
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: PydanticUser) -> List[PydanticBlock]:
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
if not block_ids:
return []
@@ -540,9 +540,8 @@ class BlockManager:
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
)
# Apply access control if actor is provided
if actor:
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# Apply access control - actor is required for org-scoping
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# TODO: Add soft delete filter if applicable
# if hasattr(BlockModel, "is_deleted"):

View File

@@ -105,9 +105,53 @@ class ConversationManager:
actor: PydanticUser,
limit: int = 50,
after: Optional[str] = None,
summary_search: Optional[str] = None,
) -> List[PydanticConversation]:
"""List conversations for an agent with cursor-based pagination."""
"""List conversations for an agent with cursor-based pagination.
Args:
agent_id: The agent ID to list conversations for
actor: The user performing the action
limit: Maximum number of conversations to return
after: Cursor for pagination (conversation ID)
summary_search: Optional text to search for within the summary field
Returns:
List of conversations matching the criteria
"""
async with db_registry.async_session() as session:
# If summary search is provided, use custom query
if summary_search:
from sqlalchemy import and_
stmt = (
select(ConversationModel)
.where(
and_(
ConversationModel.agent_id == agent_id,
ConversationModel.organization_id == actor.organization_id,
ConversationModel.summary.isnot(None),
ConversationModel.summary.contains(summary_search),
)
)
.order_by(ConversationModel.created_at.desc())
.limit(limit)
)
if after:
# Add cursor filtering
after_conv = await ConversationModel.read_async(
db_session=session,
identifier=after,
actor=actor,
)
stmt = stmt.where(ConversationModel.created_at < after_conv.created_at)
result = await session.execute(stmt)
conversations = result.scalars().all()
return [conv.to_pydantic() for conv in conversations]
# Use default list logic
conversations = await ConversationModel.list_async(
db_session=session,
actor=actor,

View File

@@ -91,18 +91,17 @@ class FileManager:
await session.rollback()
return await self.get_file_by_id(file_metadata.id, actor=actor)
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
@trace_method
# @async_redis_cache(
# key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}",
# key_func=lambda self, file_id, actor, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id}:{include_content}:{strip_directory_prefix}",
# prefix="file_content",
# ttl_s=3600,
# model_class=PydanticFileMetadata,
# )
async def get_file_by_id(
self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False
self, file_id: str, actor: PydanticUser, *, include_content: bool = False, strip_directory_prefix: bool = False
) -> Optional[PydanticFileMetadata]:
"""Retrieve a file by its ID.
@@ -479,7 +478,7 @@ class FileManager:
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
"""Delete a file by its ID."""
async with db_registry.async_session() as session:
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor)
# invalidate cache for this file before deletion
await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id)

View File

@@ -1194,9 +1194,9 @@ async def build_agent_passage_query(
"""
# Handle embedding for vector search
# If embed_query is True but no embedding_config, fall through to text search
embedded_text = None
if embed_query:
assert embedding_config is not None, "embedding_config must be specified for vector search"
if embed_query and embedding_config is not None:
assert query_text is not None, "query_text must be specified for vector search"
# Use the new LLMClient for embeddings

View File

@@ -63,7 +63,7 @@ class LLMBatchManager:
self,
llm_batch_id: str,
status: JobStatus,
actor: Optional[PydanticUser] = None,
actor: PydanticUser,
latest_polling_response: Optional[BetaMessageBatch] = None,
) -> PydanticLLMBatchJob:
"""Update a batch jobs status and optionally its polling response."""
@@ -107,8 +107,8 @@ class LLMBatchManager:
async def list_llm_batch_jobs_async(
self,
letta_batch_id: str,
actor: PydanticUser,
limit: Optional[int] = None,
actor: Optional[PydanticUser] = None,
after: Optional[str] = None,
) -> List[PydanticLLMBatchJob]:
"""
@@ -153,8 +153,8 @@ class LLMBatchManager:
async def get_messages_for_letta_batch_async(
self,
letta_batch_job_id: str,
actor: PydanticUser,
limit: int = 100,
actor: Optional[PydanticUser] = None,
agent_id: Optional[str] = None,
sort_descending: bool = True,
cursor: Optional[str] = None, # Message ID as cursor

View File

@@ -419,6 +419,9 @@ class MCPManager:
"""
# Create base MCPServer object
if isinstance(server_config, StdioServerConfig):
# Check if stdio MCP servers are disabled (not suitable for multi-tenant deployments)
if tool_settings.mcp_disable_stdio:
raise ValueError("MCP stdio servers are disabled. Set MCP_DISABLE_STDIO=false to enable them.")
mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config)
elif isinstance(server_config, SSEServerConfig):
mcp_server = MCPServer(
@@ -832,6 +835,9 @@ class MCPManager:
server_config = SSEServerConfig(**server_config.model_dump())
return AsyncFastMCPSSEClient(server_config=server_config, oauth=oauth, agent_id=agent_id)
elif server_config.type == MCPServerType.STDIO:
# Check if stdio MCP servers are disabled (not suitable for multi-tenant deployments)
if tool_settings.mcp_disable_stdio:
raise ValueError("MCP stdio servers are disabled. Set MCP_DISABLE_STDIO=false to enable them.")
server_config = StdioServerConfig(**server_config.model_dump())
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=None, agent_id=agent_id)
elif server_config.type == MCPServerType.STREAMABLE_HTTP:

View File

@@ -516,6 +516,9 @@ class MCPServerManager:
"""
# Create base MCPServer object
if isinstance(server_config, StdioServerConfig):
# Check if stdio MCP servers are disabled (not suitable for multi-tenant deployments)
if tool_settings.mcp_disable_stdio:
raise ValueError("MCP stdio servers are disabled. Set MCP_DISABLE_STDIO=false to enable them.")
mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config)
elif isinstance(server_config, SSEServerConfig):
mcp_server = MCPServer(
@@ -1003,6 +1006,9 @@ class MCPServerManager:
server_config = SSEServerConfig(**server_config.model_dump())
return AsyncFastMCPSSEClient(server_config=server_config, oauth=oauth, agent_id=agent_id)
elif server_config.type == MCPServerType.STDIO:
# Check if stdio MCP servers are disabled (not suitable for multi-tenant deployments)
if tool_settings.mcp_disable_stdio:
raise ValueError("MCP stdio servers are disabled. Set MCP_DISABLE_STDIO=false to enable them.")
server_config = StdioServerConfig(**server_config.model_dump())
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=None, agent_id=agent_id)
elif server_config.type == MCPServerType.STREAMABLE_HTTP:

View File

@@ -8,7 +8,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from letta.constants import MAX_EMBEDDING_DIM
from letta.errors import EmbeddingConfigRequiredError
from letta.helpers.decorators import async_redis_cache
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
@@ -193,6 +192,93 @@ class PassageManager:
return passage.to_pydantic()
@enforce_types
@trace_method
async def create_agent_passages_async(self, pydantic_passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
"""Create multiple agent passages in a single database transaction.
Args:
pydantic_passages: List of passages to create
actor: User performing the operation
Returns:
List of created passages
"""
if not pydantic_passages:
return []
import numpy as np
from letta.helpers.tpuf_client import should_use_tpuf
use_tpuf = should_use_tpuf()
passage_objects: List[ArchivalPassage] = []
all_tags_data: List[tuple] = [] # (passage_index, tags) for creating tags after passages are created
for idx, pydantic_passage in enumerate(pydantic_passages):
if not pydantic_passage.archive_id:
raise ValueError("Agent passage must have archive_id")
if pydantic_passage.source_id:
raise ValueError("Agent passage cannot have source_id")
data = pydantic_passage.model_dump(to_orm=True)
# Deduplicate tags if provided (for dual storage consistency)
tags = data.get("tags")
if tags:
tags = list(set(tags))
all_tags_data.append((idx, tags))
# Pad embeddings to MAX_EMBEDDING_DIM for pgvector (only when using Postgres as vector DB)
embedding = data["embedding"]
if embedding and not use_tpuf:
np_embedding = np.array(embedding)
if np_embedding.shape[0] != MAX_EMBEDDING_DIM:
embedding = np.pad(np_embedding, (0, MAX_EMBEDDING_DIM - np_embedding.shape[0]), mode="constant").tolist()
# Sanitize text to remove null bytes which PostgreSQL rejects
text = data["text"]
if text and "\x00" in text:
text = text.replace("\x00", "")
logger.warning(f"Removed null bytes from passage text (length: {len(data['text'])} -> {len(text)})")
common_fields = {
"id": data.get("id"),
"text": text,
"embedding": embedding,
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata_", {}),
"tags": tags,
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
agent_fields = {"archive_id": data["archive_id"]}
passage = ArchivalPassage(**common_fields, **agent_fields)
passage_objects.append(passage)
async with db_registry.async_session() as session:
# Batch create all passages in a single transaction
created_passages = await ArchivalPassage.batch_create_async(
items=passage_objects,
db_session=session,
actor=actor,
)
# Create tags for passages that have them
for idx, tags in all_tags_data:
created_passage = created_passages[idx]
await self._create_tags_for_passage(
session=session,
passage_id=created_passage.id,
archive_id=created_passage.archive_id,
organization_id=created_passage.organization_id,
tags=tags,
actor=actor,
)
return [p.to_pydantic() for p in created_passages]
@enforce_types
@trace_method
async def create_source_passage_async(
@@ -474,15 +560,6 @@ class PassageManager:
Returns:
List of created passage objects
"""
if agent_state.embedding_config is None:
raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="insert_passage")
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
embedding_client = LLMClient.create(
provider_type=agent_state.embedding_config.embedding_endpoint_type,
actor=actor,
)
# Get or create the default archive for the agent
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor)
@@ -493,8 +570,16 @@ class PassageManager:
return []
try:
# Generate embeddings for all chunks using the new async API
embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
# Generate embeddings if embedding config is available
if agent_state.embedding_config is not None:
embedding_client = LLMClient.create(
provider_type=agent_state.embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
else:
# No embedding config - store passages without embeddings (text search only)
embeddings = [None] * len(text_chunks)
passages = []
@@ -525,20 +610,21 @@ class PassageManager:
tpuf_client = TurbopufferClient()
# Extract IDs and texts from the created passages
# Extract IDs, texts, and embeddings from the created passages
passage_ids = [p.id for p in passages]
passage_texts = [p.text for p in passages]
passage_embeddings = [p.embedding for p in passages]
# Insert to Turbopuffer with the same IDs as SQL
# TurbopufferClient will generate embeddings internally using default config
# Insert to Turbopuffer with the same IDs as SQL, reusing existing embeddings
await tpuf_client.insert_archival_memories(
archive_id=archive.id,
text_chunks=passage_texts,
passage_ids=passage_ids, # Use same IDs as SQL
passage_ids=passage_ids,
organization_id=actor.organization_id,
actor=actor,
tags=tags,
created_at=passages[0].created_at if passages else None,
embeddings=passage_embeddings,
)
except Exception as e:
logger.error(f"Failed to insert passages to Turbopuffer: {e}")

View File

@@ -98,9 +98,35 @@ class ProviderManager:
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
await deleted_provider.update_async(session, actor=actor)
# Also restore any soft-deleted models associated with this provider
# This is needed because the unique constraint on provider_models doesn't include is_deleted,
# so soft-deleted models would block creation of new models with the same handle
from sqlalchemy import update
restore_models_stmt = (
update(ProviderModelORM)
.where(
and_(
ProviderModelORM.provider_id == deleted_provider.id,
ProviderModelORM.is_deleted == True,
)
)
.values(is_deleted=False)
)
result = await session.execute(restore_models_stmt)
if result.rowcount > 0:
logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'")
# Commit the provider and model restoration before syncing
# This is needed because _sync_default_models_for_provider opens a new session
# that can't see uncommitted changes from this session
await session.commit()
provider_pydantic = deleted_provider.to_pydantic()
# For BYOK providers, automatically sync available models
# This will add any new models and remove any that are no longer available
if is_byok:
await self._sync_default_models_for_provider(provider_pydantic, actor)
@@ -119,6 +145,14 @@ class ProviderManager:
# if provider.name == provider.provider_type.value:
# raise ValueError("Provider name must be unique and different from provider type")
# Fill in schema-default base_url if not provided
# This ensures providers like ZAI get their default endpoint persisted to DB
# rather than relying on cast_to_subtype() at read time
if provider.base_url is None:
typed_provider = provider.cast_to_subtype()
if typed_provider.base_url is not None:
provider.base_url = typed_provider.base_url
# Only assign organization id for non-base providers
# Base providers should be globally accessible (org_id = None)
if is_byok:
@@ -201,6 +235,21 @@ class ProviderManager:
await existing_provider.update_async(session, actor=actor)
return existing_provider.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None:
"""Update the last_synced timestamp for a provider.
Note: actor is optional to support system-level operations (e.g., during server initialization
for global providers). When actor is provided, org-scoping is enforced.
"""
from datetime import datetime, timezone
async with db_registry.async_session() as session:
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
provider.last_synced = datetime.now(timezone.utc)
await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
@@ -476,81 +525,19 @@ class ProviderManager:
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
"""Sync models for a newly created BYOK provider by querying the provider's API."""
from letta.log import get_logger
logger = get_logger(__name__)
try:
# Get the provider class and create an instance
from letta.schemas.enums import ProviderType
from letta.schemas.providers.anthropic import AnthropicProvider
from letta.schemas.providers.azure import AzureProvider
from letta.schemas.providers.bedrock import BedrockProvider
from letta.schemas.providers.google_gemini import GoogleAIProvider
from letta.schemas.providers.groq import GroqProvider
from letta.schemas.providers.ollama import OllamaProvider
from letta.schemas.providers.openai import OpenAIProvider
from letta.schemas.providers.zai import ZAIProvider
# Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc
typed_provider = provider.cast_to_subtype()
llm_models = await typed_provider.list_llm_models_async()
embedding_models = await typed_provider.list_embedding_models_async()
# ChatGPT OAuth requires cast_to_subtype to preserve api_key_enc and id
# (needed for OAuth token refresh and database persistence)
if provider.provider_type == ProviderType.chatgpt_oauth:
provider_instance = provider.cast_to_subtype()
else:
provider_type_to_class = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"groq": GroqProvider,
"google": GoogleAIProvider,
"ollama": OllamaProvider,
"bedrock": BedrockProvider,
"azure": AzureProvider,
"zai": ZAIProvider,
}
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)
provider_class = provider_type_to_class.get(provider_type)
if not provider_class:
logger.warning(f"No provider class found for type '{provider_type}'")
return
# Create provider instance with necessary parameters
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
kwargs = {
"name": provider.name,
"api_key": api_key,
"provider_category": provider.provider_category,
}
if provider.base_url:
kwargs["base_url"] = provider.base_url
if access_key:
kwargs["access_key"] = access_key
if provider.region:
kwargs["region"] = provider.region
if provider.api_version:
kwargs["api_version"] = provider.api_version
provider_instance = provider_class(**kwargs)
# Query the provider's API for available models
llm_models = await provider_instance.list_llm_models_async()
embedding_models = await provider_instance.list_embedding_models_async()
# Update handles and provider_name for BYOK providers
for model in llm_models:
model.provider_name = provider.name
model.handle = f"{provider.name}/{model.model}"
model.provider_category = provider.provider_category
for model in embedding_models:
model.handle = f"{provider.name}/{model.embedding_model}"
# Use existing sync_provider_models_async to save to database
await self.sync_provider_models_async(
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=actor.organization_id
provider=provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=actor.organization_id,
)
await self.update_provider_last_synced_async(provider.id, actor=actor)
except Exception as e:
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
@@ -713,7 +700,7 @@ class ProviderManager:
enabled=True,
model_endpoint_type=llm_config.model_endpoint_type,
max_context_window=llm_config.context_window,
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek"],
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek", "openrouter"],
supports_tool_calling=True, # Assume true for LLMs for now
)
@@ -737,14 +724,27 @@ class ProviderManager:
# Roll back the session to clear the failed transaction
await session.rollback()
else:
# Check if max_context_window needs to be updated
# Check if max_context_window or model_endpoint_type needs to be updated
existing_model = existing[0]
needs_update = False
if existing_model.max_context_window != llm_config.context_window:
logger.info(
f" Updating LLM model {llm_config.handle} max_context_window: "
f"{existing_model.max_context_window} -> {llm_config.context_window}"
)
existing_model.max_context_window = llm_config.context_window
needs_update = True
if existing_model.model_endpoint_type != llm_config.model_endpoint_type:
logger.info(
f" Updating LLM model {llm_config.handle} model_endpoint_type: "
f"{existing_model.model_endpoint_type} -> {llm_config.model_endpoint_type}"
)
existing_model.model_endpoint_type = llm_config.model_endpoint_type
needs_update = True
if needs_update:
await existing_model.update_async(session)
else:
logger.info(f" LLM model {llm_config.handle} already exists (ID: {existing[0].id}), skipping")
@@ -801,7 +801,17 @@ class ProviderManager:
# Roll back the session to clear the failed transaction
await session.rollback()
else:
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
# Check if model_endpoint_type needs to be updated
existing_model = existing[0]
if existing_model.model_endpoint_type != embedding_config.embedding_endpoint_type:
logger.info(
f" Updating embedding model {embedding_config.handle} model_endpoint_type: "
f"{existing_model.model_endpoint_type} -> {embedding_config.embedding_endpoint_type}"
)
existing_model.model_endpoint_type = embedding_config.embedding_endpoint_type
await existing_model.update_async(session)
else:
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
@enforce_types
@trace_method
@@ -972,8 +982,8 @@ class ProviderManager:
# Determine the model endpoint - use provider's base_url if set,
# otherwise use provider-specific defaults
if provider.base_url:
model_endpoint = provider.base_url
if typed_provider.base_url:
model_endpoint = typed_provider.base_url
elif provider.provider_type == ProviderType.chatgpt_oauth:
# ChatGPT OAuth uses the ChatGPT backend API, not a generic endpoint pattern
from letta.schemas.providers.chatgpt_oauth import CHATGPT_CODEX_ENDPOINT

View File

@@ -2,10 +2,12 @@
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
from letta.schemas.provider_trace import ProviderTrace
from letta.orm.provider_trace_metadata import ProviderTraceMetadata as ProviderTraceMetadataModel
from letta.schemas.provider_trace import ProviderTrace, ProviderTraceMetadata
from letta.schemas.user import User
from letta.server.db import db_registry
from letta.services.provider_trace_backends.base import ProviderTraceBackendClient
from letta.settings import telemetry_settings
class PostgresProviderTraceBackend(ProviderTraceBackendClient):
@@ -15,7 +17,17 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
self,
actor: User,
provider_trace: ProviderTrace,
) -> ProviderTrace | ProviderTraceMetadata:
if telemetry_settings.provider_trace_pg_metadata_only:
return await self._create_metadata_only_async(actor, provider_trace)
return await self._create_full_async(actor, provider_trace)
async def _create_full_async(
self,
actor: User,
provider_trace: ProviderTrace,
) -> ProviderTrace:
"""Write full provider trace to provider_traces table."""
async with db_registry.async_session() as session:
provider_trace_model = ProviderTraceModel(**provider_trace.model_dump())
provider_trace_model.organization_id = actor.organization_id
@@ -31,11 +43,44 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
await provider_trace_model.create_async(session, actor=actor, no_commit=True, no_refresh=True)
return provider_trace_model.to_pydantic()
async def _create_metadata_only_async(
self,
actor: User,
provider_trace: ProviderTrace,
) -> ProviderTraceMetadata:
"""Write metadata-only trace to provider_trace_metadata table."""
metadata = ProviderTraceMetadata(
id=provider_trace.id,
step_id=provider_trace.step_id,
agent_id=provider_trace.agent_id,
agent_tags=provider_trace.agent_tags,
call_type=provider_trace.call_type,
run_id=provider_trace.run_id,
source=provider_trace.source,
org_id=provider_trace.org_id,
user_id=provider_trace.user_id,
)
metadata_model = ProviderTraceMetadataModel(**metadata.model_dump())
metadata_model.organization_id = actor.organization_id
async with db_registry.async_session() as session:
await metadata_model.create_async(session, actor=actor, no_commit=True, no_refresh=True)
return metadata_model.to_pydantic()
async def get_by_step_id_async(
self,
step_id: str,
actor: User,
) -> ProviderTrace | None:
"""Read from provider_traces table. Always reads from full table regardless of write flag."""
return await self._get_full_by_step_id_async(step_id, actor)
async def _get_full_by_step_id_async(
self,
step_id: str,
actor: User,
) -> ProviderTrace | None:
"""Read from provider_traces table."""
async with db_registry.async_session() as session:
provider_trace_model = await ProviderTraceModel.read_async(
db_session=session,
@@ -43,3 +88,17 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
actor=actor,
)
return provider_trace_model.to_pydantic() if provider_trace_model else None
async def _get_metadata_by_step_id_async(
self,
step_id: str,
actor: User,
) -> ProviderTraceMetadata | None:
"""Read from provider_trace_metadata table."""
async with db_registry.async_session() as session:
metadata_model = await ProviderTraceMetadataModel.read_async(
db_session=session,
step_id=step_id,
actor=actor,
)
return metadata_model.to_pydantic() if metadata_model else None

View File

@@ -17,7 +17,8 @@ logger = get_logger(__name__)
# Protocol version for crouton communication.
# Bump this when making breaking changes to the record schema.
# Must match ProtocolVersion in apps/crouton/main.go.
PROTOCOL_VERSION = 1
# v2: Added user_id, compaction_settings (summarization), llm_config (non-summarization)
PROTOCOL_VERSION = 2
class SocketProviderTraceBackend(ProviderTraceBackendClient):
@@ -94,6 +95,11 @@ class SocketProviderTraceBackend(ProviderTraceBackendClient):
"error": error,
"error_type": error_type,
"timestamp": datetime.now(timezone.utc).isoformat(),
# v2 protocol fields
"org_id": provider_trace.org_id,
"user_id": provider_trace.user_id,
"compaction_settings": provider_trace.compaction_settings,
"llm_config": provider_trace.llm_config,
}
# Fire-and-forget in background thread

View File

@@ -455,9 +455,11 @@ class RunManager:
# Dispatch callback outside of database session if needed
if needs_callback:
if refresh_result_messages:
# Defensive: ensure stop_reason is never None
stop_reason_value = pydantic_run.stop_reason if pydantic_run.stop_reason else StopReasonType.completed
result = LettaResponse(
messages=await self.get_run_messages(run_id=run_id, actor=actor),
stop_reason=LettaStopReason(stop_reason=pydantic_run.stop_reason),
stop_reason=LettaStopReason(stop_reason=stop_reason_value),
usage=await self.get_run_usage(run_id=run_id, actor=actor),
)
final_metadata["result"] = result.model_dump()
@@ -719,7 +721,7 @@ class RunManager:
)
# Use the standard function to create properly formatted approval response messages
approval_response_messages = create_approval_response_message_from_input(
approval_response_messages = await create_approval_response_message_from_input(
agent_state=agent_state,
input_message=approval_input,
run_id=run_id,

View File

@@ -167,9 +167,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
async def get_sandbox_config_by_type_async(
self, type: SandboxType, actor: Optional[PydanticUser] = None
) -> Optional[PydanticSandboxConfig]:
async def get_sandbox_config_by_type_async(self, type: SandboxType, actor: PydanticUser) -> Optional[PydanticSandboxConfig]:
"""Retrieve a sandbox config by its type."""
async with db_registry.async_session() as session:
try:
@@ -345,7 +343,7 @@ class SandboxConfigManager:
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
@trace_method
async def get_sandbox_env_var_by_key_and_sandbox_config_id_async(
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
self, key: str, sandbox_config_id: str, actor: PydanticUser
) -> Optional[PydanticEnvVar]:
"""Retrieve a sandbox environment variable by its key and sandbox_config_id."""
async with db_registry.async_session() as session:

Some files were not shown because too many files have changed in this diff Show More