chore: bump v0.16.4 (#3168)
This commit is contained in:
43
.skills/llm-provider-usage-statistics/SKILL.md
Normal file
43
.skills/llm-provider-usage-statistics/SKILL.md
Normal file
@@ -0,0 +1,43 @@
|
||||
---
|
||||
name: llm-provider-usage-statistics
|
||||
description: Reference guide for token counting and prefix caching across LLM providers (OpenAI, Anthropic, Gemini). Use when debugging token counts or optimizing prefix caching.
|
||||
---
|
||||
|
||||
# LLM Provider Usage Statistics
|
||||
|
||||
Reference documentation for how different LLM providers report token usage.
|
||||
|
||||
## Quick Reference: Token Counting Semantics
|
||||
|
||||
| Provider | `input_tokens` meaning | Cache tokens | Must add cache to get total? |
|
||||
|----------|------------------------|--------------|------------------------------|
|
||||
| OpenAI | TOTAL (includes cached) | `cached_tokens` is subset | No |
|
||||
| Anthropic | NON-cached only | `cache_read_input_tokens` + `cache_creation_input_tokens` | **Yes** |
|
||||
| Gemini | TOTAL (includes cached) | `cached_content_token_count` is subset | No |
|
||||
|
||||
**Critical difference:** Anthropic's `input_tokens` excludes cached tokens, so you must add them:
|
||||
```
|
||||
total_input = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
|
||||
```
|
||||
|
||||
## Quick Reference: Prefix Caching
|
||||
|
||||
| Provider | Min tokens | How to enable | TTL |
|
||||
|----------|-----------|---------------|-----|
|
||||
| OpenAI | 1,024 | Automatic | ~5-10 min |
|
||||
| Anthropic | 1,024 | Requires `cache_control` breakpoints | 5 min |
|
||||
| Gemini 2.0+ | 1,024 | Automatic (implicit) | Variable |
|
||||
|
||||
## Quick Reference: Reasoning/Thinking Tokens
|
||||
|
||||
| Provider | Field name | Models |
|
||||
|----------|-----------|--------|
|
||||
| OpenAI | `reasoning_tokens` | o1, o3 models |
|
||||
| Anthropic | N/A | (thinking is in content blocks, not usage) |
|
||||
| Gemini | `thoughts_token_count` | Gemini 2.0 with thinking enabled |
|
||||
|
||||
## Provider Reference Files
|
||||
|
||||
- **OpenAI:** [references/openai.md](references/openai.md) - Chat Completions vs Responses API, reasoning models, cached_tokens
|
||||
- **Anthropic:** [references/anthropic.md](references/anthropic.md) - cache_control setup, beta headers, cache token fields
|
||||
- **Gemini:** [references/gemini.md](references/gemini.md) - implicit caching, thinking tokens, usage_metadata fields
|
||||
@@ -0,0 +1,83 @@
|
||||
# Anthropic Usage Statistics
|
||||
|
||||
## Response Format
|
||||
|
||||
```
|
||||
response.usage.input_tokens # NON-cached input tokens only
|
||||
response.usage.output_tokens # Output tokens
|
||||
response.usage.cache_read_input_tokens # Tokens read from cache
|
||||
response.usage.cache_creation_input_tokens # Tokens written to cache
|
||||
```
|
||||
|
||||
## Critical: Token Calculation
|
||||
|
||||
**Anthropic's `input_tokens` is NOT the total.** To get total input tokens:
|
||||
|
||||
```python
|
||||
total_input = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
|
||||
```
|
||||
|
||||
This is different from OpenAI/Gemini where `prompt_tokens` is already the total.
|
||||
|
||||
## Prefix Caching (Prompt Caching)
|
||||
|
||||
**Requirements:**
|
||||
- Minimum 1,024 tokens for Claude 3.5 Haiku/Sonnet
|
||||
- Minimum 2,048 tokens for Claude 3 Opus
|
||||
- Requires explicit `cache_control` breakpoints in messages
|
||||
- TTL: 5 minutes
|
||||
|
||||
**How to enable:**
|
||||
Add `cache_control` to message content:
|
||||
```python
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "...",
|
||||
"cache_control": {"type": "ephemeral"}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Beta header required:**
|
||||
```python
|
||||
betas = ["prompt-caching-2024-07-31"]
|
||||
```
|
||||
|
||||
## Cache Behavior
|
||||
|
||||
- `cache_creation_input_tokens`: Tokens that were cached on this request (cache write)
|
||||
- `cache_read_input_tokens`: Tokens that were read from existing cache (cache hit)
|
||||
- On first request: expect `cache_creation_input_tokens > 0`
|
||||
- On subsequent requests with same prefix: expect `cache_read_input_tokens > 0`
|
||||
|
||||
## Streaming
|
||||
|
||||
In streaming mode, usage is reported in two events:
|
||||
|
||||
1. **`message_start`**: Initial usage (may have cache info)
|
||||
```python
|
||||
event.message.usage.input_tokens
|
||||
event.message.usage.output_tokens
|
||||
event.message.usage.cache_read_input_tokens
|
||||
event.message.usage.cache_creation_input_tokens
|
||||
```
|
||||
|
||||
2. **`message_delta`**: Cumulative output tokens
|
||||
```python
|
||||
event.usage.output_tokens # This is CUMULATIVE, not incremental
|
||||
```
|
||||
|
||||
**Important:** Per Anthropic docs, `message_delta` token counts are cumulative, so assign (don't accumulate).
|
||||
|
||||
## Letta Implementation
|
||||
|
||||
- **Client:** `letta/llm_api/anthropic_client.py`
|
||||
- **Streaming interfaces:**
|
||||
- `letta/interfaces/anthropic_streaming_interface.py`
|
||||
- `letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py` (tracks cache tokens)
|
||||
- **Extract method:** `AnthropicClient.extract_usage_statistics()`
|
||||
- **Cache control:** `_add_cache_control_to_system_message()`, `_add_cache_control_to_messages()`
|
||||
81
.skills/llm-provider-usage-statistics/references/gemini.md
Normal file
81
.skills/llm-provider-usage-statistics/references/gemini.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# Gemini Usage Statistics
|
||||
|
||||
## Response Format
|
||||
|
||||
Gemini returns usage in `usage_metadata`:
|
||||
|
||||
```
|
||||
response.usage_metadata.prompt_token_count # Total input tokens
|
||||
response.usage_metadata.candidates_token_count # Output tokens
|
||||
response.usage_metadata.total_token_count # Sum
|
||||
response.usage_metadata.cached_content_token_count # Tokens from cache (optional)
|
||||
response.usage_metadata.thoughts_token_count # Reasoning tokens (optional)
|
||||
```
|
||||
|
||||
## Token Counting
|
||||
|
||||
- `prompt_token_count` is the TOTAL (includes cached)
|
||||
- `cached_content_token_count` is a subset (when present)
|
||||
- Similar to OpenAI's semantics
|
||||
|
||||
## Implicit Caching (Gemini 2.0+)
|
||||
|
||||
**Requirements:**
|
||||
- Minimum 1,024 tokens
|
||||
- Automatic (no opt-in required)
|
||||
- Available on Gemini 2.0 Flash and later models
|
||||
|
||||
**Behavior:**
|
||||
- Caching is probabilistic and server-side
|
||||
- `cached_content_token_count` may or may not be present
|
||||
- When present, indicates tokens that were served from cache
|
||||
|
||||
**Note:** Unlike Anthropic, Gemini doesn't have explicit cache_control. Caching is implicit and managed by Google's infrastructure.
|
||||
|
||||
## Reasoning/Thinking Tokens
|
||||
|
||||
For models with extended thinking (like Gemini 2.0 with thinking enabled):
|
||||
- `thoughts_token_count` reports tokens used for reasoning
|
||||
- These are similar to OpenAI's `reasoning_tokens`
|
||||
|
||||
**Enabling thinking:**
|
||||
```python
|
||||
generation_config = {
|
||||
"thinking_config": {
|
||||
"thinking_budget": 1024 # Max thinking tokens
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
In streaming mode:
|
||||
- `usage_metadata` is typically in the **final chunk**
|
||||
- Same fields as non-streaming
|
||||
- May not be present in intermediate chunks
|
||||
|
||||
**Important:** `stream_async()` returns an async generator (not awaitable):
|
||||
```python
|
||||
# Correct:
|
||||
stream = client.stream_async(request_data, llm_config)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
# Incorrect (will error):
|
||||
stream = await client.stream_async(...) # TypeError!
|
||||
```
|
||||
|
||||
## APIs
|
||||
|
||||
Gemini has two APIs:
|
||||
- **Google AI (google_ai):** Uses `google.genai` SDK
|
||||
- **Vertex AI (google_vertex):** Uses same SDK with different auth
|
||||
|
||||
Both share the same response format.
|
||||
|
||||
## Letta Implementation
|
||||
|
||||
- **Client:** `letta/llm_api/google_vertex_client.py` (handles both google_ai and google_vertex)
|
||||
- **Streaming interface:** `letta/interfaces/gemini_streaming_interface.py`
|
||||
- **Extract method:** `GoogleVertexClient.extract_usage_statistics()`
|
||||
- Response is a `GenerateContentResponse` object with `.usage_metadata` attribute
|
||||
61
.skills/llm-provider-usage-statistics/references/openai.md
Normal file
61
.skills/llm-provider-usage-statistics/references/openai.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# OpenAI Usage Statistics
|
||||
|
||||
## APIs and Response Formats
|
||||
|
||||
OpenAI has two APIs with different response structures:
|
||||
|
||||
### Chat Completions API
|
||||
```
|
||||
response.usage.prompt_tokens # Total input tokens (includes cached)
|
||||
response.usage.completion_tokens # Output tokens
|
||||
response.usage.total_tokens # Sum
|
||||
response.usage.prompt_tokens_details.cached_tokens # Subset that was cached
|
||||
response.usage.completion_tokens_details.reasoning_tokens # For o1/o3 models
|
||||
```
|
||||
|
||||
### Responses API (newer)
|
||||
```
|
||||
response.usage.input_tokens # Total input tokens
|
||||
response.usage.output_tokens # Output tokens
|
||||
response.usage.total_tokens # Sum
|
||||
response.usage.input_tokens_details.cached_tokens # Subset that was cached
|
||||
response.usage.output_tokens_details.reasoning_tokens # For reasoning models
|
||||
```
|
||||
|
||||
## Prefix Caching
|
||||
|
||||
**Requirements:**
|
||||
- Minimum 1,024 tokens in the prefix
|
||||
- Automatic (no opt-in required)
|
||||
- Cached in 128-token increments
|
||||
- TTL: approximately 5-10 minutes of inactivity
|
||||
|
||||
**Supported models:** GPT-4o, GPT-4o-mini, o1, o1-mini, o3-mini
|
||||
|
||||
**Cache behavior:**
|
||||
- `cached_tokens` will be a multiple of 128
|
||||
- Cache hit means those tokens were not re-processed
|
||||
- Cost: cached tokens are cheaper than non-cached
|
||||
|
||||
## Reasoning Models (o1, o3)
|
||||
|
||||
For reasoning models, additional tokens are used for "thinking":
|
||||
- `reasoning_tokens` in `completion_tokens_details`
|
||||
- These are output tokens used for internal reasoning
|
||||
- Not visible in the response content
|
||||
|
||||
## Streaming
|
||||
|
||||
In streaming mode, usage is reported in the **final chunk** when `stream_options.include_usage=True`:
|
||||
```python
|
||||
request_data["stream_options"] = {"include_usage": True}
|
||||
```
|
||||
|
||||
The final chunk will have `chunk.usage` with the same structure as non-streaming.
|
||||
|
||||
## Letta Implementation
|
||||
|
||||
- **Client:** `letta/llm_api/openai_client.py`
|
||||
- **Streaming interface:** `letta/interfaces/openai_streaming_interface.py`
|
||||
- **Extract method:** `OpenAIClient.extract_usage_statistics()`
|
||||
- Uses OpenAI SDK's pydantic models (`ChatCompletion`) for type-safe parsing
|
||||
@@ -0,0 +1,36 @@
|
||||
"""nullable embedding for archives and passages
|
||||
|
||||
Revision ID: 297e8217e952
|
||||
Revises: 308a180244fc
|
||||
Create Date: 2026-01-20 14:11:21.137232
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "297e8217e952"
|
||||
down_revision: Union[str, None] = "308a180244fc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column("archival_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.alter_column("archives", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.alter_column("source_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column("source_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("archives", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("archival_passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,31 @@
|
||||
"""last_synced column for providers
|
||||
|
||||
Revision ID: 308a180244fc
|
||||
Revises: 82feb220a9b8
|
||||
Create Date: 2026-01-05 18:54:15.996786
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "308a180244fc"
|
||||
down_revision: Union[str, None] = "82feb220a9b8"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("providers", sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("providers", "last_synced")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Add v2 protocol fields to provider_traces
|
||||
|
||||
Revision ID: 9275f62ad282
|
||||
Revises: 297e8217e952
|
||||
Create Date: 2026-01-22
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "9275f62ad282"
|
||||
down_revision: Union[str, None] = "297e8217e952"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("provider_traces", sa.Column("org_id", sa.String(), nullable=True))
|
||||
op.add_column("provider_traces", sa.Column("user_id", sa.String(), nullable=True))
|
||||
op.add_column("provider_traces", sa.Column("compaction_settings", sa.JSON(), nullable=True))
|
||||
op.add_column("provider_traces", sa.Column("llm_config", sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("provider_traces", "llm_config")
|
||||
op.drop_column("provider_traces", "compaction_settings")
|
||||
op.drop_column("provider_traces", "user_id")
|
||||
op.drop_column("provider_traces", "org_id")
|
||||
@@ -0,0 +1,59 @@
|
||||
"""create provider_trace_metadata table
|
||||
|
||||
Revision ID: a1b2c3d4e5f8
|
||||
Revises: 9275f62ad282
|
||||
Create Date: 2026-01-28
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.settings import settings
|
||||
|
||||
revision: str = "a1b2c3d4e5f8"
|
||||
down_revision: Union[str, None] = "9275f62ad282"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
return
|
||||
|
||||
op.create_table(
|
||||
"provider_trace_metadata",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("step_id", sa.String(), nullable=True),
|
||||
sa.Column("agent_id", sa.String(), nullable=True),
|
||||
sa.Column("agent_tags", sa.JSON(), nullable=True),
|
||||
sa.Column("call_type", sa.String(), nullable=True),
|
||||
sa.Column("run_id", sa.String(), nullable=True),
|
||||
sa.Column("source", sa.String(), nullable=True),
|
||||
sa.Column("org_id", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("created_at", "id"),
|
||||
)
|
||||
op.create_index("ix_provider_trace_metadata_step_id", "provider_trace_metadata", ["step_id"], unique=False)
|
||||
op.create_index("ix_provider_trace_metadata_id", "provider_trace_metadata", ["id"], unique=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
return
|
||||
|
||||
op.drop_index("ix_provider_trace_metadata_id", table_name="provider_trace_metadata")
|
||||
op.drop_index("ix_provider_trace_metadata_step_id", table_name="provider_trace_metadata")
|
||||
op.drop_table("provider_trace_metadata")
|
||||
@@ -14,7 +14,7 @@ services:
|
||||
- ./.persist/pgdata-test:/var/lib/postgresql/data
|
||||
- ./init.sql:/docker-entrypoint-initdb.d/init.sql
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- '5432:5432'
|
||||
letta_server:
|
||||
image: letta/letta:latest
|
||||
hostname: letta
|
||||
@@ -25,8 +25,8 @@ services:
|
||||
depends_on:
|
||||
- letta_db
|
||||
ports:
|
||||
- "8083:8083"
|
||||
- "8283:8283"
|
||||
- '8083:8083'
|
||||
- '8283:8283'
|
||||
environment:
|
||||
- LETTA_PG_DB=${LETTA_PG_DB:-letta}
|
||||
- LETTA_PG_USER=${LETTA_PG_USER:-letta}
|
||||
|
||||
48587
fern/openapi.json
Normal file
48587
fern/openapi.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@ try:
|
||||
__version__ = version("letta")
|
||||
except PackageNotFoundError:
|
||||
# Fallback for development installations
|
||||
__version__ = "0.16.3"
|
||||
__version__ = "0.16.4"
|
||||
|
||||
if os.environ.get("LETTA_VERSION"):
|
||||
__version__ = os.environ["LETTA_VERSION"]
|
||||
|
||||
@@ -27,12 +27,16 @@ class LettaLLMAdapter(ABC):
|
||||
agent_id: str | None = None,
|
||||
agent_tags: list[str] | None = None,
|
||||
run_id: str | None = None,
|
||||
org_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
self.llm_client: LLMClientBase = llm_client
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
self.agent_id: str | None = agent_id
|
||||
self.agent_tags: list[str] | None = agent_tags
|
||||
self.run_id: str | None = run_id
|
||||
self.org_id: str | None = org_id
|
||||
self.user_id: str | None = user_id
|
||||
self.message_id: str | None = None
|
||||
self.request_data: dict | None = None
|
||||
self.response_data: dict | None = None
|
||||
|
||||
@@ -127,6 +127,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
),
|
||||
),
|
||||
label="create_provider_trace",
|
||||
|
||||
@@ -33,8 +33,10 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
agent_id: str | None = None,
|
||||
agent_tags: list[str] | None = None,
|
||||
run_id: str | None = None,
|
||||
org_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id)
|
||||
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
|
||||
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
|
||||
|
||||
async def invoke_llm(
|
||||
@@ -60,7 +62,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
self.request_data = request_data
|
||||
|
||||
# Instantiate streaming interface
|
||||
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
||||
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock, ProviderType.minimax]:
|
||||
self.interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=self.llm_config.put_inner_thoughts_in_kwargs,
|
||||
@@ -68,7 +70,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
elif self.llm_config.model_endpoint_type == ProviderType.openai:
|
||||
elif self.llm_config.model_endpoint_type in [ProviderType.openai, ProviderType.openrouter]:
|
||||
# For non-v1 agents, always use Chat Completions streaming interface
|
||||
self.interface = OpenAIStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
@@ -114,64 +116,9 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
# Extract reasoning content from the interface
|
||||
self.reasoning_content = self.interface.get_reasoning_content()
|
||||
|
||||
# Extract usage statistics
|
||||
# Some providers don't provide usage in streaming, use fallback if needed
|
||||
if hasattr(self.interface, "input_tokens") and hasattr(self.interface, "output_tokens"):
|
||||
# Handle cases where tokens might not be set (e.g., LMStudio)
|
||||
input_tokens = self.interface.input_tokens
|
||||
output_tokens = self.interface.output_tokens
|
||||
|
||||
# Fallback to estimated values if not provided
|
||||
if not input_tokens and hasattr(self.interface, "fallback_input_tokens"):
|
||||
input_tokens = self.interface.fallback_input_tokens
|
||||
if not output_tokens and hasattr(self.interface, "fallback_output_tokens"):
|
||||
output_tokens = self.interface.fallback_output_tokens
|
||||
|
||||
# Extract cache token data (OpenAI/Gemini use cached_tokens, Anthropic uses cache_read_tokens)
|
||||
# None means provider didn't report, 0 means provider reported 0
|
||||
cached_input_tokens = None
|
||||
if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens is not None:
|
||||
cached_input_tokens = self.interface.cached_tokens
|
||||
elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens is not None:
|
||||
cached_input_tokens = self.interface.cache_read_tokens
|
||||
|
||||
# Extract cache write tokens (Anthropic only)
|
||||
cache_write_tokens = None
|
||||
if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens is not None:
|
||||
cache_write_tokens = self.interface.cache_creation_tokens
|
||||
|
||||
# Extract reasoning tokens (OpenAI o1/o3 models use reasoning_tokens, Gemini uses thinking_tokens)
|
||||
reasoning_tokens = None
|
||||
if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens is not None:
|
||||
reasoning_tokens = self.interface.reasoning_tokens
|
||||
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
|
||||
reasoning_tokens = self.interface.thinking_tokens
|
||||
|
||||
# Calculate actual total input tokens
|
||||
#
|
||||
# ANTHROPIC: input_tokens is NON-cached only, must add cache tokens
|
||||
# Total = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
|
||||
#
|
||||
# OPENAI/GEMINI: input_tokens is already TOTAL
|
||||
# cached_tokens is a subset, NOT additive
|
||||
is_anthropic = hasattr(self.interface, "cache_read_tokens") or hasattr(self.interface, "cache_creation_tokens")
|
||||
if is_anthropic:
|
||||
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
|
||||
else:
|
||||
actual_input_tokens = input_tokens or 0
|
||||
|
||||
self.usage = LettaUsageStatistics(
|
||||
step_count=1,
|
||||
completion_tokens=output_tokens or 0,
|
||||
prompt_tokens=actual_input_tokens,
|
||||
total_tokens=actual_input_tokens + (output_tokens or 0),
|
||||
cached_input_tokens=cached_input_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
else:
|
||||
# Default usage statistics if not available
|
||||
self.usage = LettaUsageStatistics(step_count=1, completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
# Extract usage statistics from the streaming interface
|
||||
self.usage = self.interface.get_usage_statistics()
|
||||
self.usage.step_count = 1
|
||||
|
||||
# Store any additional data from the interface
|
||||
self.message_id = self.interface.letta_message_id
|
||||
@@ -236,6 +183,9 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
),
|
||||
),
|
||||
label="create_provider_trace",
|
||||
|
||||
@@ -46,6 +46,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
call_type="agent_step",
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
)
|
||||
try:
|
||||
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
|
||||
|
||||
@@ -14,8 +14,8 @@ from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import LettaMessageContentUnion
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.streaming_response import get_cancellation_event_for_run
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
@@ -70,8 +70,11 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
# Store request data
|
||||
self.request_data = request_data
|
||||
|
||||
# Get cancellation event for this run to enable graceful cancellation (before branching)
|
||||
cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None
|
||||
|
||||
# Instantiate streaming interface
|
||||
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
||||
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock, ProviderType.minimax]:
|
||||
# NOTE: different
|
||||
self.interface = SimpleAnthropicStreamingInterface(
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
@@ -81,6 +84,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
elif self.llm_config.model_endpoint_type in [
|
||||
ProviderType.openai,
|
||||
ProviderType.deepseek,
|
||||
ProviderType.openrouter,
|
||||
ProviderType.zai,
|
||||
ProviderType.chatgpt_oauth,
|
||||
]:
|
||||
@@ -102,6 +106,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
cancellation_event=cancellation_event,
|
||||
)
|
||||
else:
|
||||
self.interface = SimpleOpenAIStreamingInterface(
|
||||
@@ -112,12 +117,14 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
model=self.llm_config.model,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
cancellation_event=cancellation_event,
|
||||
)
|
||||
elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
|
||||
self.interface = SimpleGeminiStreamingInterface(
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
cancellation_event=cancellation_event,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
|
||||
@@ -157,68 +164,10 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
# Extract all content parts
|
||||
self.content: List[LettaMessageContentUnion] = self.interface.get_content()
|
||||
|
||||
# Extract usage statistics
|
||||
# Some providers don't provide usage in streaming, use fallback if needed
|
||||
if hasattr(self.interface, "input_tokens") and hasattr(self.interface, "output_tokens"):
|
||||
# Handle cases where tokens might not be set (e.g., LMStudio)
|
||||
input_tokens = self.interface.input_tokens
|
||||
output_tokens = self.interface.output_tokens
|
||||
|
||||
# Fallback to estimated values if not provided
|
||||
if not input_tokens and hasattr(self.interface, "fallback_input_tokens"):
|
||||
input_tokens = self.interface.fallback_input_tokens
|
||||
if not output_tokens and hasattr(self.interface, "fallback_output_tokens"):
|
||||
output_tokens = self.interface.fallback_output_tokens
|
||||
|
||||
# Extract cache token data (OpenAI/Gemini use cached_tokens)
|
||||
# None means provider didn't report, 0 means provider reported 0
|
||||
cached_input_tokens = None
|
||||
if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens is not None:
|
||||
cached_input_tokens = self.interface.cached_tokens
|
||||
# Anthropic uses cache_read_tokens for cache hits
|
||||
elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens is not None:
|
||||
cached_input_tokens = self.interface.cache_read_tokens
|
||||
|
||||
# Extract cache write tokens (Anthropic only)
|
||||
# None means provider didn't report, 0 means provider reported 0
|
||||
cache_write_tokens = None
|
||||
if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens is not None:
|
||||
cache_write_tokens = self.interface.cache_creation_tokens
|
||||
|
||||
# Extract reasoning tokens (OpenAI o1/o3 models use reasoning_tokens, Gemini uses thinking_tokens)
|
||||
# None means provider didn't report, 0 means provider reported 0
|
||||
reasoning_tokens = None
|
||||
if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens is not None:
|
||||
reasoning_tokens = self.interface.reasoning_tokens
|
||||
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
|
||||
reasoning_tokens = self.interface.thinking_tokens
|
||||
|
||||
# Calculate actual total input tokens for context window limit checks (summarization trigger).
|
||||
#
|
||||
# ANTHROPIC: input_tokens is NON-cached only, must add cache tokens
|
||||
# Total = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
|
||||
#
|
||||
# OPENAI/GEMINI: input_tokens (prompt_tokens/prompt_token_count) is already TOTAL
|
||||
# cached_tokens is a subset, NOT additive
|
||||
# Total = input_tokens (don't add cached_tokens or it double-counts!)
|
||||
is_anthropic = hasattr(self.interface, "cache_read_tokens") or hasattr(self.interface, "cache_creation_tokens")
|
||||
if is_anthropic:
|
||||
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
|
||||
else:
|
||||
actual_input_tokens = input_tokens or 0
|
||||
|
||||
self.usage = LettaUsageStatistics(
|
||||
step_count=1,
|
||||
completion_tokens=output_tokens or 0,
|
||||
prompt_tokens=actual_input_tokens,
|
||||
total_tokens=actual_input_tokens + (output_tokens or 0),
|
||||
cached_input_tokens=cached_input_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
else:
|
||||
# Default usage statistics if not available
|
||||
self.usage = LettaUsageStatistics(step_count=1, completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
# Extract usage statistics from the interface
|
||||
# Each interface implements get_usage_statistics() with provider-specific logic
|
||||
self.usage = self.interface.get_usage_statistics()
|
||||
self.usage.step_count = 1
|
||||
|
||||
# Store any additional data from the interface
|
||||
self.message_id = self.interface.letta_message_id
|
||||
@@ -283,6 +232,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
),
|
||||
),
|
||||
label="create_provider_trace",
|
||||
|
||||
@@ -97,6 +97,25 @@ async def _prepare_in_context_messages_async(
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
|
||||
|
||||
@trace_method
|
||||
def validate_persisted_tool_call_ids(tool_return_message: Message, approval_response_message: ApprovalCreate) -> bool:
|
||||
persisted_tool_returns = tool_return_message.tool_returns
|
||||
if not persisted_tool_returns:
|
||||
return False
|
||||
persisted_tool_call_ids = [tool_return.tool_call_id for tool_return in persisted_tool_returns]
|
||||
|
||||
approval_responses = approval_response_message.approvals
|
||||
if not approval_responses:
|
||||
return False
|
||||
approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses]
|
||||
|
||||
request_response_diff = set(persisted_tool_call_ids).symmetric_difference(set(approval_response_tool_call_ids))
|
||||
if request_response_diff:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@trace_method
|
||||
def validate_approval_tool_call_ids(approval_request_message: Message, approval_response_message: ApprovalCreate):
|
||||
approval_requests = approval_request_message.tool_calls
|
||||
@@ -227,6 +246,36 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
if input_messages[0].type == "approval":
|
||||
# User is trying to send an approval response
|
||||
if current_in_context_messages and current_in_context_messages[-1].role != "approval":
|
||||
# No pending approval request - check if this is an idempotent retry
|
||||
# Check last few messages for a tool return matching the approval's tool_call_ids
|
||||
# (approved tool return should be recent, but server-side tool calls may come after it)
|
||||
approval_already_processed = False
|
||||
recent_messages = current_in_context_messages[-10:] # Only check last 10 messages
|
||||
for msg in reversed(recent_messages):
|
||||
if msg.role == "tool" and validate_persisted_tool_call_ids(msg, input_messages[0]):
|
||||
logger.info(
|
||||
f"Idempotency check: Found matching tool return in recent history. "
|
||||
f"tool_returns={msg.tool_returns}, approval_response.approvals={input_messages[0].approvals}"
|
||||
)
|
||||
approval_already_processed = True
|
||||
break
|
||||
|
||||
if approval_already_processed:
|
||||
# Approval already handled, just process follow-up messages if any or manually inject keep-alive message
|
||||
keep_alive_messages = input_messages[1:] or [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
TextContent(
|
||||
text="<system-alert>Automated keep-alive ping. Ignore this message and continue from where you stopped.</system-alert>"
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
new_in_context_messages = await create_input_messages(
|
||||
input_messages=keep_alive_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
|
||||
)
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
logger.warn(
|
||||
f"Cannot process approval response: No tool call is currently awaiting approval. Last message: {current_in_context_messages[-1]}"
|
||||
)
|
||||
@@ -235,7 +284,7 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
"Please send a regular message to interact with the agent."
|
||||
)
|
||||
validate_approval_tool_call_ids(current_in_context_messages[-1], input_messages[0])
|
||||
new_in_context_messages = create_approval_response_message_from_input(
|
||||
new_in_context_messages = await create_approval_response_message_from_input(
|
||||
agent_state=agent_state, input_message=input_messages[0], run_id=run_id
|
||||
)
|
||||
if len(input_messages) > 1:
|
||||
|
||||
@@ -218,6 +218,7 @@ class LettaAgent(BaseAgent):
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
run_id: str | None = None,
|
||||
):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id,
|
||||
@@ -330,6 +331,7 @@ class LettaAgent(BaseAgent):
|
||||
tool_rules_solver,
|
||||
agent_step_span,
|
||||
step_metrics,
|
||||
run_id=run_id,
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
@@ -418,6 +420,9 @@ class LettaAgent(BaseAgent):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
@@ -549,6 +554,7 @@ class LettaAgent(BaseAgent):
|
||||
llm_config=agent_state.llm_config,
|
||||
total_tokens=usage.total_tokens,
|
||||
force=False,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
|
||||
@@ -677,6 +683,7 @@ class LettaAgent(BaseAgent):
|
||||
tool_rules_solver,
|
||||
agent_step_span,
|
||||
step_metrics,
|
||||
run_id=run_id,
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
@@ -766,6 +773,9 @@ class LettaAgent(BaseAgent):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
@@ -882,6 +892,7 @@ class LettaAgent(BaseAgent):
|
||||
llm_config=agent_state.llm_config,
|
||||
total_tokens=usage.total_tokens,
|
||||
force=False,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
|
||||
@@ -908,6 +919,7 @@ class LettaAgent(BaseAgent):
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
run_id: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
|
||||
@@ -1027,6 +1039,8 @@ class LettaAgent(BaseAgent):
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
step_progression = StepProgression.STREAM_RECEIVED
|
||||
@@ -1234,6 +1248,9 @@ class LettaAgent(BaseAgent):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
@@ -1378,6 +1395,7 @@ class LettaAgent(BaseAgent):
|
||||
llm_config=agent_state.llm_config,
|
||||
total_tokens=usage.total_tokens,
|
||||
force=False,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
|
||||
@@ -1441,6 +1459,7 @@ class LettaAgent(BaseAgent):
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
agent_step_span: "Span",
|
||||
step_metrics: StepMetrics,
|
||||
run_id: str | None = None,
|
||||
) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None:
|
||||
for attempt in range(self.max_summarization_retries + 1):
|
||||
try:
|
||||
@@ -1461,6 +1480,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
step_id=step_metrics.id,
|
||||
call_type="agent_step",
|
||||
)
|
||||
response = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
|
||||
@@ -1488,6 +1508,7 @@ class LettaAgent(BaseAgent):
|
||||
new_letta_messages=new_in_context_messages,
|
||||
llm_config=agent_state.llm_config,
|
||||
force=True,
|
||||
run_id=run_id,
|
||||
)
|
||||
new_in_context_messages = []
|
||||
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
|
||||
@@ -1503,6 +1524,8 @@ class LettaAgent(BaseAgent):
|
||||
agent_state: AgentState,
|
||||
llm_client: LLMClientBase,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None:
|
||||
for attempt in range(self.max_summarization_retries + 1):
|
||||
try:
|
||||
@@ -1530,6 +1553,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
step_id=step_id,
|
||||
call_type="agent_step",
|
||||
)
|
||||
|
||||
@@ -1555,6 +1579,7 @@ class LettaAgent(BaseAgent):
|
||||
new_letta_messages=new_in_context_messages,
|
||||
llm_config=agent_state.llm_config,
|
||||
force=True,
|
||||
run_id=run_id,
|
||||
)
|
||||
new_in_context_messages: list[Message] = []
|
||||
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
|
||||
@@ -1568,10 +1593,17 @@ class LettaAgent(BaseAgent):
|
||||
new_letta_messages: list[Message],
|
||||
llm_config: LLMConfig,
|
||||
force: bool,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
) -> list[Message]:
|
||||
if isinstance(e, ContextWindowExceededError):
|
||||
return await self._rebuild_context_window(
|
||||
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, llm_config=llm_config, force=force
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
llm_config=llm_config,
|
||||
force=force,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
@@ -1584,6 +1616,8 @@ class LettaAgent(BaseAgent):
|
||||
llm_config: LLMConfig,
|
||||
total_tokens: int | None = None,
|
||||
force: bool = False,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
) -> list[Message]:
|
||||
# If total tokens is reached, we truncate down
|
||||
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
|
||||
@@ -1597,6 +1631,8 @@ class LettaAgent(BaseAgent):
|
||||
new_letta_messages=new_letta_messages,
|
||||
force=True,
|
||||
clear=True,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
# NOTE (Sarah): Seems like this is doing nothing?
|
||||
@@ -1606,6 +1642,8 @@ class LettaAgent(BaseAgent):
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_id,
|
||||
|
||||
@@ -156,7 +156,11 @@ class LettaAgentV2(BaseAgentV2):
|
||||
run_id=None,
|
||||
messages=in_context_messages + input_messages_to_persist,
|
||||
llm_adapter=LettaLLMRequestAdapter(
|
||||
llm_client=self.llm_client, llm_config=self.agent_state.llm_config, agent_tags=self.agent_state.tags
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
agent_tags=self.agent_state.tags,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
),
|
||||
dry_run=True,
|
||||
enforce_run_id_set=False,
|
||||
@@ -213,6 +217,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
),
|
||||
run_id=run_id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
@@ -236,6 +242,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
new_letta_messages=self.response_messages,
|
||||
total_tokens=self.usage.total_tokens,
|
||||
force=False,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
if self.stop_reason is None:
|
||||
@@ -297,6 +304,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
else:
|
||||
llm_adapter = LettaLLMRequestAdapter(
|
||||
@@ -305,6 +314,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -343,6 +354,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
new_letta_messages=self.response_messages,
|
||||
total_tokens=self.usage.total_tokens,
|
||||
force=False,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
except:
|
||||
@@ -488,6 +500,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
in_context_messages=messages,
|
||||
new_letta_messages=self.response_messages,
|
||||
force=True,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
@@ -1246,6 +1260,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
new_letta_messages: list[Message],
|
||||
total_tokens: int | None = None,
|
||||
force: bool = False,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
) -> list[Message]:
|
||||
self.logger.warning("Running deprecated v2 summarizer. This should be removed in the future.")
|
||||
# always skip summarization if last message is an approval request message
|
||||
@@ -1268,6 +1284,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
new_letta_messages=new_letta_messages,
|
||||
force=True,
|
||||
clear=True,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
# NOTE (Sarah): Seems like this is doing nothing?
|
||||
@@ -1277,6 +1295,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to summarize conversation history: {e}")
|
||||
|
||||
@@ -41,6 +41,7 @@ from letta.schemas.step import StepProgression
|
||||
from letta.schemas.step_metrics import StepMetrics
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import (
|
||||
create_approval_request_message_from_llm_response,
|
||||
create_letta_messages_from_llm_response,
|
||||
@@ -72,6 +73,16 @@ class LettaAgentV3(LettaAgentV2):
|
||||
* Support Gemini / OpenAI client
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
conversation_id: str | None = None,
|
||||
):
|
||||
super().__init__(agent_state, actor)
|
||||
# Set conversation_id after parent init (which calls _initialize_state)
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
def _initialize_state(self):
|
||||
super()._initialize_state()
|
||||
self._require_tool_call = False
|
||||
@@ -168,7 +179,13 @@ class LettaAgentV3(LettaAgentV2):
|
||||
input_messages_to_persist=input_messages_to_persist,
|
||||
# TODO need to support non-streaming adapter too
|
||||
llm_adapter=SimpleLLMRequestAdapter(
|
||||
llm_client=self.llm_client, llm_config=self.agent_state.llm_config, agent_id=self.agent_state.id, run_id=run_id
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
),
|
||||
run_id=run_id,
|
||||
# use_assistant_message=use_assistant_message,
|
||||
@@ -310,14 +327,20 @@ class LettaAgentV3(LettaAgentV2):
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
else:
|
||||
llm_adapter = SimpleLLMRequestAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -390,7 +413,9 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error during agent stream: {e}", exc_info=True)
|
||||
# Use repr() if str() is empty (happens with Exception() with no args)
|
||||
error_detail = str(e) or repr(e)
|
||||
self.logger.warning(f"Error during agent stream: {error_detail}", exc_info=True)
|
||||
|
||||
# Set stop_reason if not already set
|
||||
if self.stop_reason is None:
|
||||
@@ -411,7 +436,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
run_id=run_id,
|
||||
error_type="internal_error",
|
||||
message="An error occurred during agent execution.",
|
||||
detail=str(e),
|
||||
detail=error_detail,
|
||||
)
|
||||
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
|
||||
|
||||
@@ -486,10 +511,11 @@ class LettaAgentV3(LettaAgentV2):
|
||||
new_messages: The new messages to persist
|
||||
in_context_messages: The current in-context messages
|
||||
"""
|
||||
# make sure all the new messages have the correct run_id and step_id
|
||||
# make sure all the new messages have the correct run_id, step_id, and conversation_id
|
||||
for message in new_messages:
|
||||
message.step_id = step_id
|
||||
message.run_id = run_id
|
||||
message.conversation_id = self.conversation_id
|
||||
|
||||
# persist the new message objects - ONLY place where messages are persisted
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
@@ -653,7 +679,15 @@ class LettaAgentV3(LettaAgentV2):
|
||||
return
|
||||
|
||||
step_id = approval_request.step_id
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
|
||||
if step_id is None:
|
||||
# Old approval messages may not have step_id set - generate a new one
|
||||
self.logger.warning(f"Approval request message {approval_request.id} has no step_id, generating new step_id")
|
||||
step_id = generate_step_id()
|
||||
step_progression, logged_step, step_metrics, agent_step_span = await self._step_checkpoint_start(
|
||||
step_id=step_id, run_id=run_id
|
||||
)
|
||||
else:
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
|
||||
else:
|
||||
# Check for job cancellation at the start of each step
|
||||
if run_id and await self._check_run_cancellation(run_id):
|
||||
@@ -760,7 +794,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# TODO: might want to delay this checkpoint in case of corrupated state
|
||||
try:
|
||||
summary_message, messages, _ = await self.compact(
|
||||
messages, trigger_threshold=self.agent_state.llm_config.context_window
|
||||
messages,
|
||||
trigger_threshold=self.agent_state.llm_config.context_window,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
self.logger.info("Summarization succeeded, continuing to retry LLM request")
|
||||
continue
|
||||
@@ -776,7 +813,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
# update the messages
|
||||
await self._checkpoint_messages(
|
||||
run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
new_messages=[summary_message],
|
||||
in_context_messages=messages,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -879,20 +919,30 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self.logger.info(
|
||||
f"Context window exceeded (current: {self.context_token_estimate}, threshold: {self.agent_state.llm_config.context_window}), trying to compact messages"
|
||||
)
|
||||
summary_message, messages, _ = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window)
|
||||
summary_message, messages, _ = await self.compact(
|
||||
messages,
|
||||
trigger_threshold=self.agent_state.llm_config.context_window,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
# TODO: persist + return the summary message
|
||||
# TODO: convert this to a SummaryMessage
|
||||
self.response_messages.append(summary_message)
|
||||
for message in Message.to_letta_messages(summary_message):
|
||||
yield message
|
||||
await self._checkpoint_messages(
|
||||
run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
new_messages=[summary_message],
|
||||
in_context_messages=messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# NOTE: message persistence does not happen in the case of an exception (rollback to previous state)
|
||||
self.logger.warning(f"Error during step processing: {e}")
|
||||
self.job_update_metadata = {"error": str(e)}
|
||||
# Use repr() if str() is empty (happens with Exception() with no args)
|
||||
error_detail = str(e) or repr(e)
|
||||
self.logger.warning(f"Error during step processing: {error_detail}")
|
||||
self.job_update_metadata = {"error": error_detail}
|
||||
|
||||
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||||
if not self.stop_reason:
|
||||
@@ -1445,7 +1495,12 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
@trace_method
|
||||
async def compact(
|
||||
self, messages, trigger_threshold: Optional[int] = None, compaction_settings: Optional["CompactionSettings"] = None
|
||||
self,
|
||||
messages,
|
||||
trigger_threshold: Optional[int] = None,
|
||||
compaction_settings: Optional["CompactionSettings"] = None,
|
||||
run_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> tuple[Message, list[Message], str]:
|
||||
"""Compact the current in-context messages for this agent.
|
||||
|
||||
@@ -1472,7 +1527,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
summarizer_config = CompactionSettings(model=handle)
|
||||
|
||||
# Build the LLMConfig used for summarization
|
||||
summarizer_llm_config = self._build_summarizer_llm_config(
|
||||
summarizer_llm_config = await self._build_summarizer_llm_config(
|
||||
agent_llm_config=self.agent_state.llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
)
|
||||
@@ -1484,6 +1539,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
llm_config=summarizer_llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
elif summarizer_config.mode == "sliding_window":
|
||||
try:
|
||||
@@ -1492,6 +1551,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
llm_config=summarizer_llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Sliding window summarization failed with exception: {str(e)}. Falling back to all mode.")
|
||||
@@ -1500,6 +1563,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
llm_config=summarizer_llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
summarization_mode_used = "all"
|
||||
else:
|
||||
@@ -1533,6 +1600,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
llm_config=self.agent_state.llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=compacted_messages,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
summarization_mode_used = "all"
|
||||
|
||||
@@ -1584,8 +1655,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
return summary_message_obj, final_messages, summary
|
||||
|
||||
@staticmethod
|
||||
def _build_summarizer_llm_config(
|
||||
async def _build_summarizer_llm_config(
|
||||
self,
|
||||
agent_llm_config: LLMConfig,
|
||||
summarizer_config: CompactionSettings,
|
||||
) -> LLMConfig:
|
||||
@@ -1611,12 +1682,41 @@ class LettaAgentV3(LettaAgentV2):
|
||||
model_name = summarizer_config.model
|
||||
|
||||
# Start from the agent's config and override model + provider_name + handle
|
||||
# Note: model_endpoint_type is NOT overridden - the parsed provider_name
|
||||
# is a custom label (e.g. "claude-pro-max"), not the endpoint type (e.g. "anthropic")
|
||||
base = agent_llm_config.model_copy()
|
||||
base.provider_name = provider_name
|
||||
base.model = model_name
|
||||
base.handle = summarizer_config.model
|
||||
# Check if the summarizer's provider matches the agent's provider
|
||||
# If they match, we can safely use the agent's config as a base
|
||||
# If they don't match, we need to load the default config for the new provider
|
||||
from letta.schemas.enums import ProviderType
|
||||
|
||||
provider_matches = False
|
||||
try:
|
||||
# Check if provider_name is a valid ProviderType that matches agent's endpoint type
|
||||
provider_type = ProviderType(provider_name)
|
||||
provider_matches = provider_type.value == agent_llm_config.model_endpoint_type
|
||||
except ValueError:
|
||||
# provider_name is a custom label - check if it matches agent's provider_name
|
||||
provider_matches = provider_name == agent_llm_config.provider_name
|
||||
|
||||
if provider_matches:
|
||||
# Same provider - use agent's config as base and override model/handle
|
||||
base = agent_llm_config.model_copy()
|
||||
base.model = model_name
|
||||
base.handle = summarizer_config.model
|
||||
else:
|
||||
# Different provider - load default config for this handle
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
try:
|
||||
base = await provider_manager.get_llm_config_from_handle(
|
||||
handle=summarizer_config.model,
|
||||
actor=self.actor,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Failed to load LLM config for summarizer handle '{summarizer_config.model}': {e}. "
|
||||
f"Falling back to agent's LLM config."
|
||||
)
|
||||
return agent_llm_config
|
||||
|
||||
# If explicit model_settings are provided for the summarizer, apply
|
||||
# them just like server.create_agent_async does for agents.
|
||||
|
||||
@@ -25,7 +25,7 @@ PROVIDER_ORDER = {
|
||||
"xai": 12,
|
||||
"lmstudio": 13,
|
||||
"zai": 14,
|
||||
"openrouter": 15, # Note: OpenRouter uses OpenRouterProvider, not a ProviderType enum
|
||||
"openrouter": 15,
|
||||
}
|
||||
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
|
||||
@@ -7,8 +7,52 @@ from typing import Dict, Optional, Tuple
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.types import JsonDict
|
||||
|
||||
_ALLOWED_TYPING_NAMES = {name: obj for name, obj in vars(typing).items() if not name.startswith("_")}
|
||||
_ALLOWED_BUILTIN_TYPES = {name: obj for name, obj in vars(builtins).items() if isinstance(obj, type)}
|
||||
_ALLOWED_TYPE_NAMES = {**_ALLOWED_TYPING_NAMES, **_ALLOWED_BUILTIN_TYPES, "typing": typing}
|
||||
|
||||
def resolve_type(annotation: str):
|
||||
|
||||
def _resolve_annotation_node(node: ast.AST):
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id == "None":
|
||||
return type(None)
|
||||
if node.id in _ALLOWED_TYPE_NAMES:
|
||||
return _ALLOWED_TYPE_NAMES[node.id]
|
||||
raise ValueError(f"Unsupported annotation name: {node.id}")
|
||||
|
||||
if isinstance(node, ast.Attribute):
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "typing" and node.attr in _ALLOWED_TYPING_NAMES:
|
||||
return _ALLOWED_TYPING_NAMES[node.attr]
|
||||
raise ValueError("Unsupported annotation attribute")
|
||||
|
||||
if isinstance(node, ast.Subscript):
|
||||
origin = _resolve_annotation_node(node.value)
|
||||
args = _resolve_subscript_slice(node.slice)
|
||||
return origin[args]
|
||||
|
||||
if isinstance(node, ast.Tuple):
|
||||
return tuple(_resolve_annotation_node(elt) for elt in node.elts)
|
||||
|
||||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
left = _resolve_annotation_node(node.left)
|
||||
right = _resolve_annotation_node(node.right)
|
||||
return left | right
|
||||
|
||||
if isinstance(node, ast.Constant) and node.value is None:
|
||||
return type(None)
|
||||
|
||||
raise ValueError("Unsupported annotation expression")
|
||||
|
||||
|
||||
def _resolve_subscript_slice(slice_node: ast.AST):
|
||||
if isinstance(slice_node, ast.Index):
|
||||
slice_node = slice_node.value
|
||||
if isinstance(slice_node, ast.Tuple):
|
||||
return tuple(_resolve_annotation_node(elt) for elt in slice_node.elts)
|
||||
return _resolve_annotation_node(slice_node)
|
||||
|
||||
|
||||
def resolve_type(annotation: str, *, allow_unsafe_eval: bool = False, extra_globals: Optional[Dict[str, object]] = None):
|
||||
"""
|
||||
Resolve a type annotation string into a Python type.
|
||||
Previously, primitive support for int, float, str, dict, list, set, tuple, bool.
|
||||
@@ -23,15 +67,23 @@ def resolve_type(annotation: str):
|
||||
ValueError: If the annotation is unsupported or invalid.
|
||||
"""
|
||||
python_types = {**vars(typing), **vars(builtins)}
|
||||
if extra_globals:
|
||||
python_types.update(extra_globals)
|
||||
|
||||
if annotation in python_types:
|
||||
return python_types[annotation]
|
||||
|
||||
try:
|
||||
# Allow use of typing and builtins in a safe eval context
|
||||
return eval(annotation, python_types)
|
||||
parsed = ast.parse(annotation, mode="eval")
|
||||
return _resolve_annotation_node(parsed.body)
|
||||
except Exception:
|
||||
raise ValueError(f"Unsupported annotation: {annotation}")
|
||||
if allow_unsafe_eval:
|
||||
try:
|
||||
return eval(annotation, python_types)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Unsupported annotation: {annotation}") from exc
|
||||
|
||||
raise ValueError(f"Unsupported annotation: {annotation}")
|
||||
|
||||
|
||||
# TODO :: THIS MUST BE EDITED TO HANDLE THINGS
|
||||
@@ -62,14 +114,34 @@ def get_function_annotations_from_source(source_code: str, function_name: str) -
|
||||
|
||||
|
||||
# NOW json_loads -> ast.literal_eval -> typing.get_origin
|
||||
def coerce_dict_args_by_annotations(function_args: JsonDict, annotations: Dict[str, str]) -> dict:
|
||||
def coerce_dict_args_by_annotations(
|
||||
function_args: JsonDict,
|
||||
annotations: Dict[str, object],
|
||||
*,
|
||||
allow_unsafe_eval: bool = False,
|
||||
extra_globals: Optional[Dict[str, object]] = None,
|
||||
) -> dict:
|
||||
coerced_args = dict(function_args) # Shallow copy
|
||||
|
||||
for arg_name, value in coerced_args.items():
|
||||
if arg_name in annotations:
|
||||
annotation_str = annotations[arg_name]
|
||||
try:
|
||||
arg_type = resolve_type(annotation_str)
|
||||
annotation_value = annotations[arg_name]
|
||||
if isinstance(annotation_value, str):
|
||||
arg_type = resolve_type(
|
||||
annotation_value,
|
||||
allow_unsafe_eval=allow_unsafe_eval,
|
||||
extra_globals=extra_globals,
|
||||
)
|
||||
elif isinstance(annotation_value, typing.ForwardRef):
|
||||
arg_type = resolve_type(
|
||||
annotation_value.__forward_arg__,
|
||||
allow_unsafe_eval=allow_unsafe_eval,
|
||||
extra_globals=extra_globals,
|
||||
)
|
||||
else:
|
||||
arg_type = annotation_value
|
||||
|
||||
# Always parse strings using literal_eval or json if possible
|
||||
if isinstance(value, str):
|
||||
|
||||
@@ -383,12 +383,12 @@ def memory_replace(agent_state: "AgentState", label: str, old_str: str, new_str:
|
||||
# snippet = "\n".join(new_value.split("\n")[start_line : end_line + 1])
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The core memory block with label `{label}` has been edited. "
|
||||
# success_msg += self._make_output(
|
||||
# snippet, f"a snippet of {path}", start_line + 1
|
||||
# )
|
||||
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
success_msg = (
|
||||
f"The core memory block with label `{label}` has been successfully edited. "
|
||||
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
|
||||
f"Review the changes and make sure they are as expected (correct indentation, "
|
||||
f"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
# return None
|
||||
return success_msg
|
||||
@@ -454,14 +454,12 @@ def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_li
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The core memory block with label `{label}` has been edited. "
|
||||
# success_msg += self._make_output(
|
||||
# snippet,
|
||||
# "a snippet of the edited file",
|
||||
# max(1, insert_line - SNIPPET_LINES + 1),
|
||||
# )
|
||||
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
success_msg = (
|
||||
f"The core memory block with label `{label}` has been successfully edited. "
|
||||
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
|
||||
f"Review the changes and make sure they are as expected (correct indentation, "
|
||||
f"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
return success_msg
|
||||
|
||||
@@ -532,12 +530,12 @@ def memory_rethink(agent_state: "AgentState", label: str, new_memory: str) -> No
|
||||
agent_state.memory.update_block_value(label=label, value=new_memory)
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The core memory block with label `{label}` has been edited. "
|
||||
# success_msg += self._make_output(
|
||||
# snippet, f"a snippet of {path}", start_line + 1
|
||||
# )
|
||||
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
success_msg = (
|
||||
f"The core memory block with label `{label}` has been successfully edited. "
|
||||
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
|
||||
f"Review the changes and make sure they are as expected (correct indentation, "
|
||||
f"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
# return None
|
||||
return success_msg
|
||||
|
||||
@@ -166,3 +166,61 @@ async def _convert_message_create_to_message(
|
||||
batch_item_id=message_create.batch_item_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_url_to_base64(url: str) -> tuple[str, str]:
|
||||
"""Resolve URL to base64 data and media type."""
|
||||
if url.startswith("file://"):
|
||||
parsed = urlparse(url)
|
||||
file_path = unquote(parsed.path)
|
||||
image_bytes = await asyncio.to_thread(lambda: open(file_path, "rb").read())
|
||||
media_type, _ = mimetypes.guess_type(file_path)
|
||||
media_type = media_type or "image/jpeg"
|
||||
else:
|
||||
image_bytes, media_type = await _fetch_image_from_url(url)
|
||||
media_type = media_type or mimetypes.guess_type(url)[0] or "image/png"
|
||||
|
||||
image_data = base64.standard_b64encode(image_bytes).decode("utf-8")
|
||||
return image_data, media_type
|
||||
|
||||
|
||||
async def resolve_tool_return_images(func_response: str | list) -> str | list:
|
||||
"""Resolve URL and LettaImage sources to base64 for tool returns."""
|
||||
if isinstance(func_response, str):
|
||||
return func_response
|
||||
|
||||
resolved = []
|
||||
for part in func_response:
|
||||
if isinstance(part, ImageContent):
|
||||
if part.source.type == ImageSourceType.url:
|
||||
image_data, media_type = await _resolve_url_to_base64(part.source.url)
|
||||
part.source = Base64Image(media_type=media_type, data=image_data)
|
||||
elif part.source.type == ImageSourceType.letta and not part.source.data:
|
||||
pass
|
||||
resolved.append(part)
|
||||
elif isinstance(part, TextContent):
|
||||
resolved.append(part)
|
||||
elif isinstance(part, dict):
|
||||
if part.get("type") == "image" and part.get("source", {}).get("type") == "url":
|
||||
url = part["source"].get("url")
|
||||
if url:
|
||||
image_data, media_type = await _resolve_url_to_base64(url)
|
||||
resolved.append(
|
||||
ImageContent(
|
||||
source=Base64Image(
|
||||
media_type=media_type,
|
||||
data=image_data,
|
||||
detail=part.get("source", {}).get("detail"),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
resolved.append(part)
|
||||
elif part.get("type") == "text":
|
||||
resolved.append(TextContent(text=part.get("text", "")))
|
||||
else:
|
||||
resolved.append(part)
|
||||
else:
|
||||
resolved.append(part)
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -7,6 +7,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, TagMatchMode
|
||||
@@ -321,6 +322,7 @@ class TurbopufferClient:
|
||||
actor: "PydanticUser",
|
||||
tags: Optional[List[str]] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
embeddings: Optional[List[List[float]]] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert passages into Turbopuffer.
|
||||
|
||||
@@ -332,6 +334,7 @@ class TurbopufferClient:
|
||||
actor: User actor for embedding generation
|
||||
tags: Optional list of tags to attach to all passages
|
||||
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
||||
embeddings: Optional pre-computed embeddings (must match 1:1 with text_chunks). If provided, skips embedding generation.
|
||||
|
||||
Returns:
|
||||
List of PydanticPassage objects that were inserted
|
||||
@@ -345,9 +348,30 @@ class TurbopufferClient:
|
||||
logger.warning("All text chunks were empty, skipping insertion")
|
||||
return []
|
||||
|
||||
# generate embeddings using the default config
|
||||
filtered_texts = [text for _, text in filtered_chunks]
|
||||
embeddings = await self._generate_embeddings(filtered_texts, actor)
|
||||
|
||||
# use provided embeddings only if dimensions match TPUF's expected dimension
|
||||
use_provided_embeddings = False
|
||||
if embeddings is not None:
|
||||
if len(embeddings) != len(text_chunks):
|
||||
raise LettaInvalidArgumentError(
|
||||
f"embeddings length ({len(embeddings)}) must match text_chunks length ({len(text_chunks)})",
|
||||
argument_name="embeddings",
|
||||
)
|
||||
# check if first non-empty embedding has correct dimensions
|
||||
filtered_indices = [i for i, _ in filtered_chunks]
|
||||
sample_embedding = embeddings[filtered_indices[0]] if filtered_indices else None
|
||||
if sample_embedding is not None and len(sample_embedding) == self.default_embedding_config.embedding_dim:
|
||||
use_provided_embeddings = True
|
||||
filtered_embeddings = [embeddings[i] for i, _ in filtered_chunks]
|
||||
else:
|
||||
logger.debug(
|
||||
f"Embedding dimension mismatch (got {len(sample_embedding) if sample_embedding else 'None'}, "
|
||||
f"expected {self.default_embedding_config.embedding_dim}), regenerating embeddings"
|
||||
)
|
||||
|
||||
if not use_provided_embeddings:
|
||||
filtered_embeddings = await self._generate_embeddings(filtered_texts, actor)
|
||||
|
||||
namespace_name = await self._get_archive_namespace_name(archive_id)
|
||||
|
||||
@@ -379,7 +403,7 @@ class TurbopufferClient:
|
||||
tags_arrays = [] # Store tags as arrays
|
||||
passages = []
|
||||
|
||||
for (original_idx, text), embedding in zip(filtered_chunks, embeddings):
|
||||
for (original_idx, text), embedding in zip(filtered_chunks, filtered_embeddings):
|
||||
passage_id = passage_ids[original_idx]
|
||||
|
||||
# append to columns
|
||||
|
||||
@@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -145,6 +146,26 @@ class SimpleAnthropicStreamingInterface:
|
||||
return tool_calls[0]
|
||||
return None
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Anthropic: input_tokens is NON-cached only, must add cache tokens for total
|
||||
actual_input_tokens = (self.input_tokens or 0) + (self.cache_read_tokens or 0) + (self.cache_creation_tokens or 0)
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=actual_input_tokens,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=actual_input_tokens + (self.output_tokens or 0),
|
||||
cached_input_tokens=self.cache_read_tokens if self.cache_read_tokens else None,
|
||||
cache_write_tokens=self.cache_creation_tokens if self.cache_creation_tokens else None,
|
||||
reasoning_tokens=None, # Anthropic doesn't report reasoning tokens separately
|
||||
)
|
||||
|
||||
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
|
||||
def _process_group(
|
||||
group: list[ReasoningMessage | HiddenReasoningMessage | AssistantMessage],
|
||||
@@ -228,10 +249,10 @@ class SimpleAnthropicStreamingInterface:
|
||||
prev_message_type = new_message_type
|
||||
# print(f"Yielding message: {message}")
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
|
||||
@@ -41,6 +41,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -127,6 +128,25 @@ class AnthropicStreamingInterface:
|
||||
arguments = str(json.dumps(tool_input, indent=2))
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Anthropic: input_tokens is NON-cached only in streaming
|
||||
# This interface doesn't track cache tokens, so we just use the raw values
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=self.input_tokens or 0,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
|
||||
cached_input_tokens=None, # This interface doesn't track cache tokens
|
||||
cache_write_tokens=None,
|
||||
reasoning_tokens=None,
|
||||
)
|
||||
|
||||
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
|
||||
"""
|
||||
Check if inner thoughts are complete in the current tool call arguments
|
||||
@@ -218,10 +238,10 @@ class AnthropicStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
@@ -636,6 +656,25 @@ class SimpleAnthropicStreamingInterface:
|
||||
arguments = str(json.dumps(tool_input, indent=2))
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Anthropic: input_tokens is NON-cached only in streaming
|
||||
# This interface doesn't track cache tokens, so we just use the raw values
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=self.input_tokens or 0,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
|
||||
cached_input_tokens=None, # This interface doesn't track cache tokens
|
||||
cache_write_tokens=None,
|
||||
reasoning_tokens=None,
|
||||
)
|
||||
|
||||
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
|
||||
def _process_group(
|
||||
group: list[ReasoningMessage | HiddenReasoningMessage | AssistantMessage],
|
||||
@@ -726,10 +765,10 @@ class SimpleAnthropicStreamingInterface:
|
||||
prev_message_type = new_message_type
|
||||
# print(f"Yielding message: {message}")
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
|
||||
@@ -26,6 +26,7 @@ from letta.schemas.letta_message_content import (
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
@@ -43,9 +44,11 @@ class SimpleGeminiStreamingInterface:
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
|
||||
# self.messages = messages
|
||||
# self.tools = tools
|
||||
@@ -89,6 +92,9 @@ class SimpleGeminiStreamingInterface:
|
||||
# Raw usage from provider (for transparent logging in provider trace)
|
||||
self.raw_usage: dict | None = None
|
||||
|
||||
# Track cancellation status
|
||||
self.stream_was_cancelled: bool = False
|
||||
|
||||
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
|
||||
"""This is (unusually) in chunked format, instead of merged"""
|
||||
for content in self.content_parts:
|
||||
@@ -116,6 +122,27 @@ class SimpleGeminiStreamingInterface:
|
||||
"""Return all finalized tool calls collected during this message (parallel supported)."""
|
||||
return list(self.collected_tool_calls)
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
|
||||
Note:
|
||||
Gemini uses `thinking_tokens` instead of `reasoning_tokens` (OpenAI o1/o3).
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=self.input_tokens or 0,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
|
||||
# Gemini: input_tokens is already total, cached_tokens is a subset (not additive)
|
||||
cached_input_tokens=self.cached_tokens,
|
||||
cache_write_tokens=None, # Gemini doesn't report cache write tokens
|
||||
reasoning_tokens=self.thinking_tokens, # Gemini uses thinking_tokens
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncIterator[GenerateContentResponse],
|
||||
@@ -137,10 +164,10 @@ class SimpleGeminiStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
@@ -164,7 +191,11 @@ class SimpleGeminiStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
logger.info("GeminiStreamingInterface: Stream processing complete.")
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(f"GeminiStreamingInterface: Stream processing complete. stream was cancelled: {self.stream_was_cancelled}")
|
||||
|
||||
async def _process_event(
|
||||
self,
|
||||
|
||||
@@ -54,6 +54,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||||
from letta.streaming_utils import (
|
||||
@@ -82,6 +83,7 @@ class OpenAIStreamingInterface:
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.use_assistant_message = use_assistant_message
|
||||
|
||||
@@ -93,6 +95,7 @@ class OpenAIStreamingInterface:
|
||||
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
|
||||
@@ -191,6 +194,28 @@ class OpenAIStreamingInterface:
|
||||
function=FunctionCall(arguments=self._get_current_function_arguments(), name=function_name),
|
||||
)
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Use actual tokens if available, otherwise fall back to estimated
|
||||
input_tokens = self.input_tokens if self.input_tokens else self.fallback_input_tokens
|
||||
output_tokens = self.output_tokens if self.output_tokens else self.fallback_output_tokens
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=input_tokens or 0,
|
||||
completion_tokens=output_tokens or 0,
|
||||
total_tokens=(input_tokens or 0) + (output_tokens or 0),
|
||||
# OpenAI: input_tokens is already total, cached_tokens is a subset (not additive)
|
||||
cached_input_tokens=None, # This interface doesn't track cache tokens
|
||||
cache_write_tokens=None,
|
||||
reasoning_tokens=None, # This interface doesn't track reasoning tokens
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ChatCompletionChunk],
|
||||
@@ -226,14 +251,15 @@ class OpenAIStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -267,6 +293,10 @@ class OpenAIStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"OpenAIStreamingInterface: Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
@@ -561,9 +591,11 @@ class SimpleOpenAIStreamingInterface:
|
||||
model: str = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
|
||||
@@ -662,6 +694,28 @@ class SimpleOpenAIStreamingInterface:
|
||||
raise ValueError("No tool calls available")
|
||||
return calls[0]
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Use actual tokens if available, otherwise fall back to estimated
|
||||
input_tokens = self.input_tokens if self.input_tokens else self.fallback_input_tokens
|
||||
output_tokens = self.output_tokens if self.output_tokens else self.fallback_output_tokens
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=input_tokens or 0,
|
||||
completion_tokens=output_tokens or 0,
|
||||
total_tokens=(input_tokens or 0) + (output_tokens or 0),
|
||||
# OpenAI: input_tokens is already total, cached_tokens is a subset (not additive)
|
||||
cached_input_tokens=self.cached_tokens,
|
||||
cache_write_tokens=None, # OpenAI doesn't have cache write tokens
|
||||
reasoning_tokens=self.reasoning_tokens,
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ChatCompletionChunk],
|
||||
@@ -715,14 +769,15 @@ class SimpleOpenAIStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -764,6 +819,10 @@ class SimpleOpenAIStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"SimpleOpenAIStreamingInterface: Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
@@ -932,6 +991,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
model: str = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.is_openai_proxy = is_openai_proxy
|
||||
self.messages = messages
|
||||
@@ -946,6 +1006,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
self.message_id = None
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
@@ -1063,6 +1124,24 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
raise ValueError("No tool calls available")
|
||||
return calls[0]
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=self.input_tokens or 0,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
|
||||
# OpenAI Responses API: input_tokens is already total
|
||||
cached_input_tokens=self.cached_tokens,
|
||||
cache_write_tokens=None, # OpenAI doesn't have cache write tokens
|
||||
reasoning_tokens=self.reasoning_tokens,
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ResponseStreamEvent],
|
||||
@@ -1102,14 +1181,15 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
)
|
||||
# Continue to next event rather than killing the stream
|
||||
continue
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -1136,6 +1216,10 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"ResponsesAPI Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
|
||||
@@ -48,6 +48,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
UsageStatistics,
|
||||
)
|
||||
from letta.schemas.response_format import JsonSchemaResponseFormat
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.settings import model_settings
|
||||
|
||||
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
||||
@@ -777,6 +778,18 @@ class AnthropicClient(LLMClientBase):
|
||||
if not block.get("text", "").strip():
|
||||
block["text"] = "."
|
||||
|
||||
# Strip trailing whitespace from final assistant message
|
||||
# Anthropic API rejects messages where "final assistant content cannot end with trailing whitespace"
|
||||
if is_final_assistant:
|
||||
if isinstance(content, str):
|
||||
msg["content"] = content.rstrip()
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
# Find and strip trailing whitespace from the last text block
|
||||
for block in reversed(content):
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
block["text"] = block.get("text", "").rstrip()
|
||||
break
|
||||
|
||||
try:
|
||||
count_params = {
|
||||
"model": model or "claude-3-7-sonnet-20250219",
|
||||
@@ -976,6 +989,35 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
return super().handle_llm_error(e)
|
||||
|
||||
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||
"""Extract usage statistics from Anthropic response and return as LettaUsageStatistics."""
|
||||
if not response_data:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
response = AnthropicMessage(**response_data)
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
|
||||
# Extract cache data if available (None means not reported, 0 means reported as 0)
|
||||
cache_read_tokens = None
|
||||
cache_creation_tokens = None
|
||||
if hasattr(response.usage, "cache_read_input_tokens"):
|
||||
cache_read_tokens = response.usage.cache_read_input_tokens
|
||||
if hasattr(response.usage, "cache_creation_input_tokens"):
|
||||
cache_creation_tokens = response.usage.cache_creation_input_tokens
|
||||
|
||||
# Per Anthropic docs: "Total input tokens in a request is the summation of
|
||||
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
|
||||
actual_input_tokens = prompt_tokens + (cache_read_tokens or 0) + (cache_creation_tokens or 0)
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=actual_input_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=actual_input_tokens + completion_tokens,
|
||||
cached_input_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_creation_tokens,
|
||||
)
|
||||
|
||||
# TODO: Input messages doesn't get used here
|
||||
# TODO: Clean up this interface
|
||||
@trace_method
|
||||
@@ -1020,10 +1062,13 @@ class AnthropicClient(LLMClientBase):
|
||||
}
|
||||
"""
|
||||
response = AnthropicMessage(**response_data)
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
finish_reason = remap_finish_reason(str(response.stop_reason))
|
||||
|
||||
# Extract usage via centralized method
|
||||
from letta.schemas.enums import ProviderType
|
||||
|
||||
usage_stats = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.anthropic)
|
||||
|
||||
content = None
|
||||
reasoning_content = None
|
||||
reasoning_content_signature = None
|
||||
@@ -1088,35 +1133,12 @@ class AnthropicClient(LLMClientBase):
|
||||
),
|
||||
)
|
||||
|
||||
# Build prompt tokens details with cache data if available
|
||||
prompt_tokens_details = None
|
||||
cache_read_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
if hasattr(response.usage, "cache_read_input_tokens") or hasattr(response.usage, "cache_creation_input_tokens"):
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
|
||||
|
||||
cache_read_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
cache_creation_tokens = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
|
||||
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
|
||||
# Per Anthropic docs: "Total input tokens in a request is the summation of
|
||||
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
|
||||
actual_input_tokens = prompt_tokens + cache_read_tokens + cache_creation_tokens
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
id=response.id,
|
||||
choices=[choice],
|
||||
created=get_utc_time_int(),
|
||||
model=response.model,
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=actual_input_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=actual_input_tokens + completion_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
),
|
||||
usage=usage_stats,
|
||||
)
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
|
||||
|
||||
@@ -54,6 +54,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
UsageStatistics,
|
||||
)
|
||||
from letta.schemas.providers.chatgpt_oauth import ChatGPTOAuthCredentials, ChatGPTOAuthProvider
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -511,6 +512,25 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
# Response should already be in ChatCompletion format after transformation
|
||||
return ChatCompletionResponse(**response_data)
|
||||
|
||||
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||
"""Extract usage statistics from ChatGPT OAuth response and return as LettaUsageStatistics."""
|
||||
if not response_data:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
usage = response_data.get("usage")
|
||||
if not usage:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
prompt_tokens = usage.get("prompt_tokens") or 0
|
||||
completion_tokens = usage.get("completion_tokens") or 0
|
||||
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def stream_async(
|
||||
self,
|
||||
|
||||
@@ -39,6 +39,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.settings import model_settings, settings
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
@@ -415,6 +416,34 @@ class GoogleVertexClient(LLMClientBase):
|
||||
|
||||
return request_data
|
||||
|
||||
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||
"""Extract usage statistics from Gemini response and return as LettaUsageStatistics."""
|
||||
if not response_data:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
response = GenerateContentResponse(**response_data)
|
||||
if not response.usage_metadata:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
cached_tokens = None
|
||||
if (
|
||||
hasattr(response.usage_metadata, "cached_content_token_count")
|
||||
and response.usage_metadata.cached_content_token_count is not None
|
||||
):
|
||||
cached_tokens = response.usage_metadata.cached_content_token_count
|
||||
|
||||
reasoning_tokens = None
|
||||
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None:
|
||||
reasoning_tokens = response.usage_metadata.thoughts_token_count
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=response.usage_metadata.prompt_token_count or 0,
|
||||
completion_tokens=response.usage_metadata.candidates_token_count or 0,
|
||||
total_tokens=response.usage_metadata.total_token_count or 0,
|
||||
cached_input_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
@@ -642,36 +671,10 @@ class GoogleVertexClient(LLMClientBase):
|
||||
# "totalTokenCount": 36
|
||||
# }
|
||||
if response.usage_metadata:
|
||||
# Extract cache token data if available (Gemini uses cached_content_token_count)
|
||||
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
|
||||
prompt_tokens_details = None
|
||||
if (
|
||||
hasattr(response.usage_metadata, "cached_content_token_count")
|
||||
and response.usage_metadata.cached_content_token_count is not None
|
||||
):
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
|
||||
# Extract usage via centralized method
|
||||
from letta.schemas.enums import ProviderType
|
||||
|
||||
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
|
||||
cached_tokens=response.usage_metadata.cached_content_token_count,
|
||||
)
|
||||
|
||||
# Extract thinking/reasoning token data if available (Gemini uses thoughts_token_count)
|
||||
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
|
||||
completion_tokens_details = None
|
||||
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None:
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails
|
||||
|
||||
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
|
||||
reasoning_tokens=response.usage_metadata.thoughts_token_count,
|
||||
)
|
||||
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=response.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response.usage_metadata.candidates_token_count,
|
||||
total_tokens=response.usage_metadata.total_token_count,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
)
|
||||
usage = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.google_ai)
|
||||
else:
|
||||
# Count it ourselves using the Gemini token counting API
|
||||
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||
|
||||
@@ -43,6 +43,14 @@ class GroqClient(OpenAIClient):
|
||||
data["logprobs"] = False
|
||||
data["n"] = 1
|
||||
|
||||
# for openai.BadRequestError: Error code: 400 - {'error': {'message': "'messages.2' : for 'role:assistant' the following must be satisfied[('messages.2' : property 'reasoning_content' is unsupported)]", 'type': 'invalid_request_error'}}
|
||||
if "messages" in data:
|
||||
for message in data["messages"]:
|
||||
if "reasoning_content" in message:
|
||||
del message["reasoning_content"]
|
||||
if "reasoning_content_signature" in message:
|
||||
del message["reasoning_content_signature"]
|
||||
|
||||
return data
|
||||
|
||||
@trace_method
|
||||
|
||||
@@ -167,8 +167,8 @@ def create(
|
||||
printd("unsetting function_call because functions is None")
|
||||
function_call = None
|
||||
|
||||
# openai
|
||||
if llm_config.model_endpoint_type == "openai":
|
||||
# openai and openrouter (OpenAI-compatible)
|
||||
if llm_config.model_endpoint_type in ["openai", "openrouter"]:
|
||||
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
|
||||
# only is a problem if we are *not* using an openai proxy
|
||||
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
|
||||
|
||||
@@ -93,6 +93,21 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.minimax:
|
||||
from letta.llm_api.minimax_client import MiniMaxClient
|
||||
|
||||
return MiniMaxClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.openrouter:
|
||||
# OpenRouter uses OpenAI-compatible API, so we can use the OpenAI client directly
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.deepseek:
|
||||
from letta.llm_api.deepseek_client import DeepseekClient
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -43,6 +44,10 @@ class LLMClientBase:
|
||||
self._telemetry_run_id: Optional[str] = None
|
||||
self._telemetry_step_id: Optional[str] = None
|
||||
self._telemetry_call_type: Optional[str] = None
|
||||
self._telemetry_org_id: Optional[str] = None
|
||||
self._telemetry_user_id: Optional[str] = None
|
||||
self._telemetry_compaction_settings: Optional[Dict] = None
|
||||
self._telemetry_llm_config: Optional[Dict] = None
|
||||
|
||||
def set_telemetry_context(
|
||||
self,
|
||||
@@ -52,6 +57,10 @@ class LLMClientBase:
|
||||
run_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
call_type: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
compaction_settings: Optional[Dict] = None,
|
||||
llm_config: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Set telemetry context for provider trace logging."""
|
||||
self._telemetry_manager = telemetry_manager
|
||||
@@ -60,6 +69,14 @@ class LLMClientBase:
|
||||
self._telemetry_run_id = run_id
|
||||
self._telemetry_step_id = step_id
|
||||
self._telemetry_call_type = call_type
|
||||
self._telemetry_org_id = org_id
|
||||
self._telemetry_user_id = user_id
|
||||
self._telemetry_compaction_settings = compaction_settings
|
||||
self._telemetry_llm_config = llm_config
|
||||
|
||||
def extract_usage_statistics(self, response_data: Optional[dict], llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||
"""Provider-specific usage parsing hook (override in subclasses). Returns LettaUsageStatistics."""
|
||||
return LettaUsageStatistics()
|
||||
|
||||
async def request_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""Wrapper around request_async that logs telemetry for all requests including errors.
|
||||
@@ -96,6 +113,10 @@ class LLMClientBase:
|
||||
agent_tags=self._telemetry_agent_tags,
|
||||
run_id=self._telemetry_run_id,
|
||||
call_type=self._telemetry_call_type,
|
||||
org_id=self._telemetry_org_id,
|
||||
user_id=self._telemetry_user_id,
|
||||
compaction_settings=self._telemetry_compaction_settings,
|
||||
llm_config=self._telemetry_llm_config,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -137,6 +158,10 @@ class LLMClientBase:
|
||||
agent_tags=self._telemetry_agent_tags,
|
||||
run_id=self._telemetry_run_id,
|
||||
call_type=self._telemetry_call_type,
|
||||
org_id=self._telemetry_org_id,
|
||||
user_id=self._telemetry_user_id,
|
||||
compaction_settings=self._telemetry_compaction_settings,
|
||||
llm_config=self._telemetry_llm_config,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
175
letta/llm_api/minimax_client.py
Normal file
175
letta/llm_api/minimax_client.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import anthropic
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import BetaMessage, BetaRawMessageStreamEvent
|
||||
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MiniMaxClient(AnthropicClient):
|
||||
"""
|
||||
MiniMax LLM client using Anthropic-compatible API.
|
||||
|
||||
Uses the beta messages API to ensure compatibility with Anthropic streaming interfaces.
|
||||
Temperature must be in range (0.0, 1.0].
|
||||
Some Anthropic params are ignored: top_k, stop_sequences, service_tier, etc.
|
||||
|
||||
Documentation: https://platform.minimax.io/docs/api-reference/text-anthropic-api
|
||||
|
||||
Note: We override client creation to always use llm_config.model_endpoint as base_url
|
||||
(required for BYOK where provider_name is user's custom name, not "minimax").
|
||||
We also override request methods to avoid passing Anthropic-specific beta headers.
|
||||
"""
|
||||
|
||||
@trace_method
|
||||
def _get_anthropic_client(
|
||||
self, llm_config: LLMConfig, async_client: bool = False
|
||||
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
"""Create Anthropic client configured for MiniMax API."""
|
||||
api_key, _, _ = self.get_byok_overrides(llm_config)
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.minimax_api_key
|
||||
|
||||
# Always use model_endpoint for base_url (works for both base and BYOK providers)
|
||||
base_url = llm_config.model_endpoint
|
||||
|
||||
if async_client:
|
||||
return anthropic.AsyncAnthropic(api_key=api_key, base_url=base_url, max_retries=model_settings.anthropic_max_retries)
|
||||
return anthropic.Anthropic(api_key=api_key, base_url=base_url, max_retries=model_settings.anthropic_max_retries)
|
||||
|
||||
@trace_method
|
||||
async def _get_anthropic_client_async(
|
||||
self, llm_config: LLMConfig, async_client: bool = False
|
||||
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
"""Create Anthropic client configured for MiniMax API (async version)."""
|
||||
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.minimax_api_key
|
||||
|
||||
# Always use model_endpoint for base_url (works for both base and BYOK providers)
|
||||
base_url = llm_config.model_endpoint
|
||||
|
||||
if async_client:
|
||||
return anthropic.AsyncAnthropic(api_key=api_key, base_url=base_url, max_retries=model_settings.anthropic_max_retries)
|
||||
return anthropic.Anthropic(api_key=api_key, base_url=base_url, max_retries=model_settings.anthropic_max_retries)
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Synchronous request to MiniMax API.
|
||||
|
||||
Uses beta messages API for compatibility with Anthropic streaming interfaces.
|
||||
"""
|
||||
client = self._get_anthropic_client(llm_config, async_client=False)
|
||||
|
||||
response: BetaMessage = client.beta.messages.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Asynchronous request to MiniMax API.
|
||||
|
||||
Uses beta messages API for compatibility with Anthropic streaming interfaces.
|
||||
"""
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
|
||||
try:
|
||||
response: BetaMessage = await client.beta.messages.create(**request_data)
|
||||
return response.model_dump()
|
||||
except ValueError as e:
|
||||
# Handle streaming fallback if needed (similar to Anthropic client)
|
||||
if "streaming is required" in str(e).lower():
|
||||
logger.warning(
|
||||
"[MiniMax] Non-streaming request rejected. Falling back to streaming mode. Error: %s",
|
||||
str(e),
|
||||
)
|
||||
return await self._request_via_streaming(request_data, llm_config, betas=[])
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
|
||||
"""
|
||||
Asynchronous streaming request to MiniMax API.
|
||||
|
||||
Uses beta messages API for compatibility with Anthropic streaming interfaces.
|
||||
"""
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
request_data["stream"] = True
|
||||
|
||||
try:
|
||||
return await client.beta.messages.create(**request_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming MiniMax request: {e}")
|
||||
raise e
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
agent_type: AgentType,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Build request data for MiniMax API.
|
||||
|
||||
Inherits most logic from AnthropicClient, with MiniMax-specific adjustments:
|
||||
- Temperature must be in range (0.0, 1.0]
|
||||
"""
|
||||
data = super().build_request_data(
|
||||
agent_type,
|
||||
messages,
|
||||
llm_config,
|
||||
tools,
|
||||
force_tool_call,
|
||||
requires_subsequent_tool_call,
|
||||
tool_return_truncation_chars,
|
||||
)
|
||||
|
||||
# MiniMax temperature range is (0.0, 1.0], recommended value: 1
|
||||
if data.get("temperature") is not None:
|
||||
temp = data["temperature"]
|
||||
if temp <= 0:
|
||||
data["temperature"] = 0.01 # Minimum valid value (exclusive of 0)
|
||||
logger.warning(f"[MiniMax] Temperature {temp} is invalid. Clamped to 0.01.")
|
||||
elif temp > 1.0:
|
||||
data["temperature"] = 1.0 # Maximum valid value
|
||||
logger.warning(f"[MiniMax] Temperature {temp} is invalid. Clamped to 1.0.")
|
||||
|
||||
# MiniMax ignores these Anthropic-specific parameters, but we can remove them
|
||||
# to avoid potential issues (they won't cause errors, just ignored)
|
||||
# Note: We don't remove them since MiniMax silently ignores them
|
||||
|
||||
return data
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
"""
|
||||
All MiniMax M2.x models support native interleaved thinking.
|
||||
|
||||
Unlike Anthropic where only certain models (Claude 3.7+) support extended thinking,
|
||||
all MiniMax models natively support thinking blocks without beta headers.
|
||||
"""
|
||||
return True
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
"""MiniMax models support all tool choice modes."""
|
||||
return False
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
"""MiniMax doesn't currently advertise structured output support."""
|
||||
return False
|
||||
@@ -60,6 +60,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
)
|
||||
from letta.schemas.openai.responses_request import ResponsesRequest
|
||||
from letta.schemas.response_format import JsonSchemaResponseFormat
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -169,6 +170,7 @@ def supports_content_none(llm_config: LLMConfig) -> bool:
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
||||
api_key, _, _ = self.get_byok_overrides(llm_config)
|
||||
has_byok_key = api_key is not None # Track if we got a BYOK key
|
||||
|
||||
# Default to global OpenAI key when no BYOK override
|
||||
if not api_key:
|
||||
@@ -181,9 +183,11 @@ class OpenAIClient(LLMClientBase):
|
||||
llm_config.provider_name == "openrouter"
|
||||
)
|
||||
if is_openrouter:
|
||||
or_key = model_settings.openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
|
||||
if or_key:
|
||||
kwargs["api_key"] = or_key
|
||||
# Only use prod OpenRouter key if no BYOK key was provided
|
||||
if not has_byok_key:
|
||||
or_key = model_settings.openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
|
||||
if or_key:
|
||||
kwargs["api_key"] = or_key
|
||||
# Attach optional headers if provided
|
||||
headers = {}
|
||||
if model_settings.openrouter_referer:
|
||||
@@ -207,6 +211,7 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
async def _prepare_client_kwargs_async(self, llm_config: LLMConfig) -> dict:
|
||||
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
|
||||
has_byok_key = api_key is not None # Track if we got a BYOK key
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
@@ -216,9 +221,11 @@ class OpenAIClient(LLMClientBase):
|
||||
llm_config.provider_name == "openrouter"
|
||||
)
|
||||
if is_openrouter:
|
||||
or_key = model_settings.openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
|
||||
if or_key:
|
||||
kwargs["api_key"] = or_key
|
||||
# Only use prod OpenRouter key if no BYOK key was provided
|
||||
if not has_byok_key:
|
||||
or_key = model_settings.openrouter_api_key or os.environ.get("OPENROUTER_API_KEY")
|
||||
if or_key:
|
||||
kwargs["api_key"] = or_key
|
||||
headers = {}
|
||||
if model_settings.openrouter_referer:
|
||||
headers["HTTP-Referer"] = model_settings.openrouter_referer
|
||||
@@ -591,6 +598,66 @@ class OpenAIClient(LLMClientBase):
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return is_openai_reasoning_model(llm_config.model)
|
||||
|
||||
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||
"""Extract usage statistics from OpenAI response and return as LettaUsageStatistics."""
|
||||
if not response_data:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
# Handle Responses API format (used by reasoning models like o1/o3)
|
||||
if response_data.get("object") == "response":
|
||||
usage = response_data.get("usage", {}) or {}
|
||||
prompt_tokens = usage.get("input_tokens") or 0
|
||||
completion_tokens = usage.get("output_tokens") or 0
|
||||
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
|
||||
|
||||
input_details = usage.get("input_tokens_details", {}) or {}
|
||||
cached_tokens = input_details.get("cached_tokens")
|
||||
|
||||
output_details = usage.get("output_tokens_details", {}) or {}
|
||||
reasoning_tokens = output_details.get("reasoning_tokens")
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cached_input_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
# Handle standard Chat Completions API format using pydantic models
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
try:
|
||||
completion = ChatCompletion.model_validate(response_data)
|
||||
except Exception:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
if not completion.usage:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
usage = completion.usage
|
||||
prompt_tokens = usage.prompt_tokens or 0
|
||||
completion_tokens = usage.completion_tokens or 0
|
||||
total_tokens = usage.total_tokens or (prompt_tokens + completion_tokens)
|
||||
|
||||
# Extract cached tokens from prompt_tokens_details
|
||||
cached_tokens = None
|
||||
if usage.prompt_tokens_details:
|
||||
cached_tokens = usage.prompt_tokens_details.cached_tokens
|
||||
|
||||
# Extract reasoning tokens from completion_tokens_details
|
||||
reasoning_tokens = None
|
||||
if usage.completion_tokens_details:
|
||||
reasoning_tokens = usage.completion_tokens_details.reasoning_tokens
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cached_input_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
@@ -607,30 +674,10 @@ class OpenAIClient(LLMClientBase):
|
||||
# See example payload in tests/integration_test_send_message_v2.py
|
||||
model = response_data.get("model")
|
||||
|
||||
# Extract usage
|
||||
usage = response_data.get("usage", {}) or {}
|
||||
prompt_tokens = usage.get("input_tokens") or 0
|
||||
completion_tokens = usage.get("output_tokens") or 0
|
||||
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
|
||||
# Extract usage via centralized method
|
||||
from letta.schemas.enums import ProviderType
|
||||
|
||||
# Extract detailed token breakdowns (Responses API uses input_tokens_details/output_tokens_details)
|
||||
prompt_tokens_details = None
|
||||
input_details = usage.get("input_tokens_details", {}) or {}
|
||||
if input_details.get("cached_tokens"):
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
|
||||
|
||||
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
|
||||
cached_tokens=input_details.get("cached_tokens") or 0,
|
||||
)
|
||||
|
||||
completion_tokens_details = None
|
||||
output_details = usage.get("output_tokens_details", {}) or {}
|
||||
if output_details.get("reasoning_tokens"):
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails
|
||||
|
||||
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
|
||||
reasoning_tokens=output_details.get("reasoning_tokens") or 0,
|
||||
)
|
||||
usage_stats = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.openai)
|
||||
|
||||
# Extract assistant message text from the outputs list
|
||||
outputs = response_data.get("output") or []
|
||||
@@ -698,13 +745,7 @@ class OpenAIClient(LLMClientBase):
|
||||
choices=[choice],
|
||||
created=int(response_data.get("created_at") or 0),
|
||||
model=model or (llm_config.model if hasattr(llm_config, "model") else None),
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
),
|
||||
usage=usage_stats,
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
1
letta/model_specs/__init__.py
Normal file
1
letta/model_specs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Model specification utilities for Letta."""
|
||||
120
letta/model_specs/litellm_model_specs.py
Normal file
120
letta/model_specs/litellm_model_specs.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Utility functions for working with litellm model specifications.
|
||||
|
||||
This module provides access to model specifications from the litellm model_prices_and_context_window.json file.
|
||||
The data is synced from: https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
from async_lru import alru_cache
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Path to the litellm model specs JSON file
|
||||
MODEL_SPECS_PATH = os.path.join(os.path.dirname(__file__), "model_prices_and_context_window.json")
|
||||
|
||||
|
||||
@alru_cache(maxsize=1)
|
||||
async def load_model_specs() -> dict:
|
||||
"""Load the litellm model specifications from the JSON file.
|
||||
|
||||
Returns:
|
||||
dict: The model specifications data
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the model specs file is not found
|
||||
json.JSONDecodeError: If the file is not valid JSON
|
||||
"""
|
||||
if not os.path.exists(MODEL_SPECS_PATH):
|
||||
logger.warning(f"Model specs file not found at {MODEL_SPECS_PATH}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
async with aiofiles.open(MODEL_SPECS_PATH, "r") as f:
|
||||
content = await f.read()
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse model specs JSON: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def get_model_spec(model_name: str) -> Optional[dict]:
|
||||
"""Get the specification for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model (e.g., "gpt-4o", "gpt-4o-mini")
|
||||
|
||||
Returns:
|
||||
Optional[dict]: The model specification if found, None otherwise
|
||||
"""
|
||||
specs = await load_model_specs()
|
||||
return specs.get(model_name)
|
||||
|
||||
|
||||
async def get_max_input_tokens(model_name: str) -> Optional[int]:
|
||||
"""Get the max input tokens for a model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
|
||||
Returns:
|
||||
Optional[int]: The max input tokens if found, None otherwise
|
||||
"""
|
||||
spec = await get_model_spec(model_name)
|
||||
if not spec:
|
||||
return None
|
||||
|
||||
return spec.get("max_input_tokens")
|
||||
|
||||
|
||||
async def get_max_output_tokens(model_name: str) -> Optional[int]:
|
||||
"""Get the max output tokens for a model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
|
||||
Returns:
|
||||
Optional[int]: The max output tokens if found, None otherwise
|
||||
"""
|
||||
spec = await get_model_spec(model_name)
|
||||
if not spec:
|
||||
return None
|
||||
|
||||
# Try max_output_tokens first, fall back to max_tokens
|
||||
return spec.get("max_output_tokens") or spec.get("max_tokens")
|
||||
|
||||
|
||||
async def get_context_window(model_name: str) -> Optional[int]:
|
||||
"""Get the context window size for a model.
|
||||
|
||||
For most models, this is the max_input_tokens.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
|
||||
Returns:
|
||||
Optional[int]: The context window size if found, None otherwise
|
||||
"""
|
||||
return await get_max_input_tokens(model_name)
|
||||
|
||||
|
||||
async def get_litellm_provider(model_name: str) -> Optional[str]:
|
||||
"""Get the litellm provider for a model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name if found, None otherwise
|
||||
"""
|
||||
spec = await get_model_spec(model_name)
|
||||
if not spec:
|
||||
return None
|
||||
|
||||
return spec.get("litellm_provider")
|
||||
31925
letta/model_specs/model_prices_and_context_window.json
Normal file
31925
letta/model_specs/model_prices_and_context_window.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -31,6 +31,7 @@ from letta.orm.prompt import Prompt
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.provider_model import ProviderModel
|
||||
from letta.orm.provider_trace import ProviderTrace
|
||||
from letta.orm.provider_trace_metadata import ProviderTraceMetadata
|
||||
from letta.orm.run import Run
|
||||
from letta.orm.run_metrics import RunMetrics
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||
|
||||
@@ -46,8 +46,8 @@ class Archive(SqlalchemyBase, OrganizationMixin):
|
||||
default=VectorDBProvider.NATIVE,
|
||||
doc="The vector database provider used for this archive's passages",
|
||||
)
|
||||
embedding_config: Mapped[dict] = mapped_column(
|
||||
EmbeddingConfigColumn, nullable=False, doc="Embedding configuration for passages in this archive"
|
||||
embedding_config: Mapped[Optional[dict]] = mapped_column(
|
||||
EmbeddingConfigColumn, nullable=True, doc="Embedding configuration for passages in this archive"
|
||||
)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive")
|
||||
_vector_db_namespace: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Private field for vector database namespace")
|
||||
|
||||
@@ -25,18 +25,18 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
|
||||
text: Mapped[str] = mapped_column(doc="Passage text content")
|
||||
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
|
||||
embedding_config: Mapped[Optional[dict]] = mapped_column(EmbeddingConfigColumn, nullable=True, doc="Embedding configuration")
|
||||
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
||||
# dual storage: json column for fast retrieval, junction table for efficient queries
|
||||
tags: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="Tags associated with this passage")
|
||||
|
||||
# Vector embedding field based on database type
|
||||
# Vector embedding field based on database type - nullable for text-only search
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM), nullable=True)
|
||||
else:
|
||||
embedding = Column(CommonVector)
|
||||
embedding = Column(CommonVector, nullable=True)
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
@@ -41,6 +42,11 @@ class Provider(SqlalchemyBase, OrganizationMixin):
|
||||
api_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted API key or secret key for the provider.")
|
||||
access_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted access key for the provider.")
|
||||
|
||||
# sync tracking
|
||||
last_synced: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, doc="Last time models were synced for this provider."
|
||||
)
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
||||
models: Mapped[list["ProviderModel"]] = relationship("ProviderModel", back_populates="provider", cascade="all, delete-orphan")
|
||||
|
||||
@@ -32,5 +32,15 @@ class ProviderTrace(SqlalchemyBase, OrganizationMixin):
|
||||
String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)"
|
||||
)
|
||||
|
||||
# v2 protocol fields
|
||||
org_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the organization")
|
||||
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
|
||||
compaction_settings: Mapped[Optional[dict]] = mapped_column(
|
||||
JSON, nullable=True, doc="Compaction/summarization settings (summarization calls only)"
|
||||
)
|
||||
llm_config: Mapped[Optional[dict]] = mapped_column(
|
||||
JSON, nullable=True, doc="LLM configuration used for this call (non-summarization calls only)"
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
||||
|
||||
45
letta/orm/provider_trace_metadata.py
Normal file
45
letta/orm/provider_trace_metadata.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, DateTime, Index, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.provider_trace import ProviderTraceMetadata as PydanticProviderTraceMetadata
|
||||
|
||||
|
||||
class ProviderTraceMetadata(SqlalchemyBase, OrganizationMixin):
|
||||
"""Metadata-only provider trace storage (no request/response JSON)."""
|
||||
|
||||
__tablename__ = "provider_trace_metadata"
|
||||
__pydantic_model__ = PydanticProviderTraceMetadata
|
||||
__table_args__ = (
|
||||
Index("ix_provider_trace_metadata_step_id", "step_id"),
|
||||
UniqueConstraint("id", name="uq_provider_trace_metadata_id"),
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), primary_key=True, server_default=func.now(), doc="Timestamp when the trace was created"
|
||||
)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, doc="Unique provider trace identifier", default=lambda: f"provider_trace-{uuid.uuid4()}"
|
||||
)
|
||||
step_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the step that this trace is associated with")
|
||||
|
||||
# Telemetry context fields
|
||||
agent_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the agent that generated this trace")
|
||||
agent_tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True, doc="Tags associated with the agent for filtering")
|
||||
call_type: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Type of call (agent_step, summarization, etc.)")
|
||||
run_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the run this trace is associated with")
|
||||
source: Mapped[Optional[str]] = mapped_column(
|
||||
String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)"
|
||||
)
|
||||
|
||||
# v2 protocol fields
|
||||
org_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the organization")
|
||||
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
||||
@@ -42,7 +42,9 @@ def handle_db_timeout(func):
|
||||
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
||||
except QueryCanceledError as e:
|
||||
logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
logger.error(
|
||||
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
|
||||
)
|
||||
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
|
||||
|
||||
return wrapper
|
||||
@@ -56,7 +58,9 @@ def handle_db_timeout(func):
|
||||
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
||||
except QueryCanceledError as e:
|
||||
logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
logger.error(
|
||||
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
|
||||
)
|
||||
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
|
||||
|
||||
return async_wrapper
|
||||
@@ -207,6 +211,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
"""
|
||||
Constructs the query for listing records.
|
||||
"""
|
||||
# Security check: if the model has organization_id column, actor should be provided
|
||||
if actor is None and hasattr(cls, "organization_id"):
|
||||
logger.warning(f"SECURITY: Listing org-scoped model {cls.__name__} without actor. This bypasses organization filtering.")
|
||||
|
||||
query = select(cls)
|
||||
|
||||
if join_model and join_conditions:
|
||||
@@ -446,6 +454,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
):
|
||||
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
|
||||
|
||||
# Security check: if the model has organization_id column, actor should be provided
|
||||
# to ensure proper org-scoping. Log a warning if actor is None.
|
||||
if actor is None and hasattr(cls, "organization_id"):
|
||||
logger.warning(
|
||||
f"SECURITY: Reading org-scoped model {cls.__name__} without actor. "
|
||||
f"IDs: {identifiers}. This bypasses organization filtering."
|
||||
)
|
||||
|
||||
# Start the query
|
||||
query = select(cls)
|
||||
# Collect query conditions for better error reporting
|
||||
@@ -681,6 +697,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
**kwargs,
|
||||
):
|
||||
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
||||
|
||||
# Security check: if the model has organization_id column, actor should be provided
|
||||
if actor is None and hasattr(cls, "organization_id"):
|
||||
logger.warning(
|
||||
f"SECURITY: Calculating size for org-scoped model {cls.__name__} without actor. This bypasses organization filtering."
|
||||
)
|
||||
query = select(func.count(1)).select_from(cls)
|
||||
|
||||
if actor:
|
||||
|
||||
@@ -17,7 +17,7 @@ class ArchiveBase(OrmMetadataBase):
|
||||
vector_db_provider: VectorDBProvider = Field(
|
||||
default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages"
|
||||
)
|
||||
embedding_config: EmbeddingConfig = Field(..., description="Embedding configuration for passages in this archive")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="Embedding configuration for passages in this archive")
|
||||
metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata")
|
||||
|
||||
|
||||
|
||||
@@ -63,11 +63,14 @@ class ProviderType(str, Enum):
|
||||
hugging_face = "hugging-face"
|
||||
letta = "letta"
|
||||
lmstudio_openai = "lmstudio_openai"
|
||||
minimax = "minimax"
|
||||
mistral = "mistral"
|
||||
ollama = "ollama"
|
||||
openai = "openai"
|
||||
together = "together"
|
||||
vllm = "vllm"
|
||||
sglang = "sglang"
|
||||
openrouter = "openrouter"
|
||||
xai = "xai"
|
||||
zai = "zai"
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.validators import AgentId, BlockId
|
||||
|
||||
|
||||
class ManagerType(str, Enum):
|
||||
@@ -93,43 +94,43 @@ class RoundRobinManagerUpdate(ManagerConfig):
|
||||
|
||||
class SupervisorManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.supervisor] = Field(ManagerType.supervisor, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
manager_agent_id: AgentId = Field(..., description="")
|
||||
|
||||
|
||||
class SupervisorManagerUpdate(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.supervisor] = Field(ManagerType.supervisor, description="")
|
||||
manager_agent_id: Optional[str] = Field(..., description="")
|
||||
manager_agent_id: Optional[AgentId] = Field(..., description="")
|
||||
|
||||
|
||||
class DynamicManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.dynamic] = Field(ManagerType.dynamic, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
manager_agent_id: AgentId = Field(..., description="")
|
||||
termination_token: Optional[str] = Field("DONE!", description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class DynamicManagerUpdate(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.dynamic] = Field(ManagerType.dynamic, description="")
|
||||
manager_agent_id: Optional[str] = Field(None, description="")
|
||||
manager_agent_id: Optional[AgentId] = Field(None, description="")
|
||||
termination_token: Optional[str] = Field(None, description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class SleeptimeManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.sleeptime] = Field(ManagerType.sleeptime, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
manager_agent_id: AgentId = Field(..., description="")
|
||||
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class SleeptimeManagerUpdate(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.sleeptime] = Field(ManagerType.sleeptime, description="")
|
||||
manager_agent_id: Optional[str] = Field(None, description="")
|
||||
manager_agent_id: Optional[AgentId] = Field(None, description="")
|
||||
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class VoiceSleeptimeManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
manager_agent_id: AgentId = Field(..., description="")
|
||||
max_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
|
||||
@@ -142,7 +143,7 @@ class VoiceSleeptimeManager(ManagerConfig):
|
||||
|
||||
class VoiceSleeptimeManagerUpdate(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
|
||||
manager_agent_id: Optional[str] = Field(None, description="")
|
||||
manager_agent_id: Optional[AgentId] = Field(None, description="")
|
||||
max_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
|
||||
@@ -170,11 +171,11 @@ ManagerConfigUpdateUnion = Annotated[
|
||||
|
||||
|
||||
class GroupCreate(BaseModel):
|
||||
agent_ids: List[str] = Field(..., description="")
|
||||
agent_ids: List[AgentId] = Field(..., description="")
|
||||
description: str = Field(..., description="")
|
||||
manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="")
|
||||
project_id: Optional[str] = Field(None, description="The associated project id.")
|
||||
shared_block_ids: List[str] = Field([], description="", deprecated=True)
|
||||
shared_block_ids: List[BlockId] = Field([], description="", deprecated=True)
|
||||
hidden: Optional[bool] = Field(
|
||||
None,
|
||||
description="If set to True, the group will be hidden.",
|
||||
@@ -190,8 +191,8 @@ class InternalTemplateGroupCreate(GroupCreate):
|
||||
|
||||
|
||||
class GroupUpdate(BaseModel):
|
||||
agent_ids: Optional[List[str]] = Field(None, description="")
|
||||
agent_ids: Optional[List[AgentId]] = Field(None, description="")
|
||||
description: Optional[str] = Field(None, description="")
|
||||
manager_config: Optional[ManagerConfigUpdateUnion] = Field(None, description="")
|
||||
project_id: Optional[str] = Field(None, description="The associated project id.")
|
||||
shared_block_ids: Optional[List[str]] = Field(None, description="", deprecated=True)
|
||||
shared_block_ids: Optional[List[BlockId]] = Field(None, description="", deprecated=True)
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.validators import AgentId, BlockId
|
||||
|
||||
|
||||
class IdentityType(str, Enum):
|
||||
@@ -57,8 +58,8 @@ class IdentityCreate(LettaBase):
|
||||
name: str = Field(..., description="The name of the identity.")
|
||||
identity_type: IdentityType = Field(..., description="The type of the identity.")
|
||||
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
|
||||
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
|
||||
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
|
||||
agent_ids: Optional[List[AgentId]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
|
||||
block_ids: Optional[List[BlockId]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
|
||||
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
|
||||
|
||||
|
||||
@@ -67,8 +68,8 @@ class IdentityUpsert(LettaBase):
|
||||
name: str = Field(..., description="The name of the identity.")
|
||||
identity_type: IdentityType = Field(..., description="The type of the identity.")
|
||||
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
|
||||
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
|
||||
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
|
||||
agent_ids: Optional[List[AgentId]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
|
||||
block_ids: Optional[List[BlockId]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
|
||||
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
|
||||
|
||||
|
||||
@@ -76,8 +77,8 @@ class IdentityUpdate(LettaBase):
|
||||
identifier_key: Optional[str] = Field(None, description="External, user-generated identifier key of the identity.")
|
||||
name: Optional[str] = Field(None, description="The name of the identity.")
|
||||
identity_type: Optional[IdentityType] = Field(None, description="The type of the identity.")
|
||||
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
|
||||
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
|
||||
agent_ids: Optional[List[AgentId]] = Field(None, description="The agent ids that are associated with the identity.", deprecated=True)
|
||||
block_ids: Optional[List[BlockId]] = Field(None, description="The IDs of the blocks associated with the identity.", deprecated=True)
|
||||
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
|
||||
|
||||
|
||||
|
||||
@@ -7,8 +7,10 @@ from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
from letta.schemas.letta_message_content import (
|
||||
LettaAssistantMessageContentUnion,
|
||||
LettaToolReturnContentUnion,
|
||||
LettaUserMessageContentUnion,
|
||||
get_letta_assistant_message_content_union_str_json_schema,
|
||||
get_letta_tool_return_content_union_str_json_schema,
|
||||
get_letta_user_message_content_union_str_json_schema,
|
||||
)
|
||||
|
||||
@@ -35,7 +37,11 @@ class ApprovalReturn(MessageReturn):
|
||||
|
||||
class ToolReturn(MessageReturn):
|
||||
type: Literal[MessageReturnType.tool] = Field(default=MessageReturnType.tool, description="The message type to be created.")
|
||||
tool_return: str
|
||||
tool_return: Union[str, List[LettaToolReturnContentUnion]] = Field(
|
||||
...,
|
||||
description="The tool return value - either a string or list of content parts (text/image)",
|
||||
json_schema_extra=get_letta_tool_return_content_union_str_json_schema(),
|
||||
)
|
||||
status: Literal["success", "error"]
|
||||
tool_call_id: str
|
||||
stdout: Optional[List[str]] = None
|
||||
@@ -563,6 +569,10 @@ class SystemMessageListResult(UpdateSystemMessage):
|
||||
default=None,
|
||||
description="The unique identifier of the agent that owns the message.",
|
||||
)
|
||||
conversation_id: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the conversation that the message belongs to.",
|
||||
)
|
||||
|
||||
created_at: datetime = Field(..., description="The time the message was created in ISO format.")
|
||||
|
||||
@@ -581,6 +591,10 @@ class UserMessageListResult(UpdateUserMessage):
|
||||
default=None,
|
||||
description="The unique identifier of the agent that owns the message.",
|
||||
)
|
||||
conversation_id: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the conversation that the message belongs to.",
|
||||
)
|
||||
|
||||
created_at: datetime = Field(..., description="The time the message was created in ISO format.")
|
||||
|
||||
@@ -599,6 +613,10 @@ class ReasoningMessageListResult(UpdateReasoningMessage):
|
||||
default=None,
|
||||
description="The unique identifier of the agent that owns the message.",
|
||||
)
|
||||
conversation_id: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the conversation that the message belongs to.",
|
||||
)
|
||||
|
||||
created_at: datetime = Field(..., description="The time the message was created in ISO format.")
|
||||
|
||||
@@ -617,6 +635,10 @@ class AssistantMessageListResult(UpdateAssistantMessage):
|
||||
default=None,
|
||||
description="The unique identifier of the agent that owns the message.",
|
||||
)
|
||||
conversation_id: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the conversation that the message belongs to.",
|
||||
)
|
||||
|
||||
created_at: datetime = Field(..., description="The time the message was created in ISO format.")
|
||||
|
||||
|
||||
@@ -138,6 +138,48 @@ def get_letta_user_message_content_union_str_json_schema():
|
||||
}
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Tool Return Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
LettaToolReturnContentUnion = Annotated[
|
||||
Union[TextContent, ImageContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_tool_return_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
{"$ref": "#/components/schemas/ImageContent"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
"image": "#/components/schemas/ImageContent",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_letta_tool_return_content_union_str_json_schema():
|
||||
"""Schema that accepts either string or list of content parts for tool returns."""
|
||||
return {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaToolReturnContentUnion",
|
||||
},
|
||||
},
|
||||
{"type": "string"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Assistant Content Types
|
||||
# -------------------------------
|
||||
|
||||
@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MES
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_message_content import LettaMessageContentUnion
|
||||
from letta.schemas.message import MessageCreate, MessageCreateUnion, MessageRole
|
||||
from letta.validators import AgentId
|
||||
|
||||
|
||||
class ClientToolSchema(BaseModel):
|
||||
@@ -125,12 +126,33 @@ class LettaStreamingRequest(LettaRequest):
|
||||
)
|
||||
|
||||
|
||||
class ConversationMessageRequest(LettaRequest):
|
||||
"""Request for sending messages to a conversation. Streams by default."""
|
||||
|
||||
streaming: bool = Field(
|
||||
default=True,
|
||||
description="If True (default), returns a streaming response (Server-Sent Events). If False, returns a complete JSON response.",
|
||||
)
|
||||
stream_tokens: bool = Field(
|
||||
default=False,
|
||||
description="Flag to determine if individual tokens should be streamed, rather than streaming per step (only used when streaming=true).",
|
||||
)
|
||||
include_pings: bool = Field(
|
||||
default=True,
|
||||
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts (only used when streaming=true).",
|
||||
)
|
||||
background: bool = Field(
|
||||
default=False,
|
||||
description="Whether to process the request in the background (only used when streaming=true).",
|
||||
)
|
||||
|
||||
|
||||
class LettaAsyncRequest(LettaRequest):
|
||||
callback_url: Optional[str] = Field(None, description="Optional callback URL to POST to when the job completes")
|
||||
|
||||
|
||||
class LettaBatchRequest(LettaRequest):
|
||||
agent_id: str = Field(..., description="The ID of the agent to send this batch request for")
|
||||
agent_id: AgentId = Field(..., description="The ID of the agent to send this batch request for")
|
||||
|
||||
|
||||
class CreateBatch(BaseModel):
|
||||
|
||||
@@ -43,12 +43,14 @@ class LLMConfig(BaseModel):
|
||||
"koboldcpp",
|
||||
"vllm",
|
||||
"hugging-face",
|
||||
"minimax",
|
||||
"mistral",
|
||||
"together", # completions endpoint
|
||||
"bedrock",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"zai",
|
||||
"openrouter",
|
||||
"chatgpt_oauth",
|
||||
] = Field(..., description="The endpoint type for the model.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||
@@ -320,9 +322,10 @@ class LLMConfig(BaseModel):
|
||||
GoogleAIModelSettings,
|
||||
GoogleVertexModelSettings,
|
||||
GroqModelSettings,
|
||||
Model,
|
||||
ModelSettings,
|
||||
OpenAIModelSettings,
|
||||
OpenAIReasoning,
|
||||
OpenRouterModelSettings,
|
||||
TogetherModelSettings,
|
||||
XAIModelSettings,
|
||||
ZAIModelSettings,
|
||||
@@ -395,15 +398,30 @@ class LLMConfig(BaseModel):
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "openrouter":
|
||||
return OpenRouterModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "chatgpt_oauth":
|
||||
return ChatGPTOAuthModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
reasoning=ChatGPTOAuthReasoning(reasoning_effort=self.reasoning_effort or "medium"),
|
||||
)
|
||||
elif self.model_endpoint_type == "minimax":
|
||||
# MiniMax uses Anthropic-compatible API
|
||||
thinking_type = "enabled" if self.enable_reasoner else "disabled"
|
||||
return AnthropicModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
thinking=AnthropicThinking(type=thinking_type, budget_tokens=self.max_reasoning_tokens or 1024),
|
||||
verbosity=self.verbosity,
|
||||
strict=self.strict,
|
||||
)
|
||||
else:
|
||||
# If we don't know the model type, use the default Model schema
|
||||
return Model(max_output_tokens=self.max_tokens or 4096)
|
||||
# If we don't know the model type, use the base ModelSettings schema
|
||||
return ModelSettings(max_output_tokens=self.max_tokens or 4096)
|
||||
|
||||
@classmethod
|
||||
def is_openai_reasoning_model(cls, config: "LLMConfig") -> bool:
|
||||
|
||||
@@ -50,6 +50,7 @@ from letta.schemas.letta_message_content import (
|
||||
ImageContent,
|
||||
ImageSourceType,
|
||||
LettaMessageContentUnion,
|
||||
LettaToolReturnContentUnion,
|
||||
OmittedReasoningContent,
|
||||
ReasoningContent,
|
||||
RedactedReasoningContent,
|
||||
@@ -71,6 +72,34 @@ def truncate_tool_return(content: Optional[str], limit: Optional[int]) -> Option
|
||||
return content[:limit] + f"... [truncated {len(content) - limit} chars]"
|
||||
|
||||
|
||||
def _get_text_from_part(part: Union[TextContent, ImageContent, dict]) -> Optional[str]:
|
||||
"""Extract text from a content part, returning None for images."""
|
||||
if isinstance(part, TextContent):
|
||||
return part.text
|
||||
elif isinstance(part, dict) and part.get("type") == "text":
|
||||
return part.get("text", "")
|
||||
return None
|
||||
|
||||
|
||||
def tool_return_to_text(func_response: Optional[Union[str, List]]) -> Optional[str]:
|
||||
"""Convert tool return content to text, replacing images with placeholders."""
|
||||
if func_response is None:
|
||||
return None
|
||||
if isinstance(func_response, str):
|
||||
return func_response
|
||||
|
||||
text_parts = [text for part in func_response if (text := _get_text_from_part(part))]
|
||||
image_count = sum(
|
||||
1 for part in func_response if isinstance(part, ImageContent) or (isinstance(part, dict) and part.get("type") == "image")
|
||||
)
|
||||
|
||||
result = "\n".join(text_parts)
|
||||
if image_count > 0:
|
||||
placeholder = "[Image omitted]" if image_count == 1 else f"[{image_count} images omitted]"
|
||||
result = (result + " " + placeholder) if result else placeholder
|
||||
return result if result else None
|
||||
|
||||
|
||||
def add_inner_thoughts_to_tool_call(
|
||||
tool_call: OpenAIToolCall,
|
||||
inner_thoughts: str,
|
||||
@@ -366,6 +395,7 @@ class Message(BaseMessage):
|
||||
message_type=lm.message_type,
|
||||
content=lm.content,
|
||||
agent_id=message.agent_id,
|
||||
conversation_id=message.conversation_id,
|
||||
created_at=message.created_at,
|
||||
)
|
||||
)
|
||||
@@ -376,6 +406,7 @@ class Message(BaseMessage):
|
||||
message_type=lm.message_type,
|
||||
content=lm.content,
|
||||
agent_id=message.agent_id,
|
||||
conversation_id=message.conversation_id,
|
||||
created_at=message.created_at,
|
||||
)
|
||||
)
|
||||
@@ -386,6 +417,7 @@ class Message(BaseMessage):
|
||||
message_type=lm.message_type,
|
||||
reasoning=lm.reasoning,
|
||||
agent_id=message.agent_id,
|
||||
conversation_id=message.conversation_id,
|
||||
created_at=message.created_at,
|
||||
)
|
||||
)
|
||||
@@ -396,6 +428,7 @@ class Message(BaseMessage):
|
||||
message_type=lm.message_type,
|
||||
content=lm.content,
|
||||
agent_id=message.agent_id,
|
||||
conversation_id=message.conversation_id,
|
||||
created_at=message.created_at,
|
||||
)
|
||||
)
|
||||
@@ -786,8 +819,14 @@ class Message(BaseMessage):
|
||||
for tool_return in self.tool_returns:
|
||||
parsed_data = self._parse_tool_response(tool_return.func_response)
|
||||
|
||||
# Preserve multi-modal content (ToolReturn supports Union[str, List])
|
||||
if isinstance(tool_return.func_response, list):
|
||||
tool_return_value = tool_return.func_response
|
||||
else:
|
||||
tool_return_value = parsed_data["message"]
|
||||
|
||||
tool_return_obj = LettaToolReturn(
|
||||
tool_return=parsed_data["message"],
|
||||
tool_return=tool_return_value,
|
||||
status=parsed_data["status"],
|
||||
tool_call_id=tool_return.tool_call_id,
|
||||
stdout=tool_return.stdout,
|
||||
@@ -801,11 +840,18 @@ class Message(BaseMessage):
|
||||
|
||||
first_tool_return = all_tool_returns[0]
|
||||
|
||||
# Convert deprecated string-only field to text (preserve images in tool_returns list)
|
||||
deprecated_tool_return_text = (
|
||||
tool_return_to_text(first_tool_return.tool_return)
|
||||
if isinstance(first_tool_return.tool_return, list)
|
||||
else first_tool_return.tool_return
|
||||
)
|
||||
|
||||
return ToolReturnMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
# deprecated top-level fields populated from first tool return
|
||||
tool_return=first_tool_return.tool_return,
|
||||
tool_return=deprecated_tool_return_text,
|
||||
status=first_tool_return.status,
|
||||
tool_call_id=first_tool_return.tool_call_id,
|
||||
stdout=first_tool_return.stdout,
|
||||
@@ -840,11 +886,11 @@ class Message(BaseMessage):
|
||||
"""Check if message has exactly one text content item."""
|
||||
return self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent)
|
||||
|
||||
def _parse_tool_response(self, response_text: str) -> dict:
|
||||
def _parse_tool_response(self, response_text: Union[str, List]) -> dict:
|
||||
"""Parse tool response JSON and extract message and status.
|
||||
|
||||
Args:
|
||||
response_text: Raw JSON response text
|
||||
response_text: Raw JSON response text OR list of content parts (for multi-modal)
|
||||
|
||||
Returns:
|
||||
Dictionary with 'message' and 'status' keys
|
||||
@@ -852,6 +898,14 @@ class Message(BaseMessage):
|
||||
Raises:
|
||||
ValueError: If JSON parsing fails
|
||||
"""
|
||||
# Handle multi-modal content (list with text/images)
|
||||
if isinstance(response_text, list):
|
||||
text_representation = tool_return_to_text(response_text) or "[Multi-modal content]"
|
||||
return {
|
||||
"message": text_representation,
|
||||
"status": "success",
|
||||
}
|
||||
|
||||
try:
|
||||
function_return = parse_json(response_text)
|
||||
return {
|
||||
@@ -1301,7 +1355,9 @@ class Message(BaseMessage):
|
||||
tool_return = self.tool_returns[0]
|
||||
if not tool_return.tool_call_id:
|
||||
raise TypeError("OpenAI API requires tool_call_id to be set.")
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
# Convert to text first (replaces images with placeholders), then truncate
|
||||
func_response_text = tool_return_to_text(tool_return.func_response)
|
||||
func_response = truncate_tool_return(func_response_text, tool_return_truncation_chars)
|
||||
openai_message = {
|
||||
"content": func_response,
|
||||
"role": self.role,
|
||||
@@ -1356,8 +1412,9 @@ class Message(BaseMessage):
|
||||
for tr in m.tool_returns:
|
||||
if not tr.tool_call_id:
|
||||
raise TypeError("ToolReturn came back without a tool_call_id.")
|
||||
# Ensure explicit tool_returns are truncated for Chat Completions
|
||||
func_response = truncate_tool_return(tr.func_response, tool_return_truncation_chars)
|
||||
# Convert multi-modal to text (images → placeholders), then truncate
|
||||
func_response_text = tool_return_to_text(tr.func_response)
|
||||
func_response = truncate_tool_return(func_response_text, tool_return_truncation_chars)
|
||||
result.append(
|
||||
{
|
||||
"content": func_response,
|
||||
@@ -1418,7 +1475,10 @@ class Message(BaseMessage):
|
||||
message_dicts.append(user_dict)
|
||||
|
||||
elif self.role == "assistant" or self.role == "approval":
|
||||
assert self.tool_calls is not None or (self.content is not None and len(self.content) > 0)
|
||||
# Validate that message has content OpenAI Responses API can process
|
||||
if self.tool_calls is None and (self.content is None or len(self.content) == 0):
|
||||
# Skip this message (similar to Anthropic handling at line 1308)
|
||||
return message_dicts
|
||||
|
||||
# A few things may be in here, firstly reasoning content, secondly assistant messages, thirdly tool calls
|
||||
# TODO check if OpenAI Responses is capable of R->A->T like Anthropic?
|
||||
@@ -1456,17 +1516,17 @@ class Message(BaseMessage):
|
||||
)
|
||||
|
||||
elif self.role == "tool":
|
||||
# Handle tool returns - similar pattern to Anthropic
|
||||
# Handle tool returns - supports images via content arrays
|
||||
if self.tool_returns:
|
||||
for tool_return in self.tool_returns:
|
||||
if not tool_return.tool_call_id:
|
||||
raise TypeError("OpenAI Responses API requires tool_call_id to be set.")
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
output = self._tool_return_to_responses_output(tool_return.func_response, tool_return_truncation_chars)
|
||||
message_dicts.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_return.tool_call_id[:max_tool_id_length] if max_tool_id_length else tool_return.tool_call_id,
|
||||
"output": func_response,
|
||||
"output": output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -1534,6 +1594,50 @@ class Message(BaseMessage):
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _image_dict_to_data_url(part: dict) -> Optional[str]:
|
||||
"""Convert image dict to data URL."""
|
||||
source = part.get("source", {})
|
||||
if source.get("type") == "base64" and source.get("data"):
|
||||
media_type = source.get("media_type", "image/png")
|
||||
return f"data:{media_type};base64,{source['data']}"
|
||||
elif source.get("type") == "url":
|
||||
return source.get("url")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _tool_return_to_responses_output(
|
||||
func_response: Optional[Union[str, List]],
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> Union[str, List[dict]]:
|
||||
"""Convert tool return to OpenAI Responses API format."""
|
||||
if func_response is None:
|
||||
return ""
|
||||
if isinstance(func_response, str):
|
||||
return truncate_tool_return(func_response, tool_return_truncation_chars) or ""
|
||||
|
||||
output_parts: List[dict] = []
|
||||
for part in func_response:
|
||||
if isinstance(part, TextContent):
|
||||
text = truncate_tool_return(part.text, tool_return_truncation_chars) or ""
|
||||
output_parts.append({"type": "input_text", "text": text})
|
||||
elif isinstance(part, ImageContent):
|
||||
image_url = Message._image_source_to_data_url(part)
|
||||
if image_url:
|
||||
detail = getattr(part.source, "detail", None) or "auto"
|
||||
output_parts.append({"type": "input_image", "image_url": image_url, "detail": detail})
|
||||
elif isinstance(part, dict):
|
||||
if part.get("type") == "text":
|
||||
text = truncate_tool_return(part.get("text", ""), tool_return_truncation_chars) or ""
|
||||
output_parts.append({"type": "input_text", "text": text})
|
||||
elif part.get("type") == "image":
|
||||
image_url = Message._image_dict_to_data_url(part)
|
||||
if image_url:
|
||||
detail = part.get("source", {}).get("detail", "auto")
|
||||
output_parts.append({"type": "input_image", "image_url": image_url, "detail": detail})
|
||||
|
||||
return output_parts if output_parts else ""
|
||||
|
||||
@staticmethod
|
||||
def to_openai_responses_dicts_from_list(
|
||||
messages: List[Message],
|
||||
@@ -1550,6 +1654,68 @@ class Message(BaseMessage):
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_base64_image_data(part: Union[ImageContent, dict]) -> Optional[tuple[str, str]]:
|
||||
"""Extract base64 data and media type from ImageContent or dict."""
|
||||
if isinstance(part, ImageContent):
|
||||
source = part.source
|
||||
if source.type == ImageSourceType.base64:
|
||||
return source.data, source.media_type
|
||||
elif source.type == ImageSourceType.letta and getattr(source, "data", None):
|
||||
return source.data, getattr(source, "media_type", None) or "image/png"
|
||||
elif isinstance(part, dict) and part.get("type") == "image":
|
||||
source = part.get("source", {})
|
||||
if source.get("type") == "base64" and source.get("data"):
|
||||
return source["data"], source.get("media_type", "image/png")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _tool_return_to_google_parts(
|
||||
func_response: Optional[Union[str, List]],
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> tuple[str, List[dict]]:
|
||||
"""Extract text and image parts for Google API format."""
|
||||
if isinstance(func_response, str):
|
||||
return truncate_tool_return(func_response, tool_return_truncation_chars) or "", []
|
||||
|
||||
text_parts = []
|
||||
image_parts = []
|
||||
for part in func_response:
|
||||
if text := _get_text_from_part(part):
|
||||
text_parts.append(text)
|
||||
elif image_data := Message._get_base64_image_data(part):
|
||||
data, media_type = image_data
|
||||
image_parts.append({"inlineData": {"data": data, "mimeType": media_type}})
|
||||
|
||||
text = truncate_tool_return("\n".join(text_parts), tool_return_truncation_chars) or ""
|
||||
if image_parts:
|
||||
suffix = f"[{len(image_parts)} image(s) attached]"
|
||||
text = f"{text}\n{suffix}" if text else suffix
|
||||
|
||||
return text, image_parts
|
||||
|
||||
@staticmethod
|
||||
def _tool_return_to_anthropic_content(
|
||||
func_response: Optional[Union[str, List]],
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> Union[str, List[dict]]:
|
||||
"""Convert tool return to Anthropic tool_result content format."""
|
||||
if func_response is None:
|
||||
return ""
|
||||
if isinstance(func_response, str):
|
||||
return truncate_tool_return(func_response, tool_return_truncation_chars) or ""
|
||||
|
||||
content: List[dict] = []
|
||||
for part in func_response:
|
||||
if text := _get_text_from_part(part):
|
||||
text = truncate_tool_return(text, tool_return_truncation_chars) or ""
|
||||
content.append({"type": "text", "text": text})
|
||||
elif image_data := Message._get_base64_image_data(part):
|
||||
data, media_type = image_data
|
||||
content.append({"type": "image", "source": {"type": "base64", "data": data, "media_type": media_type}})
|
||||
|
||||
return content if content else ""
|
||||
|
||||
def to_anthropic_dict(
|
||||
self,
|
||||
current_model: str,
|
||||
@@ -1628,8 +1794,11 @@ class Message(BaseMessage):
|
||||
}
|
||||
|
||||
elif self.role == "assistant" or self.role == "approval":
|
||||
# assert self.tool_calls is not None or text_content is not None, vars(self)
|
||||
assert self.tool_calls is not None or len(self.content) > 0
|
||||
# Validate that message has content Anthropic API can process
|
||||
if self.tool_calls is None and (self.content is None or len(self.content) == 0):
|
||||
# Skip this message (consistent with OpenAI dict handling)
|
||||
return None
|
||||
|
||||
anthropic_message = {
|
||||
"role": "assistant",
|
||||
}
|
||||
@@ -1759,12 +1928,13 @@ class Message(BaseMessage):
|
||||
f"Message ID: {self.id}, Tool: {self.name or 'unknown'}, "
|
||||
f"Tool return index: {idx}/{len(self.tool_returns)}"
|
||||
)
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
# Convert to Anthropic format (supports images)
|
||||
tool_result_content = self._tool_return_to_anthropic_content(tool_return.func_response, tool_return_truncation_chars)
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": resolved_tool_call_id,
|
||||
"content": func_response,
|
||||
"content": tool_result_content,
|
||||
}
|
||||
)
|
||||
if content:
|
||||
@@ -1884,7 +2054,16 @@ class Message(BaseMessage):
|
||||
}
|
||||
|
||||
elif self.role == "assistant" or self.role == "approval":
|
||||
assert self.tool_calls is not None or text_content is not None or len(self.content) > 1
|
||||
# Validate that message has content Google API can process
|
||||
if self.tool_calls is None and text_content is None and len(self.content) <= 1:
|
||||
# Message has no tool calls, no extractable text, and not multi-part
|
||||
logger.warning(
|
||||
f"Assistant/approval message {self.id} has no content Google API can convert: "
|
||||
f"tool_calls={self.tool_calls}, text_content={text_content}, content={self.content}"
|
||||
)
|
||||
# Return None to skip this message (similar to approval messages without tool_calls at line 1998)
|
||||
return None
|
||||
|
||||
google_ai_message = {
|
||||
"role": "model", # NOTE: different
|
||||
}
|
||||
@@ -2003,7 +2182,7 @@ class Message(BaseMessage):
|
||||
elif self.role == "tool":
|
||||
# NOTE: Significantly different tool calling format, more similar to function calling format
|
||||
|
||||
# Handle tool returns - similar pattern to Anthropic
|
||||
# Handle tool returns - Google supports images as sibling inlineData parts
|
||||
if self.tool_returns:
|
||||
parts = []
|
||||
for tool_return in self.tool_returns:
|
||||
@@ -2013,26 +2192,24 @@ class Message(BaseMessage):
|
||||
# Use the function name if available, otherwise use tool_call_id
|
||||
function_name = self.name if self.name else tool_return.tool_call_id
|
||||
|
||||
# Truncate the tool return if needed
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
text_content, image_parts = Message._tool_return_to_google_parts(
|
||||
tool_return.func_response, tool_return_truncation_chars
|
||||
)
|
||||
|
||||
# NOTE: Google AI API wants the function response as JSON only, no string
|
||||
try:
|
||||
function_response = parse_json(func_response)
|
||||
function_response = parse_json(text_content)
|
||||
except:
|
||||
function_response = {"function_response": func_response}
|
||||
function_response = {"function_response": text_content}
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": function_name,
|
||||
"response": {
|
||||
"name": function_name, # NOTE: name twice... why?
|
||||
"content": function_response,
|
||||
},
|
||||
"response": {"name": function_name, "content": function_response},
|
||||
}
|
||||
}
|
||||
)
|
||||
parts.extend(image_parts)
|
||||
|
||||
google_ai_message = {
|
||||
"role": "function",
|
||||
@@ -2325,7 +2502,9 @@ class ToolReturn(BaseModel):
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool call")
|
||||
stdout: Optional[List[str]] = Field(default=None, description="Captured stdout (e.g. prints, logs) from the tool invocation")
|
||||
stderr: Optional[List[str]] = Field(default=None, description="Captured stderr from the tool invocation")
|
||||
func_response: Optional[str] = Field(None, description="The function response string")
|
||||
func_response: Optional[Union[str, List[LettaToolReturnContentUnion]]] = Field(
|
||||
None, description="The function response - either a string or list of content parts (text/image)"
|
||||
)
|
||||
|
||||
|
||||
class MessageSearchRequest(BaseModel):
|
||||
|
||||
@@ -42,12 +42,14 @@ class Model(LLMConfig, ModelBase):
|
||||
"koboldcpp",
|
||||
"vllm",
|
||||
"hugging-face",
|
||||
"minimax",
|
||||
"mistral",
|
||||
"together",
|
||||
"bedrock",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"zai",
|
||||
"openrouter",
|
||||
"chatgpt_oauth",
|
||||
] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True)
|
||||
context_window: int = Field(
|
||||
@@ -138,6 +140,7 @@ class Model(LLMConfig, ModelBase):
|
||||
ProviderType.deepseek: DeepseekModelSettings,
|
||||
ProviderType.together: TogetherModelSettings,
|
||||
ProviderType.bedrock: BedrockModelSettings,
|
||||
ProviderType.openrouter: OpenRouterModelSettings,
|
||||
}
|
||||
|
||||
settings_class = PROVIDER_SETTINGS_MAP.get(self.provider_type)
|
||||
@@ -456,6 +459,23 @@ class BedrockModelSettings(ModelSettings):
|
||||
}
|
||||
|
||||
|
||||
class OpenRouterModelSettings(ModelSettings):
|
||||
"""OpenRouter model configuration (OpenAI-compatible)."""
|
||||
|
||||
provider_type: Literal[ProviderType.openrouter] = Field(ProviderType.openrouter, description="The type of the provider.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
"strict": False, # OpenRouter does not support strict mode
|
||||
}
|
||||
|
||||
|
||||
class ChatGPTOAuthReasoning(BaseModel):
|
||||
"""Reasoning configuration for ChatGPT OAuth models (GPT-5.x, o-series)."""
|
||||
|
||||
@@ -495,6 +515,7 @@ ModelSettingsUnion = Annotated[
|
||||
DeepseekModelSettings,
|
||||
TogetherModelSettings,
|
||||
BedrockModelSettings,
|
||||
OpenRouterModelSettings,
|
||||
ChatGPTOAuthModelSettings,
|
||||
],
|
||||
Field(discriminator="provider_type"),
|
||||
|
||||
@@ -44,7 +44,7 @@ class Passage(PassageBase):
|
||||
embedding: Optional[List[float]] = Field(..., description="The embedding of the passage.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(..., description="The embedding configuration used by the passage.")
|
||||
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.")
|
||||
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the passage.")
|
||||
|
||||
@field_validator("embedding", mode="before")
|
||||
@classmethod
|
||||
@@ -83,6 +83,7 @@ class PassageCreate(PassageBase):
|
||||
# optionally provide embeddings
|
||||
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
|
||||
created_at: Optional[datetime] = Field(None, description="Optional creation datetime for the passage.")
|
||||
|
||||
|
||||
class PassageUpdate(PassageCreate):
|
||||
|
||||
@@ -29,6 +29,9 @@ class ProviderTrace(BaseProviderTrace):
|
||||
run_id (str): ID of the run this trace is associated with.
|
||||
source (str): Source service that generated this trace (memgpt-server, lettuce-py).
|
||||
organization_id (str): The unique identifier of the organization.
|
||||
user_id (str): The unique identifier of the user who initiated the request.
|
||||
compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).
|
||||
llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).
|
||||
created_at (datetime): The timestamp when the object was created.
|
||||
"""
|
||||
|
||||
@@ -44,4 +47,30 @@ class ProviderTrace(BaseProviderTrace):
|
||||
run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with")
|
||||
source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)")
|
||||
|
||||
# v2 protocol fields
|
||||
org_id: Optional[str] = Field(None, description="ID of the organization")
|
||||
user_id: Optional[str] = Field(None, description="ID of the user who initiated the request")
|
||||
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
|
||||
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
|
||||
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
|
||||
class ProviderTraceMetadata(BaseProviderTrace):
|
||||
"""Metadata-only representation of a provider trace (no request/response JSON)."""
|
||||
|
||||
id: str = BaseProviderTrace.generate_id_field()
|
||||
step_id: Optional[str] = Field(None, description="ID of the step that this trace is associated with")
|
||||
|
||||
# Telemetry context fields
|
||||
agent_id: Optional[str] = Field(None, description="ID of the agent that generated this trace")
|
||||
agent_tags: Optional[list[str]] = Field(None, description="Tags associated with the agent for filtering")
|
||||
call_type: Optional[str] = Field(None, description="Type of call (agent_step, summarization, etc.)")
|
||||
run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with")
|
||||
source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)")
|
||||
|
||||
# v2 protocol fields
|
||||
org_id: Optional[str] = Field(None, description="ID of the organization")
|
||||
user_id: Optional[str] = Field(None, description="ID of the user who initiated the request")
|
||||
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@@ -12,10 +12,12 @@ from .google_vertex import GoogleVertexProvider
|
||||
from .groq import GroqProvider
|
||||
from .letta import LettaProvider
|
||||
from .lmstudio import LMStudioOpenAIProvider
|
||||
from .minimax import MiniMaxProvider
|
||||
from .mistral import MistralProvider
|
||||
from .ollama import OllamaProvider
|
||||
from .openai import OpenAIProvider
|
||||
from .openrouter import OpenRouterProvider
|
||||
from .sglang import SGLangProvider
|
||||
from .together import TogetherProvider
|
||||
from .vllm import VLLMProvider
|
||||
from .xai import XAIProvider
|
||||
@@ -40,11 +42,13 @@ __all__ = [
|
||||
"GroqProvider",
|
||||
"LettaProvider",
|
||||
"LMStudioOpenAIProvider",
|
||||
"MiniMaxProvider",
|
||||
"MistralProvider",
|
||||
"OllamaProvider",
|
||||
"OpenAIProvider",
|
||||
"TogetherProvider",
|
||||
"VLLMProvider", # Replaces ChatCompletions and Completions
|
||||
"SGLangProvider",
|
||||
"XAIProvider",
|
||||
"ZAIProvider",
|
||||
"OpenRouterProvider",
|
||||
|
||||
@@ -32,6 +32,7 @@ class Provider(ProviderBase):
|
||||
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
||||
organization_id: str | None = Field(None, description="The organization id of the user")
|
||||
updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
|
||||
last_synced: datetime | None = Field(None, description="The last time models were synced for this provider.")
|
||||
|
||||
# Encrypted fields (stored as Secret objects, serialized to strings for DB)
|
||||
# Secret class handles validation and serialization automatically via __get_pydantic_core_schema__
|
||||
@@ -191,9 +192,12 @@ class Provider(ProviderBase):
|
||||
GroqProvider,
|
||||
LettaProvider,
|
||||
LMStudioOpenAIProvider,
|
||||
MiniMaxProvider,
|
||||
MistralProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
OpenRouterProvider,
|
||||
SGLangProvider,
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
XAIProvider,
|
||||
@@ -224,6 +228,8 @@ class Provider(ProviderBase):
|
||||
return OllamaProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.vllm:
|
||||
return VLLMProvider(**self.model_dump(exclude_none=True)) # Removed support for CompletionsProvider
|
||||
case ProviderType.sglang:
|
||||
return SGLangProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.mistral:
|
||||
return MistralProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.deepseek:
|
||||
@@ -240,6 +246,10 @@ class Provider(ProviderBase):
|
||||
return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.bedrock:
|
||||
return BedrockProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.minimax:
|
||||
return MiniMaxProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.openrouter:
|
||||
return OpenRouterProvider(**self.model_dump(exclude_none=True))
|
||||
case _:
|
||||
raise ValueError(f"Unknown provider type: {self.provider_type}")
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
|
||||
class BedrockProvider(Provider):
|
||||
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = Field("bedrock", description="Identifier for Bedrock endpoint (used for model_endpoint)")
|
||||
access_key: str | None = Field(None, description="AWS access key ID for Bedrock")
|
||||
api_key: str | None = Field(None, description="AWS secret access key for Bedrock")
|
||||
region: str = Field(..., description="AWS region for Bedrock")
|
||||
@@ -99,7 +100,7 @@ class BedrockProvider(Provider):
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type=self.provider_type.value,
|
||||
model_endpoint=None,
|
||||
model_endpoint="bedrock",
|
||||
context_window=self.get_model_context_window(inference_profile_id),
|
||||
# Store the full inference profile ID in the handle for API calls
|
||||
handle=self.get_handle(inference_profile_id),
|
||||
|
||||
105
letta/schemas/providers/minimax.py
Normal file
105
letta/schemas/providers/minimax.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from typing import Literal
|
||||
|
||||
import anthropic
|
||||
from pydantic import Field
|
||||
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers.base import Provider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# MiniMax model specifications from official documentation
|
||||
# https://platform.minimax.io/docs/guides/models-intro
|
||||
MODEL_LIST = [
|
||||
{
|
||||
"name": "MiniMax-M2.1",
|
||||
"context_window": 200000,
|
||||
"max_output": 128000,
|
||||
"description": "Polyglot code mastery, precision code refactoring (~60 tps)",
|
||||
},
|
||||
{
|
||||
"name": "MiniMax-M2.1-lightning",
|
||||
"context_window": 200000,
|
||||
"max_output": 128000,
|
||||
"description": "Same performance as M2.1, significantly faster (~100 tps)",
|
||||
},
|
||||
{
|
||||
"name": "MiniMax-M2",
|
||||
"context_window": 200000,
|
||||
"max_output": 128000,
|
||||
"description": "Agentic capabilities, advanced reasoning",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class MiniMaxProvider(Provider):
|
||||
"""
|
||||
MiniMax provider using Anthropic-compatible API.
|
||||
|
||||
MiniMax models support native interleaved thinking without requiring beta headers.
|
||||
The API uses the standard messages endpoint (not beta).
|
||||
|
||||
Documentation: https://platform.minimax.io/docs/api-reference/text-anthropic-api
|
||||
"""
|
||||
|
||||
provider_type: Literal[ProviderType.minimax] = Field(ProviderType.minimax, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str | None = Field(None, description="API key for the MiniMax API.", deprecated=True)
|
||||
base_url: str = Field("https://api.minimax.io/anthropic", description="Base URL for the MiniMax Anthropic-compatible API.")
|
||||
|
||||
async def check_api_key(self):
|
||||
"""Check if the API key is valid by making a test request to the MiniMax API."""
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if not api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
try:
|
||||
# Use async Anthropic client pointed at MiniMax's Anthropic-compatible endpoint
|
||||
client = anthropic.AsyncAnthropic(api_key=api_key, base_url=self.base_url)
|
||||
# Use count_tokens as a lightweight check - similar to Anthropic provider
|
||||
await client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}])
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with MiniMax: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
except Exception as e:
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for MiniMax models."""
|
||||
# All MiniMax models support 128K output tokens
|
||||
return 128000
|
||||
|
||||
def get_model_context_window_size(self, model_name: str) -> int | None:
|
||||
"""Get the context window size for a MiniMax model."""
|
||||
# All current MiniMax models have 200K context window
|
||||
for model in MODEL_LIST:
|
||||
if model["name"] == model_name:
|
||||
return model["context_window"]
|
||||
# Default fallback
|
||||
return 200000
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
"""
|
||||
Return available MiniMax models.
|
||||
|
||||
MiniMax doesn't have a models listing endpoint, so we use a hardcoded list.
|
||||
"""
|
||||
configs = []
|
||||
for model in MODEL_LIST:
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["name"],
|
||||
model_endpoint_type="minimax",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["name"]),
|
||||
max_tokens=model["max_output"],
|
||||
# MiniMax models support native thinking, similar to Claude's extended thinking
|
||||
put_inner_thoughts_in_kwargs=True,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -42,22 +42,37 @@ class OpenAIProvider(Provider):
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for OpenAI models."""
|
||||
if model_name.startswith("gpt-5"):
|
||||
return 16384
|
||||
elif model_name.startswith("o1") or model_name.startswith("o3"):
|
||||
return 100000
|
||||
return 16384 # default for openai
|
||||
"""Get the default max output tokens for OpenAI models (sync fallback)."""
|
||||
# Simple default for openai
|
||||
return 16384
|
||||
|
||||
async def get_default_max_output_tokens_async(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for OpenAI models.
|
||||
|
||||
Uses litellm model specifications with a simple fallback.
|
||||
"""
|
||||
from letta.model_specs.litellm_model_specs import get_max_output_tokens
|
||||
|
||||
# Try litellm specs
|
||||
max_output = await get_max_output_tokens(model_name)
|
||||
if max_output is not None:
|
||||
return max_output
|
||||
|
||||
# Simple default for openai
|
||||
return 16384
|
||||
|
||||
async def _get_models_async(self) -> list[dict]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
||||
|
||||
# Similar to Nebius
|
||||
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
|
||||
# Provider-specific extra parameters for model listing
|
||||
extra_params = None
|
||||
if "openrouter.ai" in self.base_url:
|
||||
# OpenRouter: filter for models with tool calling support
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
extra_params = {"supported_parameters": "tools"}
|
||||
elif "nebius.com" in self.base_url:
|
||||
# Nebius: use verbose mode for better model info
|
||||
extra_params = {"verbose": True}
|
||||
|
||||
# Decrypt API key before using
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
@@ -76,7 +91,7 @@ class OpenAIProvider(Provider):
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
data = await self._get_models_async()
|
||||
return self._list_llm_models(data)
|
||||
return await self._list_llm_models(data)
|
||||
|
||||
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
|
||||
"""Return known OpenAI embedding models.
|
||||
@@ -116,13 +131,13 @@ class OpenAIProvider(Provider):
|
||||
),
|
||||
]
|
||||
|
||||
def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
|
||||
async def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
|
||||
"""
|
||||
This handles filtering out LLM Models by provider that meet Letta's requirements.
|
||||
"""
|
||||
configs = []
|
||||
for model in data:
|
||||
check = self._do_model_checks_for_name_and_context_size(model)
|
||||
check = await self._do_model_checks_for_name_and_context_size_async(model)
|
||||
if check is None:
|
||||
continue
|
||||
model_name, context_window_size = check
|
||||
@@ -174,7 +189,7 @@ class OpenAIProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=handle,
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
max_tokens=await self.get_default_max_output_tokens_async(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
@@ -188,12 +203,30 @@ class OpenAIProvider(Provider):
|
||||
return configs
|
||||
|
||||
def _do_model_checks_for_name_and_context_size(self, model: dict, length_key: str = "context_length") -> tuple[str, int] | None:
|
||||
"""Sync version - uses sync get_model_context_window_size (for subclasses with hardcoded values)."""
|
||||
if "id" not in model:
|
||||
logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
|
||||
return None
|
||||
|
||||
model_name = model["id"]
|
||||
context_window_size = model.get(length_key) or self.get_model_context_window_size(model_name)
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
if not context_window_size:
|
||||
logger.info("No context window size found for model: %s", model_name)
|
||||
return None
|
||||
|
||||
return model_name, context_window_size
|
||||
|
||||
async def _do_model_checks_for_name_and_context_size_async(
|
||||
self, model: dict, length_key: str = "context_length"
|
||||
) -> tuple[str, int] | None:
|
||||
"""Async version - uses async get_model_context_window_size_async (for litellm lookup)."""
|
||||
if "id" not in model:
|
||||
logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
|
||||
return None
|
||||
|
||||
model_name = model["id"]
|
||||
context_window_size = await self.get_model_context_window_size_async(model_name)
|
||||
|
||||
if not context_window_size:
|
||||
logger.info("No context window size found for model: %s", model_name)
|
||||
@@ -211,19 +244,30 @@ class OpenAIProvider(Provider):
|
||||
return llm_config
|
||||
|
||||
def get_model_context_window_size(self, model_name: str) -> int | None:
|
||||
if model_name in LLM_MAX_CONTEXT_WINDOW:
|
||||
return LLM_MAX_CONTEXT_WINDOW[model_name]
|
||||
else:
|
||||
logger.debug(
|
||||
"Model %s on %s for provider %s not found in LLM_MAX_CONTEXT_WINDOW. Using default of {LLM_MAX_CONTEXT_WINDOW['DEFAULT']}",
|
||||
model_name,
|
||||
self.base_url,
|
||||
self.__class__.__name__,
|
||||
)
|
||||
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
|
||||
"""Get the context window size for a model (sync fallback)."""
|
||||
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
|
||||
|
||||
async def get_model_context_window_size_async(self, model_name: str) -> int | None:
|
||||
"""Get the context window size for a model.
|
||||
|
||||
Uses litellm model specifications which covers all OpenAI models.
|
||||
"""
|
||||
from letta.model_specs.litellm_model_specs import get_context_window
|
||||
|
||||
context_window = await get_context_window(model_name)
|
||||
if context_window is not None:
|
||||
return context_window
|
||||
|
||||
# Simple fallback
|
||||
logger.debug(
|
||||
"Model %s not found in litellm specs. Using default of %s",
|
||||
model_name,
|
||||
LLM_MAX_CONTEXT_WINDOW["DEFAULT"],
|
||||
)
|
||||
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> int | None:
|
||||
return self.get_model_context_window_size(model_name)
|
||||
|
||||
async def get_model_context_window_async(self, model_name: str) -> int | None:
|
||||
return self.get_model_context_window_size(model_name)
|
||||
return await self.get_model_context_window_size_async(model_name)
|
||||
|
||||
@@ -1,52 +1,106 @@
|
||||
from typing import Literal
|
||||
|
||||
from openai import AsyncOpenAI, AuthenticationError
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_CONTEXT_WINDOW
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers.openai import OpenAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ALLOWED_PREFIXES = {"gpt-4", "gpt-5", "o1", "o3", "o4"}
|
||||
# DISALLOWED_KEYWORDS = {"transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro", "chat"}
|
||||
# DEFAULT_EMBEDDING_BATCH_SIZE = 1024
|
||||
# Default context window for models not in the API response
|
||||
DEFAULT_CONTEXT_WINDOW = 128000
|
||||
|
||||
|
||||
class OpenRouterProvider(OpenAIProvider):
|
||||
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
|
||||
"""
|
||||
OpenRouter provider - https://openrouter.ai/
|
||||
|
||||
OpenRouter is an OpenAI-compatible API gateway that provides access to
|
||||
multiple LLM providers (Anthropic, Meta, Mistral, etc.) through a unified API.
|
||||
"""
|
||||
|
||||
provider_type: Literal[ProviderType.openrouter] = Field(ProviderType.openrouter, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str | None = Field(None, description="API key for the OpenRouter API.", deprecated=True)
|
||||
base_url: str = Field("https://openrouter.ai/api/v1", description="Base URL for the OpenRouter API.")
|
||||
|
||||
def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
|
||||
async def check_api_key(self):
|
||||
"""Check if the API key is valid by making a test request to the OpenRouter API."""
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if not api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
try:
|
||||
# Use async OpenAI client pointed at OpenRouter's endpoint
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
|
||||
# Just list models to verify API key works
|
||||
await client.models.list()
|
||||
except AuthenticationError as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with OpenRouter: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
except Exception as e:
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def get_model_context_window_size(self, model_name: str) -> int | None:
|
||||
"""Get the context window size for an OpenRouter model.
|
||||
|
||||
OpenRouter models provide context_length in the API response,
|
||||
so this is mainly a fallback.
|
||||
"""
|
||||
This handles filtering out LLM Models by provider that meet Letta's requirements.
|
||||
return DEFAULT_CONTEXT_WINDOW
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
"""
|
||||
Return available OpenRouter models that support tool calling.
|
||||
|
||||
OpenRouter provides a models endpoint that supports filtering by supported_parameters.
|
||||
We filter for models that support 'tools' to ensure Letta compatibility.
|
||||
"""
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
# OpenRouter supports filtering models by supported parameters
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
extra_params = {"supported_parameters": "tools"}
|
||||
|
||||
response = await openai_get_model_list_async(
|
||||
self.base_url,
|
||||
api_key=api_key,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
|
||||
data = response.get("data", response)
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
check = self._do_model_checks_for_name_and_context_size(model)
|
||||
if check is None:
|
||||
if "id" not in model:
|
||||
logger.warning(f"OpenRouter model missing 'id' field: {model}")
|
||||
continue
|
||||
model_name, context_window_size = check
|
||||
|
||||
handle = self.get_handle(model_name)
|
||||
model_name = model["id"]
|
||||
|
||||
config = LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=handle,
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
# OpenRouter returns context_length in the model listing
|
||||
if "context_length" in model and model["context_length"]:
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
logger.debug(f"Model {model_name} missing context_length, using default: {context_window_size}")
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="openrouter",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
config = self._set_model_parameter_tuned_defaults(model_name, config)
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
|
||||
80
letta/schemas/providers/sglang.py
Normal file
80
letta/schemas/providers/sglang.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
SGLang provider for Letta.
|
||||
|
||||
SGLang is a high-performance inference engine that exposes OpenAI-compatible API endpoints.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers.base import Provider
|
||||
|
||||
|
||||
class SGLangProvider(Provider):
|
||||
provider_type: Literal[ProviderType.sglang] = Field(
|
||||
ProviderType.sglang,
|
||||
description="The type of the provider."
|
||||
)
|
||||
provider_category: ProviderCategory = Field(
|
||||
ProviderCategory.base,
|
||||
description="The category of the provider (base or byok)"
|
||||
)
|
||||
base_url: str = Field(
|
||||
...,
|
||||
description="Base URL for the SGLang API (e.g., http://localhost:30000)."
|
||||
)
|
||||
api_key: str | None = Field(
|
||||
None,
|
||||
description="API key for the SGLang API (optional for local instances)."
|
||||
)
|
||||
default_prompt_formatter: str | None = Field(
|
||||
default=None,
|
||||
description="Default prompt formatter (aka model wrapper)."
|
||||
)
|
||||
handle_base: str | None = Field(
|
||||
None,
|
||||
description="Custom handle base name for model handles."
|
||||
)
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
# Ensure base_url ends with /v1 (SGLang uses same convention as vLLM)
|
||||
base_url = self.base_url.rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = base_url + "/v1"
|
||||
|
||||
# Decrypt API key before using (may be None for local instances)
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
response = await openai_get_model_list_async(base_url, api_key=api_key)
|
||||
data = response.get("data", response)
|
||||
|
||||
configs = []
|
||||
|
||||
for model in data:
|
||||
model_name = model["id"]
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="openai", # SGLang is OpenAI-compatible
|
||||
model_endpoint=base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=model.get("max_model_len", 8192),
|
||||
handle=self.get_handle(model_name, base_name=self.handle_base) if self.handle_base else self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
|
||||
# SGLang embedding support not common for training use cases
|
||||
return []
|
||||
@@ -126,3 +126,53 @@ class LettaUsageStatistics(BaseModel):
|
||||
reasoning_tokens: Optional[int] = Field(
|
||||
None, description="The number of reasoning/thinking tokens generated. None if not reported by provider."
|
||||
)
|
||||
|
||||
def to_usage(self, provider_type: Optional["ProviderType"] = None) -> "UsageStatistics":
|
||||
"""Convert to UsageStatistics (OpenAI-compatible format).
|
||||
|
||||
Args:
|
||||
provider_type: ProviderType enum indicating which provider format to use.
|
||||
Used to determine which cache field to populate.
|
||||
|
||||
Returns:
|
||||
UsageStatistics object with nested prompt/completion token details.
|
||||
"""
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
UsageStatistics,
|
||||
UsageStatisticsCompletionTokenDetails,
|
||||
UsageStatisticsPromptTokenDetails,
|
||||
)
|
||||
|
||||
# Providers that use Anthropic-style cache fields (cache_read_tokens, cache_creation_tokens)
|
||||
anthropic_style_providers = {ProviderType.anthropic, ProviderType.bedrock}
|
||||
|
||||
# Build prompt_tokens_details if we have cache data
|
||||
prompt_tokens_details = None
|
||||
if self.cached_input_tokens is not None or self.cache_write_tokens is not None:
|
||||
if provider_type in anthropic_style_providers:
|
||||
# Anthropic uses cache_read_tokens and cache_creation_tokens
|
||||
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
|
||||
cache_read_tokens=self.cached_input_tokens,
|
||||
cache_creation_tokens=self.cache_write_tokens,
|
||||
)
|
||||
else:
|
||||
# OpenAI/Gemini use cached_tokens
|
||||
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
|
||||
cached_tokens=self.cached_input_tokens,
|
||||
)
|
||||
|
||||
# Build completion_tokens_details if we have reasoning tokens
|
||||
completion_tokens_details = None
|
||||
if self.reasoning_tokens is not None:
|
||||
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
|
||||
reasoning_tokens=self.reasoning_tokens,
|
||||
)
|
||||
|
||||
return UsageStatistics(
|
||||
prompt_tokens=self.prompt_tokens,
|
||||
completion_tokens=self.completion_tokens,
|
||||
total_tokens=self.total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ from letta.errors import (
|
||||
HandleNotFoundError,
|
||||
LettaAgentNotFoundError,
|
||||
LettaExpiredError,
|
||||
LettaImageFetchError,
|
||||
LettaInvalidArgumentError,
|
||||
LettaInvalidMCPSchemaError,
|
||||
LettaMCPConnectionError,
|
||||
@@ -64,6 +65,7 @@ from letta.schemas.letta_message import create_letta_error_message_schema, creat
|
||||
from letta.schemas.letta_message_content import (
|
||||
create_letta_assistant_message_content_union_schema,
|
||||
create_letta_message_content_union_schema,
|
||||
create_letta_tool_return_content_union_schema,
|
||||
create_letta_user_message_content_union_schema,
|
||||
)
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
@@ -105,6 +107,7 @@ def generate_openapi_schema(app: FastAPI):
|
||||
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaToolReturnContentUnion"] = create_letta_tool_return_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaErrorMessage"] = create_letta_error_message_schema()
|
||||
|
||||
@@ -163,10 +166,18 @@ async def lifespan(app_: FastAPI):
|
||||
except Exception as e:
|
||||
logger.warning(f"[Worker {worker_id}] Failed to download NLTK data: {e}")
|
||||
|
||||
# logger.info(f"[Worker {worker_id}] Starting lifespan initialization")
|
||||
# logger.info(f"[Worker {worker_id}] Initializing database connections")
|
||||
# db_registry.initialize_async()
|
||||
# logger.info(f"[Worker {worker_id}] Database connections initialized")
|
||||
# Log effective database timeout settings for debugging
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
from letta.server.db import db_registry
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(text("SHOW statement_timeout"))
|
||||
statement_timeout = result.scalar()
|
||||
logger.warning(f"[Worker {worker_id}] PostgreSQL statement_timeout: {statement_timeout}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Worker {worker_id}] Failed to query statement_timeout: {e}")
|
||||
|
||||
if should_use_pinecone():
|
||||
if settings.upsert_pinecone_indices:
|
||||
@@ -180,7 +191,7 @@ async def lifespan(app_: FastAPI):
|
||||
|
||||
logger.info(f"[Worker {worker_id}] Starting scheduler with leader election")
|
||||
global server
|
||||
await server.init_async()
|
||||
await server.init_async(init_with_default_org_and_user=not settings.no_default_actor)
|
||||
try:
|
||||
await start_scheduler_with_leader_election(server)
|
||||
logger.info(f"[Worker {worker_id}] Scheduler initialization completed")
|
||||
@@ -475,6 +486,7 @@ def create_application() -> "FastAPI":
|
||||
app.add_exception_handler(LettaToolNameConflictError, _error_handler_400)
|
||||
app.add_exception_handler(AgentFileImportError, _error_handler_400)
|
||||
app.add_exception_handler(EmbeddingConfigRequiredError, _error_handler_400)
|
||||
app.add_exception_handler(LettaImageFetchError, _error_handler_400)
|
||||
app.add_exception_handler(ValueError, _error_handler_400)
|
||||
|
||||
# 404 Not Found errors
|
||||
|
||||
@@ -3,7 +3,10 @@ from typing import TYPE_CHECKING, Optional
|
||||
from fastapi import Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.otel.tracing import tracer
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.validators import PRIMITIVE_ID_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.server.server import SyncServer
|
||||
@@ -42,6 +45,12 @@ def get_headers(
|
||||
) -> HeaderParams:
|
||||
"""Dependency injection function to extract common headers from requests."""
|
||||
with tracer.start_as_current_span("dependency.get_headers"):
|
||||
if actor_id is not None and PRIMITIVE_ID_PATTERNS[PrimitiveType.USER.value].match(actor_id) is None:
|
||||
raise LettaInvalidArgumentError(
|
||||
message=(f"Invalid user ID format: {actor_id}. Expected format: '{PrimitiveType.USER.value}-<uuid4>'"),
|
||||
argument_name="user_id",
|
||||
)
|
||||
|
||||
return HeaderParams(
|
||||
actor_id=actor_id,
|
||||
user_agent=user_agent,
|
||||
|
||||
@@ -239,11 +239,11 @@ async def create_background_stream_processor(
|
||||
if isinstance(chunk, tuple):
|
||||
chunk = chunk[0]
|
||||
|
||||
# Track terminal events
|
||||
# Track terminal events (check at line start to avoid false positives in message content)
|
||||
if isinstance(chunk, str):
|
||||
if "data: [DONE]" in chunk:
|
||||
if "\ndata: [DONE]" in chunk or chunk.startswith("data: [DONE]"):
|
||||
saw_done = True
|
||||
if "event: error" in chunk:
|
||||
if "\nevent: error" in chunk or chunk.startswith("event: error"):
|
||||
saw_error = True
|
||||
|
||||
# Best-effort extraction of the error payload so we can persist it on the run.
|
||||
|
||||
@@ -308,6 +308,7 @@ async def _import_agent(
|
||||
strip_messages: bool = False,
|
||||
env_vars: Optional[dict[str, Any]] = None,
|
||||
override_embedding_handle: Optional[str] = None,
|
||||
override_model_handle: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Import an agent using the new AgentFileSchema format.
|
||||
@@ -319,6 +320,11 @@ async def _import_agent(
|
||||
else:
|
||||
embedding_config_override = None
|
||||
|
||||
if override_model_handle:
|
||||
llm_config_override = await server.get_llm_config_from_handle_async(actor=actor, handle=override_model_handle)
|
||||
else:
|
||||
llm_config_override = None
|
||||
|
||||
import_result = await server.agent_serialization_manager.import_file(
|
||||
schema=agent_schema,
|
||||
actor=actor,
|
||||
@@ -327,6 +333,7 @@ async def _import_agent(
|
||||
override_existing_tools=override_existing_tools,
|
||||
env_vars=env_vars,
|
||||
override_embedding_config=embedding_config_override,
|
||||
override_llm_config=llm_config_override,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
@@ -362,6 +369,10 @@ async def import_agent(
|
||||
None,
|
||||
description="Embedding handle to override with.",
|
||||
),
|
||||
model: Optional[str] = Form(
|
||||
None,
|
||||
description="Model handle to override the agent's default model. This allows the imported agent to use a different model while keeping other defaults (e.g., context size) from the original configuration.",
|
||||
),
|
||||
# Deprecated fields (maintain backward compatibility)
|
||||
append_copy_suffix: bool = Form(
|
||||
True,
|
||||
@@ -378,6 +389,11 @@ async def import_agent(
|
||||
description="Override import with specific embedding handle. Use 'embedding' instead.",
|
||||
deprecated=True,
|
||||
),
|
||||
override_model_handle: Optional[str] = Form(
|
||||
None,
|
||||
description="Model handle to override the agent's default model. Use 'model' instead.",
|
||||
deprecated=True,
|
||||
),
|
||||
project_id: str | None = Form(
|
||||
None, description="The project ID to associate the uploaded agent with. This is now passed via headers.", deprecated=True
|
||||
),
|
||||
@@ -408,6 +424,7 @@ async def import_agent(
|
||||
# Handle backward compatibility: prefer new field names over deprecated ones
|
||||
final_name = name or override_name
|
||||
final_embedding_handle = embedding or override_embedding_handle or x_override_embedding_model
|
||||
final_model_handle = model or override_model_handle
|
||||
|
||||
# Parse secrets (new) or env_vars_json (deprecated)
|
||||
env_vars = None
|
||||
@@ -440,6 +457,7 @@ async def import_agent(
|
||||
strip_messages=strip_messages,
|
||||
env_vars=env_vars,
|
||||
override_embedding_handle=final_embedding_handle,
|
||||
override_model_handle=final_model_handle,
|
||||
)
|
||||
else:
|
||||
# This is a legacy AgentSchema
|
||||
@@ -628,7 +646,9 @@ async def run_tool_for_agent(
|
||||
|
||||
# Get agent with all relationships
|
||||
agent = await server.agent_manager.get_agent_by_id_async(
|
||||
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
|
||||
agent_id,
|
||||
actor,
|
||||
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
||||
)
|
||||
|
||||
# Find the tool by name among attached tools
|
||||
@@ -701,7 +721,7 @@ async def attach_source(
|
||||
await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async")
|
||||
|
||||
return agent_state
|
||||
@@ -728,7 +748,7 @@ async def attach_folder_to_agent(
|
||||
await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
|
||||
safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async")
|
||||
|
||||
if is_1_0_sdk_version(headers):
|
||||
@@ -759,7 +779,7 @@ async def detach_source(
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
try:
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
|
||||
await server.block_manager.delete_block_async(block.id, actor)
|
||||
except:
|
||||
@@ -791,7 +811,7 @@ async def detach_folder_from_agent(
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
try:
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
|
||||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
|
||||
await server.block_manager.delete_block_async(block.id, actor)
|
||||
except:
|
||||
@@ -1256,7 +1276,7 @@ async def detach_identity_from_agent(
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="list_passages", deprecated=True)
|
||||
@router.get("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="list_passages")
|
||||
async def list_passages(
|
||||
agent_id: AgentId,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -1285,7 +1305,7 @@ async def list_passages(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="create_passage", deprecated=True)
|
||||
@router.post("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="create_passage")
|
||||
async def create_passage(
|
||||
agent_id: AgentId,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
@@ -1306,7 +1326,6 @@ async def create_passage(
|
||||
"/{agent_id}/archival-memory/search",
|
||||
response_model=ArchivalMemorySearchResponse,
|
||||
operation_id="search_archival_memory",
|
||||
deprecated=True,
|
||||
)
|
||||
async def search_archival_memory(
|
||||
agent_id: AgentId,
|
||||
@@ -1354,7 +1373,7 @@ async def search_archival_memory(
|
||||
|
||||
# TODO(ethan): query or path parameter for memory_id?
|
||||
# @router.delete("/{agent_id}/archival")
|
||||
@router.delete("/{agent_id}/archival-memory/{memory_id}", response_model=None, operation_id="delete_passage", deprecated=True)
|
||||
@router.delete("/{agent_id}/archival-memory/{memory_id}", response_model=None, operation_id="delete_passage")
|
||||
async def delete_passage(
|
||||
memory_id: str,
|
||||
agent_id: AgentId,
|
||||
@@ -1520,7 +1539,9 @@ async def send_message(
|
||||
MetricRegistry().user_message_counter.add(1, get_ctx_attributes())
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = await server.agent_manager.get_agent_by_id_async(
|
||||
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
|
||||
agent_id,
|
||||
actor,
|
||||
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
||||
)
|
||||
|
||||
# Handle model override if specified in the request
|
||||
@@ -1799,7 +1820,9 @@ async def _process_message_background(
|
||||
|
||||
try:
|
||||
agent = await server.agent_manager.get_agent_by_id_async(
|
||||
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
|
||||
agent_id,
|
||||
actor,
|
||||
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
||||
)
|
||||
|
||||
# Handle model override if specified
|
||||
@@ -1853,15 +1876,24 @@ async def _process_message_background(
|
||||
|
||||
runs_manager = RunManager()
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
if result.stop_reason.stop_reason == "cancelled":
|
||||
# Handle cases where stop_reason might be None (defensive)
|
||||
if result.stop_reason and result.stop_reason.stop_reason == "cancelled":
|
||||
run_status = RunStatus.cancelled
|
||||
else:
|
||||
stop_reason = result.stop_reason.stop_reason
|
||||
elif result.stop_reason:
|
||||
run_status = RunStatus.completed
|
||||
stop_reason = result.stop_reason.stop_reason
|
||||
else:
|
||||
# Fallback: no stop_reason set (shouldn't happen but defensive)
|
||||
logger.error(f"Run {run_id} completed without stop_reason in result, defaulting to end_turn")
|
||||
run_status = RunStatus.completed
|
||||
stop_reason = StopReasonType.end_turn
|
||||
|
||||
await runs_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=run_status, stop_reason=result.stop_reason.stop_reason),
|
||||
update=RunUpdate(status=run_status, stop_reason=stop_reason),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -1869,20 +1901,22 @@ async def _process_message_background(
|
||||
# Update run status to failed with specific error info
|
||||
runs_manager = RunManager()
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
await runs_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.failed, metadata={"error": str(e)}),
|
||||
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error, metadata={"error": str(e)}),
|
||||
actor=actor,
|
||||
)
|
||||
except Exception as e:
|
||||
# Update run status to failed
|
||||
runs_manager = RunManager()
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
await runs_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.failed, metadata={"error": str(e)}),
|
||||
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error, metadata={"error": str(e)}),
|
||||
actor=actor,
|
||||
)
|
||||
finally:
|
||||
@@ -1966,7 +2000,9 @@ async def send_message_async(
|
||||
|
||||
if use_lettuce:
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(
|
||||
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
|
||||
agent_id,
|
||||
actor,
|
||||
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
||||
)
|
||||
# Allow V1 agents only if the message async flag is enabled
|
||||
is_v1_message_async_enabled = (
|
||||
@@ -2020,10 +2056,11 @@ async def send_message_async(
|
||||
async def update_failed_run():
|
||||
runs_manager = RunManager()
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
await runs_manager.update_run_by_id_async(
|
||||
run_id=run.id,
|
||||
update=RunUpdate(status=RunStatus.failed, metadata={"error": error_str}),
|
||||
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error, metadata={"error": error_str}),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +48,13 @@ class PassageCreateRequest(BaseModel):
|
||||
text: str = Field(..., description="The text content of the passage")
|
||||
metadata: Optional[Dict] = Field(default=None, description="Optional metadata for the passage")
|
||||
tags: Optional[List[str]] = Field(default=None, description="Optional tags for categorizing the passage")
|
||||
created_at: Optional[str] = Field(default=None, description="Optional creation datetime for the passage (ISO 8601 format)")
|
||||
|
||||
|
||||
class PassageBatchCreateRequest(BaseModel):
|
||||
"""Request model for creating multiple passages in an archive."""
|
||||
|
||||
passages: List[PassageCreateRequest] = Field(..., description="Passages to create in the archive")
|
||||
|
||||
|
||||
@router.post("/", response_model=PydanticArchive, operation_id="create_archive")
|
||||
@@ -65,16 +72,14 @@ async def create_archive(
|
||||
if embedding_config is None:
|
||||
embedding_handle = archive.embedding
|
||||
if embedding_handle is None:
|
||||
if settings.default_embedding_handle is None:
|
||||
raise LettaInvalidArgumentError(
|
||||
"Must specify either embedding or embedding_config in request", argument_name="default_embedding_handle"
|
||||
)
|
||||
else:
|
||||
embedding_handle = settings.default_embedding_handle
|
||||
embedding_config = await server.get_embedding_config_from_handle_async(
|
||||
handle=embedding_handle,
|
||||
actor=actor,
|
||||
)
|
||||
embedding_handle = settings.default_embedding_handle
|
||||
# Only resolve embedding config if we have an embedding handle
|
||||
if embedding_handle is not None:
|
||||
embedding_config = await server.get_embedding_config_from_handle_async(
|
||||
handle=embedding_handle,
|
||||
actor=actor,
|
||||
)
|
||||
# Otherwise, embedding_config remains None (text search only)
|
||||
|
||||
return await server.archive_manager.create_archive_async(
|
||||
name=archive.name,
|
||||
@@ -227,6 +232,27 @@ async def create_passage_in_archive(
|
||||
text=passage.text,
|
||||
metadata=passage.metadata,
|
||||
tags=passage.tags,
|
||||
created_at=passage.created_at,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{archive_id}/passages/batch", response_model=List[Passage], operation_id="create_passages_in_archive")
|
||||
async def create_passages_in_archive(
|
||||
archive_id: ArchiveId,
|
||||
payload: PassageBatchCreateRequest = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""
|
||||
Create multiple passages in an archive.
|
||||
|
||||
This adds passages to the archive and creates embeddings for vector storage.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.archive_manager.create_passages_in_archive_async(
|
||||
archive_id=archive_id,
|
||||
passages=[passage.model_dump() for passage in payload.passages],
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,16 +5,20 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.agents.agent_loop import AgentLoop
|
||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||
from letta.constants import REDIS_RUN_ID_PREFIX
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import LettaExpiredError, LettaInvalidArgumentError, NoActiveRunsToCancelError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.conversation import Conversation, CreateConversation, UpdateConversation
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.job import LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import LettaStreamingRequest, RetrieveStreamRequest
|
||||
from letta.schemas.letta_request import ConversationMessageRequest, LettaStreamingRequest, RetrieveStreamRequest
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
||||
from letta.server.rest_api.streaming_response import (
|
||||
@@ -60,6 +64,7 @@ async def list_conversations(
|
||||
agent_id: str = Query(..., description="The agent ID to list conversations for"),
|
||||
limit: int = Query(50, description="Maximum number of conversations to return"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination (conversation ID)"),
|
||||
summary_search: Optional[str] = Query(None, description="Search for text within conversation summaries"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
@@ -70,6 +75,7 @@ async def list_conversations(
|
||||
actor=actor,
|
||||
limit=limit,
|
||||
after=after,
|
||||
summary_search=summary_search,
|
||||
)
|
||||
|
||||
|
||||
@@ -154,51 +160,112 @@ async def list_conversation_messages(
|
||||
|
||||
@router.post(
|
||||
"/{conversation_id}/messages",
|
||||
response_model=LettaStreamingResponse,
|
||||
response_model=LettaResponse,
|
||||
operation_id="send_conversation_message",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
"text/event-stream": {"description": "Server-Sent Events stream (default, when streaming=true)"},
|
||||
"application/json": {"description": "JSON response (when streaming=false)"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def send_conversation_message(
|
||||
conversation_id: ConversationId,
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
request: ConversationMessageRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
) -> StreamingResponse | LettaResponse:
|
||||
"""
|
||||
Send a message to a conversation and get a streaming response.
|
||||
Send a message to a conversation and get a response.
|
||||
|
||||
This endpoint sends a message to an existing conversation and streams
|
||||
the agent's response back.
|
||||
This endpoint sends a message to an existing conversation.
|
||||
By default (streaming=true), returns a streaming response (Server-Sent Events).
|
||||
Set streaming=false to get a complete JSON response.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Get the conversation to find the agent_id
|
||||
if not request.messages or len(request.messages) == 0:
|
||||
raise HTTPException(status_code=422, detail="Messages must not be empty")
|
||||
|
||||
conversation = await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Force streaming mode for this endpoint
|
||||
request.streaming = True
|
||||
# Streaming mode (default)
|
||||
if request.streaming:
|
||||
# Convert to LettaStreamingRequest for StreamingService compatibility
|
||||
streaming_request = LettaStreamingRequest(
|
||||
messages=request.messages,
|
||||
streaming=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
include_pings=request.include_pings,
|
||||
background=request.background,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
override_model=request.override_model,
|
||||
client_tools=request.client_tools,
|
||||
)
|
||||
streaming_service = StreamingService(server)
|
||||
run, result = await streaming_service.create_agent_stream(
|
||||
agent_id=conversation.agent_id,
|
||||
actor=actor,
|
||||
request=streaming_request,
|
||||
run_type="send_conversation_message",
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
return result
|
||||
|
||||
# Use streaming service
|
||||
streaming_service = StreamingService(server)
|
||||
run, result = await streaming_service.create_agent_stream(
|
||||
agent_id=conversation.agent_id,
|
||||
actor=actor,
|
||||
request=request,
|
||||
run_type="send_conversation_message",
|
||||
conversation_id=conversation_id,
|
||||
# Non-streaming mode
|
||||
agent = await server.agent_manager.get_agent_by_id_async(
|
||||
conversation.agent_id,
|
||||
actor,
|
||||
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
|
||||
)
|
||||
|
||||
return result
|
||||
if request.override_model:
|
||||
override_llm_config = await server.get_llm_config_from_handle_async(
|
||||
actor=actor,
|
||||
handle=request.override_model,
|
||||
)
|
||||
agent = agent.model_copy(update={"llm_config": override_llm_config})
|
||||
|
||||
# Create a run for execution tracking
|
||||
run = None
|
||||
if settings.track_agent_run:
|
||||
runs_manager = RunManager()
|
||||
run = await runs_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
agent_id=conversation.agent_id,
|
||||
background=False,
|
||||
metadata={
|
||||
"run_type": "send_conversation_message",
|
||||
},
|
||||
request_config=LettaRequestConfig.from_letta_request(request),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Set run_id in Redis for cancellation support
|
||||
redis_client = await get_redis_client()
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{conversation.agent_id}", run.id if run else None)
|
||||
|
||||
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
|
||||
return await agent_loop.step(
|
||||
request.messages,
|
||||
max_steps=request.max_steps,
|
||||
run_id=run.id if run else None,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
client_tools=request.client_tools,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -289,11 +356,14 @@ async def retrieve_conversation_stream(
|
||||
)
|
||||
|
||||
if settings.enable_cancellation_aware_streaming:
|
||||
from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run
|
||||
|
||||
stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=stream,
|
||||
run_manager=server.run_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
cancellation_event=get_cancellation_event_for_run(run.id),
|
||||
)
|
||||
|
||||
if request and request.include_pings and settings.enable_keepalive:
|
||||
|
||||
@@ -594,7 +594,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id:
|
||||
|
||||
|
||||
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
for agent in agents:
|
||||
if agent.enable_sleeptime:
|
||||
|
||||
@@ -231,7 +231,7 @@ async def list_messages_for_batch(
|
||||
|
||||
# Get messages directly using our efficient method
|
||||
messages = await server.batch_manager.get_messages_for_letta_batch_async(
|
||||
letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, ascending=(order == "asc"), before=before, after=after
|
||||
letta_batch_job_id=batch_id, actor=actor, limit=limit, agent_id=agent_id, sort_descending=(order == "desc"), cursor=after
|
||||
)
|
||||
|
||||
return LettaBatchMessages(messages=messages)
|
||||
|
||||
@@ -4,18 +4,62 @@ from typing import List, Literal, Optional
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/passages", tags=["passages"])
|
||||
|
||||
|
||||
async def _get_embedding_config_for_search(
|
||||
server: SyncServer,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str],
|
||||
archive_id: Optional[str],
|
||||
) -> Optional[EmbeddingConfig]:
|
||||
"""Determine which embedding config to use for a passage search.
|
||||
|
||||
Args:
|
||||
server: The SyncServer instance
|
||||
actor: The user making the request
|
||||
agent_id: Optional agent ID to get embedding config from
|
||||
archive_id: Optional archive ID to get embedding config from
|
||||
|
||||
Returns:
|
||||
The embedding config to use, or None if not found
|
||||
|
||||
Priority:
|
||||
1. If agent_id is provided, use that agent's embedding config
|
||||
2. If archive_id is provided, use that archive's embedding config
|
||||
3. Otherwise, try to get embedding config from any existing agent
|
||||
4. Fall back to server default if no agents exist
|
||||
"""
|
||||
if agent_id:
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
return agent_state.embedding_config
|
||||
|
||||
if archive_id:
|
||||
archive = await server.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor)
|
||||
return archive.embedding_config
|
||||
|
||||
# Search across all passages - try to get embedding config from any agent
|
||||
agent_count = await server.agent_manager.size_async(actor=actor)
|
||||
if agent_count > 0:
|
||||
agents = await server.agent_manager.list_agents_async(actor=actor, limit=1)
|
||||
if agents:
|
||||
return agents[0].embedding_config
|
||||
|
||||
# Fall back to server default
|
||||
return server.default_embedding_config
|
||||
|
||||
|
||||
class PassageSearchRequest(BaseModel):
|
||||
"""Request model for searching passages across archives."""
|
||||
|
||||
query: str = Field(..., description="Text query for semantic search")
|
||||
query: Optional[str] = Field(None, description="Text query for semantic search")
|
||||
agent_id: Optional[str] = Field(None, description="Filter passages by agent ID")
|
||||
archive_id: Optional[str] = Field(None, description="Filter passages by archive ID")
|
||||
tags: Optional[List[str]] = Field(None, description="Optional list of tags to filter search results")
|
||||
@@ -56,29 +100,16 @@ async def search_passages(
|
||||
# Convert tag_match_mode to enum
|
||||
tag_mode = TagMatchMode.ANY if request.tag_match_mode == "any" else TagMatchMode.ALL
|
||||
|
||||
# Determine which embedding config to use
|
||||
# Determine embedding config (only needed when query text is provided)
|
||||
embed_query = bool(request.query)
|
||||
embedding_config = None
|
||||
if request.agent_id:
|
||||
# Search by agent
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=request.agent_id, actor=actor)
|
||||
embedding_config = agent_state.embedding_config
|
||||
elif request.archive_id:
|
||||
# Search by archive_id
|
||||
archive = await server.archive_manager.get_archive_by_id_async(archive_id=request.archive_id, actor=actor)
|
||||
embedding_config = archive.embedding_config
|
||||
else:
|
||||
# Search across all passages in the organization
|
||||
# Get default embedding config from any agent or use server default
|
||||
agent_count = await server.agent_manager.size_async(actor=actor)
|
||||
if agent_count > 0:
|
||||
# Get first agent to derive embedding config
|
||||
agents = await server.agent_manager.list_agents_async(actor=actor, limit=1)
|
||||
if agents:
|
||||
embedding_config = agents[0].embedding_config
|
||||
|
||||
if not embedding_config:
|
||||
# Fall back to server default
|
||||
embedding_config = server.default_embedding_config
|
||||
if embed_query:
|
||||
embedding_config = await _get_embedding_config_for_search(
|
||||
server=server,
|
||||
actor=actor,
|
||||
agent_id=request.agent_id,
|
||||
archive_id=request.archive_id,
|
||||
)
|
||||
|
||||
# Search passages
|
||||
passages_with_metadata = await server.agent_manager.query_agent_passages_async(
|
||||
@@ -88,7 +119,7 @@ async def search_passages(
|
||||
query_text=request.query,
|
||||
limit=request.limit,
|
||||
embedding_config=embedding_config,
|
||||
embed_query=True,
|
||||
embed_query=embed_query,
|
||||
tags=request.tags,
|
||||
tag_match_mode=tag_mode,
|
||||
start_date=request.start_date,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
@@ -144,6 +144,27 @@ async def check_existing_provider(
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{provider_id}/refresh", response_model=Provider, operation_id="refresh_provider_models")
|
||||
async def refresh_provider_models(
|
||||
provider_id: ProviderId,
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Refresh models for a BYOK provider by querying the provider's API.
|
||||
Adds new models and removes ones no longer available.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
|
||||
|
||||
# Only allow refresh for BYOK providers
|
||||
if provider.provider_category != ProviderCategory.byok:
|
||||
raise HTTPException(status_code=400, detail="Refresh is only supported for BYOK providers")
|
||||
|
||||
await server.provider_manager._sync_default_models_for_provider(provider, actor)
|
||||
return await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
|
||||
|
||||
|
||||
@router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
|
||||
async def delete_provider(
|
||||
provider_id: ProviderId,
|
||||
|
||||
@@ -393,11 +393,14 @@ async def retrieve_stream_for_run(
|
||||
)
|
||||
|
||||
if settings.enable_cancellation_aware_streaming:
|
||||
from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run
|
||||
|
||||
stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=stream,
|
||||
run_manager=server.run_manager,
|
||||
run_id=run_id,
|
||||
actor=actor,
|
||||
cancellation_event=get_cancellation_event_for_run(run_id),
|
||||
)
|
||||
|
||||
if request.include_pings and settings.enable_keepalive:
|
||||
|
||||
@@ -485,7 +485,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id:
|
||||
|
||||
|
||||
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
for agent in agents:
|
||||
if agent.enable_sleeptime:
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter, Body, Depends, Query
|
||||
|
||||
from letta.schemas.user import User, UserCreate, UserUpdate
|
||||
from letta.server.rest_api.dependencies import get_letta_server
|
||||
from letta.validators import UserIdQueryRequired
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.user import User
|
||||
@@ -52,7 +53,7 @@ async def update_user(
|
||||
|
||||
@router.delete("/", tags=["admin"], response_model=User, operation_id="delete_user")
|
||||
async def delete_user(
|
||||
user_id: str = Query(..., description="The user_id key to be deleted."),
|
||||
user_id: UserIdQueryRequired,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
@@ -26,6 +27,17 @@ from letta.utils import safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Global registry of cancellation events per run_id
|
||||
# Note: Events are small and we don't bother cleaning them up
|
||||
_cancellation_events: Dict[str, asyncio.Event] = {}
|
||||
|
||||
|
||||
def get_cancellation_event_for_run(run_id: str) -> asyncio.Event:
|
||||
"""Get or create a cancellation event for a run."""
|
||||
if run_id not in _cancellation_events:
|
||||
_cancellation_events[run_id] = asyncio.Event()
|
||||
return _cancellation_events[run_id]
|
||||
|
||||
|
||||
class RunCancelledException(Exception):
|
||||
"""Exception raised when a run is explicitly cancelled (not due to client timeout)"""
|
||||
@@ -125,6 +137,7 @@ async def cancellation_aware_stream_wrapper(
|
||||
run_id: str,
|
||||
actor: User,
|
||||
cancellation_check_interval: float = 0.5,
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
) -> AsyncIterator[str | bytes]:
|
||||
"""
|
||||
Wraps a stream generator to provide real-time run cancellation checking.
|
||||
@@ -156,11 +169,22 @@ async def cancellation_aware_stream_wrapper(
|
||||
run = await run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if run.status == RunStatus.cancelled:
|
||||
logger.info(f"Stream cancelled for run {run_id}, interrupting stream")
|
||||
|
||||
# Signal cancellation via shared event if available
|
||||
if cancellation_event:
|
||||
cancellation_event.set()
|
||||
logger.info(f"Set cancellation event for run {run_id}")
|
||||
|
||||
# Send cancellation event to client
|
||||
cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
|
||||
yield f"data: {json.dumps(cancellation_event)}\n\n"
|
||||
# Raise custom exception for explicit run cancellation
|
||||
raise RunCancelledException(run_id, f"Run {run_id} was cancelled")
|
||||
stop_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
|
||||
yield f"data: {json.dumps(stop_event)}\n\n"
|
||||
|
||||
# Inject exception INTO the generator so its except blocks can catch it
|
||||
try:
|
||||
await stream_generator.athrow(RunCancelledException(run_id, f"Run {run_id} was cancelled"))
|
||||
except (StopAsyncIteration, RunCancelledException):
|
||||
# Generator closed gracefully or raised the exception back
|
||||
break
|
||||
except RunCancelledException:
|
||||
# Re-raise cancellation immediately, don't catch it
|
||||
raise
|
||||
@@ -173,9 +197,10 @@ async def cancellation_aware_stream_wrapper(
|
||||
yield chunk
|
||||
|
||||
except RunCancelledException:
|
||||
# Re-raise RunCancelledException to distinguish from client timeout
|
||||
# Don't re-raise - we already injected the exception into the generator
|
||||
# The generator has handled it and set its stream_was_cancelled flag
|
||||
logger.info(f"Stream for run {run_id} was explicitly cancelled and cleaned up")
|
||||
raise
|
||||
# Don't raise - let it exit gracefully
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise CancelledError (likely client timeout) to ensure proper cleanup
|
||||
logger.info(f"Stream for run {run_id} was cancelled (likely client timeout) and cleaned up")
|
||||
|
||||
@@ -20,7 +20,7 @@ from letta.constants import (
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages, resolve_tool_return_images
|
||||
from letta.log import get_logger
|
||||
from letta.otel.context import get_ctx_attributes
|
||||
from letta.otel.metric_registry import MetricRegistry
|
||||
@@ -171,18 +171,26 @@ async def create_input_messages(
|
||||
return messages
|
||||
|
||||
|
||||
def create_approval_response_message_from_input(
|
||||
async def create_approval_response_message_from_input(
|
||||
agent_state: AgentState, input_message: ApprovalCreate, run_id: Optional[str] = None
|
||||
) -> List[Message]:
|
||||
def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn):
|
||||
async def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn):
|
||||
if isinstance(maybe_tool_return, LettaToolReturn):
|
||||
packaged_function_response = package_function_response(
|
||||
maybe_tool_return.status == "success", maybe_tool_return.tool_return, agent_state.timezone
|
||||
)
|
||||
tool_return_content = maybe_tool_return.tool_return
|
||||
|
||||
# Handle tool_return content - can be string or list of content parts (text/image)
|
||||
if isinstance(tool_return_content, str):
|
||||
# String content - wrap with package_function_response as before
|
||||
func_response = package_function_response(maybe_tool_return.status == "success", tool_return_content, agent_state.timezone)
|
||||
else:
|
||||
# List of content parts (text/image) - resolve URL images to base64 first
|
||||
resolved_content = await resolve_tool_return_images(tool_return_content)
|
||||
func_response = resolved_content
|
||||
|
||||
return ToolReturn(
|
||||
tool_call_id=maybe_tool_return.tool_call_id,
|
||||
status=maybe_tool_return.status,
|
||||
func_response=packaged_function_response,
|
||||
func_response=func_response,
|
||||
stdout=maybe_tool_return.stdout,
|
||||
stderr=maybe_tool_return.stderr,
|
||||
)
|
||||
@@ -196,6 +204,11 @@ def create_approval_response_message_from_input(
|
||||
getattr(input_message, "approval_request_id", None),
|
||||
)
|
||||
|
||||
# Process all tool returns concurrently (for async image resolution)
|
||||
import asyncio
|
||||
|
||||
converted_approvals = await asyncio.gather(*[maybe_convert_tool_return_message(approval) for approval in approvals_list])
|
||||
|
||||
return [
|
||||
Message(
|
||||
role=MessageRole.approval,
|
||||
@@ -204,7 +217,7 @@ def create_approval_response_message_from_input(
|
||||
approval_request_id=input_message.approval_request_id,
|
||||
approve=input_message.approve,
|
||||
denial_reason=input_message.reason,
|
||||
approvals=[maybe_convert_tool_return_message(approval) for approval in approvals_list],
|
||||
approvals=list(converted_approvals),
|
||||
run_id=run_id,
|
||||
group_id=input_message.group_id
|
||||
if input_message.group_id
|
||||
|
||||
@@ -19,7 +19,6 @@ from letta.config import LettaConfig
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.errors import (
|
||||
EmbeddingConfigRequiredError,
|
||||
HandleNotFoundError,
|
||||
LettaInvalidArgumentError,
|
||||
LettaMCPConnectionError,
|
||||
@@ -68,10 +67,12 @@ from letta.schemas.providers import (
|
||||
GroqProvider,
|
||||
LettaProvider,
|
||||
LMStudioOpenAIProvider,
|
||||
MiniMaxProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
OpenRouterProvider,
|
||||
Provider,
|
||||
SGLangProvider,
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
XAIProvider,
|
||||
@@ -283,15 +284,33 @@ class SyncServer(object):
|
||||
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
|
||||
# see: https://docs.vllm.ai/en/stable/features/tool_calling.html
|
||||
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
|
||||
# Auto-append /v1 to the base URL
|
||||
vllm_url = (
|
||||
model_settings.vllm_api_base if model_settings.vllm_api_base.endswith("/v1") else model_settings.vllm_api_base + "/v1"
|
||||
)
|
||||
self._enabled_providers.append(
|
||||
VLLMProvider(
|
||||
name="vllm",
|
||||
base_url=model_settings.vllm_api_base,
|
||||
base_url=vllm_url,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
handle_base=model_settings.vllm_handle_base,
|
||||
)
|
||||
)
|
||||
|
||||
if model_settings.sglang_api_base:
|
||||
# Auto-append /v1 to the base URL
|
||||
sglang_url = (
|
||||
model_settings.sglang_api_base if model_settings.sglang_api_base.endswith("/v1") else model_settings.sglang_api_base + "/v1"
|
||||
)
|
||||
self._enabled_providers.append(
|
||||
SGLangProvider(
|
||||
name="sglang",
|
||||
base_url=sglang_url,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
handle_base=model_settings.sglang_handle_base,
|
||||
)
|
||||
)
|
||||
|
||||
if model_settings.aws_access_key_id and model_settings.aws_secret_access_key and model_settings.aws_default_region:
|
||||
self._enabled_providers.append(
|
||||
BedrockProvider(
|
||||
@@ -324,6 +343,13 @@ class SyncServer(object):
|
||||
api_key_enc=Secret.from_plaintext(model_settings.xai_api_key),
|
||||
)
|
||||
)
|
||||
if model_settings.minimax_api_key:
|
||||
self._enabled_providers.append(
|
||||
MiniMaxProvider(
|
||||
name="minimax",
|
||||
api_key_enc=Secret.from_plaintext(model_settings.minimax_api_key),
|
||||
)
|
||||
)
|
||||
if model_settings.zai_api_key:
|
||||
self._enabled_providers.append(
|
||||
ZAIProvider(
|
||||
@@ -443,6 +469,8 @@ class SyncServer(object):
|
||||
embedding_models=embedding_models,
|
||||
organization_id=None, # Global models
|
||||
)
|
||||
# Update last_synced timestamp
|
||||
await self.provider_manager.update_provider_last_synced_async(persisted_provider.id)
|
||||
logger.info(
|
||||
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
|
||||
)
|
||||
@@ -628,9 +656,10 @@ class SyncServer(object):
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
|
||||
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_sleeptime_agent")
|
||||
logger.warning(f"Skipping sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
||||
return None
|
||||
request = CreateAgent(
|
||||
name=main_agent.name + "-sleeptime",
|
||||
agent_type=AgentType.sleeptime_agent,
|
||||
@@ -662,9 +691,10 @@ class SyncServer(object):
|
||||
)
|
||||
return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor)
|
||||
|
||||
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState:
|
||||
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_voice_sleeptime_agent")
|
||||
logger.warning(f"Skipping voice sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
||||
return None
|
||||
# TODO: Inject system
|
||||
request = CreateAgent(
|
||||
name=main_agent.name + "-sleeptime",
|
||||
@@ -956,7 +986,7 @@ class SyncServer(object):
|
||||
from letta.data_sources.connectors import DirectoryConnector
|
||||
|
||||
# TODO: move this into a thread
|
||||
source = await self.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
connector = DirectoryConnector(input_files=[file_path])
|
||||
num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
|
||||
|
||||
@@ -1041,9 +1071,10 @@ class SyncServer(object):
|
||||
|
||||
async def create_document_sleeptime_agent_async(
|
||||
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
||||
) -> AgentState:
|
||||
) -> Optional[AgentState]:
|
||||
if main_agent.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=main_agent.id, operation="create_document_sleeptime_agent")
|
||||
logger.warning(f"Skipping document sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
||||
return None
|
||||
try:
|
||||
block = await self.agent_manager.get_block_with_label_async(agent_id=main_agent.id, block_label=source.name, actor=actor)
|
||||
except:
|
||||
@@ -1151,10 +1182,18 @@ class SyncServer(object):
|
||||
if provider_type and provider.provider_type != provider_type:
|
||||
continue
|
||||
|
||||
# For bedrock, use schema default for base_url since DB may have NULL
|
||||
# TODO: can maybe do this for all models but want to isolate change so we don't break any other providers
|
||||
if provider.provider_type == ProviderType.bedrock:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
model_endpoint = typed_provider.base_url
|
||||
else:
|
||||
model_endpoint = provider.base_url
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
model_endpoint=model_endpoint,
|
||||
context_window=model.max_context_window or 16384,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
@@ -1162,7 +1201,7 @@ class SyncServer(object):
|
||||
)
|
||||
llm_models.append(llm_config)
|
||||
|
||||
# Get BYOK provider models by hitting provider endpoints directly
|
||||
# Get BYOK provider models - sync if not synced yet, then read from DB
|
||||
if include_byok:
|
||||
byok_providers = await self.provider_manager.list_providers_async(
|
||||
actor=actor,
|
||||
@@ -1173,9 +1212,39 @@ class SyncServer(object):
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
# Get typed provider to access schema defaults (e.g., base_url)
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_llm_models_async()
|
||||
llm_models.extend(models)
|
||||
|
||||
# Sync models if not synced yet
|
||||
if provider.last_synced is None:
|
||||
models = await typed_provider.list_llm_models_async()
|
||||
embedding_models = await typed_provider.list_embedding_models_async()
|
||||
await self.provider_manager.sync_provider_models_async(
|
||||
provider=provider,
|
||||
llm_models=models,
|
||||
embedding_models=embedding_models,
|
||||
organization_id=provider.organization_id,
|
||||
)
|
||||
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
|
||||
|
||||
# Read from database
|
||||
provider_llm_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="llm",
|
||||
provider_id=provider.id,
|
||||
enabled=True,
|
||||
)
|
||||
for model in provider_llm_models:
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=typed_provider.base_url,
|
||||
context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
llm_models.append(llm_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
|
||||
|
||||
@@ -1217,7 +1286,7 @@ class SyncServer(object):
|
||||
)
|
||||
embedding_models.append(embedding_config)
|
||||
|
||||
# Get BYOK provider models by hitting provider endpoints directly
|
||||
# Get BYOK provider models - sync if not synced yet, then read from DB
|
||||
byok_providers = await self.provider_manager.list_providers_async(
|
||||
actor=actor,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
@@ -1225,9 +1294,38 @@ class SyncServer(object):
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
# Get typed provider to access schema defaults (e.g., base_url)
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_embedding_models_async()
|
||||
embedding_models.extend(models)
|
||||
|
||||
# Sync models if not synced yet
|
||||
if provider.last_synced is None:
|
||||
llm_models = await typed_provider.list_llm_models_async()
|
||||
emb_models = await typed_provider.list_embedding_models_async()
|
||||
await self.provider_manager.sync_provider_models_async(
|
||||
provider=provider,
|
||||
llm_models=llm_models,
|
||||
embedding_models=emb_models,
|
||||
organization_id=provider.organization_id,
|
||||
)
|
||||
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
|
||||
|
||||
# Read from database
|
||||
provider_embedding_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="embedding",
|
||||
provider_id=provider.id,
|
||||
enabled=True,
|
||||
)
|
||||
for model in provider_embedding_models:
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_model=model.name,
|
||||
embedding_endpoint_type=model.model_endpoint_type,
|
||||
embedding_endpoint=typed_provider.base_url,
|
||||
embedding_dim=model.embedding_dim or 1536,
|
||||
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=model.handle,
|
||||
)
|
||||
embedding_models.append(embedding_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")
|
||||
|
||||
|
||||
@@ -357,7 +357,11 @@ class AgentManager:
|
||||
)
|
||||
agent_create.llm_config = LLMConfig.apply_reasoning_setting_to_config(
|
||||
agent_create.llm_config,
|
||||
agent_create.reasoning if agent_create.reasoning is not None else default_reasoning,
|
||||
agent_create.reasoning
|
||||
if agent_create.reasoning is not None
|
||||
else (
|
||||
agent_create.llm_config.enable_reasoner if agent_create.llm_config.enable_reasoner is not None else default_reasoning
|
||||
),
|
||||
agent_create.agent_type,
|
||||
)
|
||||
else:
|
||||
@@ -2042,10 +2046,12 @@ class AgentManager:
|
||||
if other_agent_id != agent_id:
|
||||
try:
|
||||
other_agent = await AgentModel.read_async(db_session=session, identifier=other_agent_id, actor=actor)
|
||||
if other_agent.agent_type == AgentType.sleeptime_agent and block not in other_agent.core_memory:
|
||||
other_agent.core_memory.append(block)
|
||||
# await other_agent.update_async(session, actor=actor, no_commit=True)
|
||||
await other_agent.update_async(session, actor=actor)
|
||||
if other_agent.agent_type == AgentType.sleeptime_agent:
|
||||
# Check if block with same label already exists
|
||||
existing_block = next((b for b in other_agent.core_memory if b.label == block.label), None)
|
||||
if not existing_block:
|
||||
other_agent.core_memory.append(block)
|
||||
await other_agent.update_async(session, actor=actor)
|
||||
except NoResultFound:
|
||||
# Agent might not exist anymore, skip
|
||||
continue
|
||||
@@ -2321,15 +2327,6 @@ class AgentManager:
|
||||
# Use Turbopuffer for vector search if archive is configured for TPUF
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
# Generate embedding for query
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# Query Turbopuffer - use hybrid search when text is available
|
||||
tpuf_client = TurbopufferClient()
|
||||
@@ -2488,13 +2485,15 @@ class AgentManager:
|
||||
|
||||
# Get results using existing passage query method
|
||||
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
# Only use embedding-based search if embedding config is available
|
||||
use_embedding_search = agent_state.embedding_config is not None
|
||||
passages_with_metadata = await self.query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
embed_query=True,
|
||||
embed_query=use_embedding_search,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_mode,
|
||||
start_date=start_date,
|
||||
@@ -3053,10 +3052,19 @@ class AgentManager:
|
||||
)
|
||||
|
||||
# Apply cursor-based pagination
|
||||
if before:
|
||||
query = query.where(BlockModel.id < before)
|
||||
if after:
|
||||
query = query.where(BlockModel.id > after)
|
||||
# Note: cursor direction must account for sort order
|
||||
# - ascending order: "after X" means id > X, "before X" means id < X
|
||||
# - descending order: "after X" means id < X, "before X" means id > X
|
||||
if ascending:
|
||||
if before:
|
||||
query = query.where(BlockModel.id < before)
|
||||
if after:
|
||||
query = query.where(BlockModel.id > after)
|
||||
else:
|
||||
if before:
|
||||
query = query.where(BlockModel.id > before)
|
||||
if after:
|
||||
query = query.where(BlockModel.id < after)
|
||||
|
||||
# Apply sorting - use id instead of created_at for core memory blocks
|
||||
if ascending:
|
||||
|
||||
@@ -33,6 +33,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.group import Group, GroupCreate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.mcp import MCPServer
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.source import Source
|
||||
@@ -472,6 +473,7 @@ class AgentSerializationManager:
|
||||
dry_run: bool = False,
|
||||
env_vars: Optional[Dict[str, Any]] = None,
|
||||
override_embedding_config: Optional[EmbeddingConfig] = None,
|
||||
override_llm_config: Optional[LLMConfig] = None,
|
||||
project_id: Optional[str] = None,
|
||||
) -> ImportResult:
|
||||
"""
|
||||
@@ -672,6 +674,11 @@ class AgentSerializationManager:
|
||||
agent_schema.embedding_config = override_embedding_config
|
||||
agent_schema.embedding = override_embedding_config.handle
|
||||
|
||||
# Override llm_config if provided (keeps other defaults like context size)
|
||||
if override_llm_config:
|
||||
agent_schema.llm_config = override_llm_config
|
||||
agent_schema.model = override_llm_config.handle
|
||||
|
||||
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
|
||||
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, or_, select
|
||||
|
||||
from letta.errors import EmbeddingConfigRequiredError
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
|
||||
@@ -32,7 +31,7 @@ class ArchiveManager:
|
||||
async def create_archive_async(
|
||||
self,
|
||||
name: str,
|
||||
embedding_config: EmbeddingConfig,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
description: Optional[str] = None,
|
||||
actor: PydanticUser = None,
|
||||
) -> PydanticArchive:
|
||||
@@ -289,6 +288,7 @@ class ArchiveManager:
|
||||
text: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
created_at: Optional[str] = None,
|
||||
actor: PydanticUser = None,
|
||||
) -> PydanticPassage:
|
||||
"""Create a passage in an archive.
|
||||
@@ -298,6 +298,7 @@ class ArchiveManager:
|
||||
text: The text content of the passage
|
||||
metadata: Optional metadata for the passage
|
||||
tags: Optional tags for categorizing the passage
|
||||
created_at: Optional creation datetime in ISO 8601 format
|
||||
actor: User performing the operation
|
||||
|
||||
Returns:
|
||||
@@ -312,13 +313,20 @@ class ArchiveManager:
|
||||
# Verify the archive exists and user has access
|
||||
archive = await self.get_archive_by_id_async(archive_id=archive_id, actor=actor)
|
||||
|
||||
# Generate embeddings for the text
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=archive.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([text], archive.embedding_config)
|
||||
embedding = embeddings[0] if embeddings else None
|
||||
# Generate embeddings for the text if embedding config is available
|
||||
embedding = None
|
||||
if archive.embedding_config is not None:
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=archive.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([text], archive.embedding_config)
|
||||
embedding = embeddings[0] if embeddings else None
|
||||
|
||||
# Parse created_at from ISO string if provided
|
||||
parsed_created_at = None
|
||||
if created_at:
|
||||
parsed_created_at = datetime.fromisoformat(created_at)
|
||||
|
||||
# Create the passage object with embedding
|
||||
passage = PydanticPassage(
|
||||
@@ -329,6 +337,7 @@ class ArchiveManager:
|
||||
tags=tags,
|
||||
embedding_config=archive.embedding_config,
|
||||
embedding=embedding,
|
||||
created_at=parsed_created_at,
|
||||
)
|
||||
|
||||
# Use PassageManager to create the passage
|
||||
@@ -345,13 +354,14 @@ class ArchiveManager:
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Insert to Turbopuffer with the same ID as SQL
|
||||
# Insert to Turbopuffer with the same ID as SQL, reusing existing embedding
|
||||
await tpuf_client.insert_archival_memories(
|
||||
archive_id=archive.id,
|
||||
text_chunks=[created_passage.text],
|
||||
passage_ids=[created_passage.id],
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
embeddings=[created_passage.embedding],
|
||||
)
|
||||
logger.info(f"Uploaded passage {created_passage.id} to Turbopuffer for archive {archive_id}")
|
||||
except Exception as e:
|
||||
@@ -362,6 +372,92 @@ class ArchiveManager:
|
||||
logger.info(f"Created passage {created_passage.id} in archive {archive_id}")
|
||||
return created_passage
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@trace_method
|
||||
async def create_passages_in_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
passages: List[Dict],
|
||||
actor: PydanticUser = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Create multiple passages in an archive.
|
||||
|
||||
Args:
|
||||
archive_id: ID of the archive to add the passages to
|
||||
passages: Passage create payloads
|
||||
actor: User performing the operation
|
||||
|
||||
Returns:
|
||||
The created passages
|
||||
|
||||
Raises:
|
||||
NoResultFound: If archive not found
|
||||
"""
|
||||
if not passages:
|
||||
return []
|
||||
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.services.passage_manager import PassageManager
|
||||
|
||||
archive = await self.get_archive_by_id_async(archive_id=archive_id, actor=actor)
|
||||
|
||||
texts = [passage["text"] for passage in passages]
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=archive.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(texts, archive.embedding_config)
|
||||
|
||||
if len(embeddings) != len(passages):
|
||||
raise ValueError("Embedding response count does not match passages count")
|
||||
|
||||
# Build PydanticPassage objects for batch creation
|
||||
pydantic_passages: List[PydanticPassage] = []
|
||||
for passage_payload, embedding in zip(passages, embeddings):
|
||||
# Parse created_at from ISO string if provided
|
||||
created_at = passage_payload.get("created_at")
|
||||
if created_at and isinstance(created_at, str):
|
||||
created_at = datetime.fromisoformat(created_at)
|
||||
|
||||
passage = PydanticPassage(
|
||||
text=passage_payload["text"],
|
||||
archive_id=archive_id,
|
||||
organization_id=actor.organization_id,
|
||||
metadata=passage_payload.get("metadata") or {},
|
||||
tags=passage_payload.get("tags"),
|
||||
embedding_config=archive.embedding_config,
|
||||
embedding=embedding,
|
||||
created_at=created_at,
|
||||
)
|
||||
pydantic_passages.append(passage)
|
||||
|
||||
# Use batch create for efficient single-transaction insert
|
||||
passage_manager = PassageManager()
|
||||
created_passages = await passage_manager.create_agent_passages_async(
|
||||
pydantic_passages=pydantic_passages,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.insert_archival_memories(
|
||||
archive_id=archive.id,
|
||||
text_chunks=[passage.text for passage in created_passages],
|
||||
passage_ids=[passage.id for passage in created_passages],
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
)
|
||||
logger.info(f"Uploaded {len(created_passages)} passages to Turbopuffer for archive {archive_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload passages to Turbopuffer: {e}")
|
||||
|
||||
logger.info(f"Created {len(created_passages)} passages in archive {archive_id}")
|
||||
return created_passages
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
@raise_on_invalid_id(param_name="passage_id", expected_prefix=PrimitiveType.PASSAGE)
|
||||
@@ -433,9 +529,7 @@ class ArchiveManager:
|
||||
)
|
||||
return archive
|
||||
|
||||
# Create a default archive for this agent
|
||||
if agent_state.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="create_default_archive")
|
||||
# Create a default archive for this agent (embedding_config is optional)
|
||||
archive_name = f"{agent_state.name}'s Archive"
|
||||
archive = await self.create_archive_async(
|
||||
name=archive_name,
|
||||
|
||||
@@ -508,7 +508,7 @@ class BlockManager:
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
async def get_block_by_id_async(self, block_id: str, actor: PydanticUser) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its ID, including tags."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
@@ -523,7 +523,7 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: PydanticUser) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
|
||||
if not block_ids:
|
||||
return []
|
||||
@@ -540,9 +540,8 @@ class BlockManager:
|
||||
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
|
||||
)
|
||||
|
||||
# Apply access control if actor is provided
|
||||
if actor:
|
||||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
# Apply access control - actor is required for org-scoping
|
||||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
|
||||
# TODO: Add soft delete filter if applicable
|
||||
# if hasattr(BlockModel, "is_deleted"):
|
||||
|
||||
@@ -105,9 +105,53 @@ class ConversationManager:
|
||||
actor: PydanticUser,
|
||||
limit: int = 50,
|
||||
after: Optional[str] = None,
|
||||
summary_search: Optional[str] = None,
|
||||
) -> List[PydanticConversation]:
|
||||
"""List conversations for an agent with cursor-based pagination."""
|
||||
"""List conversations for an agent with cursor-based pagination.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID to list conversations for
|
||||
actor: The user performing the action
|
||||
limit: Maximum number of conversations to return
|
||||
after: Cursor for pagination (conversation ID)
|
||||
summary_search: Optional text to search for within the summary field
|
||||
|
||||
Returns:
|
||||
List of conversations matching the criteria
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# If summary search is provided, use custom query
|
||||
if summary_search:
|
||||
from sqlalchemy import and_
|
||||
|
||||
stmt = (
|
||||
select(ConversationModel)
|
||||
.where(
|
||||
and_(
|
||||
ConversationModel.agent_id == agent_id,
|
||||
ConversationModel.organization_id == actor.organization_id,
|
||||
ConversationModel.summary.isnot(None),
|
||||
ConversationModel.summary.contains(summary_search),
|
||||
)
|
||||
)
|
||||
.order_by(ConversationModel.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if after:
|
||||
# Add cursor filtering
|
||||
after_conv = await ConversationModel.read_async(
|
||||
db_session=session,
|
||||
identifier=after,
|
||||
actor=actor,
|
||||
)
|
||||
stmt = stmt.where(ConversationModel.created_at < after_conv.created_at)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
conversations = result.scalars().all()
|
||||
return [conv.to_pydantic() for conv in conversations]
|
||||
|
||||
# Use default list logic
|
||||
conversations = await ConversationModel.list_async(
|
||||
db_session=session,
|
||||
actor=actor,
|
||||
|
||||
@@ -91,18 +91,17 @@ class FileManager:
|
||||
await session.rollback()
|
||||
return await self.get_file_by_id(file_metadata.id, actor=actor)
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
@trace_method
|
||||
# @async_redis_cache(
|
||||
# key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}",
|
||||
# key_func=lambda self, file_id, actor, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id}:{include_content}:{strip_directory_prefix}",
|
||||
# prefix="file_content",
|
||||
# ttl_s=3600,
|
||||
# model_class=PydanticFileMetadata,
|
||||
# )
|
||||
async def get_file_by_id(
|
||||
self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False
|
||||
self, file_id: str, actor: PydanticUser, *, include_content: bool = False, strip_directory_prefix: bool = False
|
||||
) -> Optional[PydanticFileMetadata]:
|
||||
"""Retrieve a file by its ID.
|
||||
|
||||
@@ -479,7 +478,7 @@ class FileManager:
|
||||
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
|
||||
"""Delete a file by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
|
||||
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor)
|
||||
|
||||
# invalidate cache for this file before deletion
|
||||
await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id)
|
||||
|
||||
@@ -1194,9 +1194,9 @@ async def build_agent_passage_query(
|
||||
"""
|
||||
|
||||
# Handle embedding for vector search
|
||||
# If embed_query is True but no embedding_config, fall through to text search
|
||||
embedded_text = None
|
||||
if embed_query:
|
||||
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
||||
if embed_query and embedding_config is not None:
|
||||
assert query_text is not None, "query_text must be specified for vector search"
|
||||
|
||||
# Use the new LLMClient for embeddings
|
||||
|
||||
@@ -63,7 +63,7 @@ class LLMBatchManager:
|
||||
self,
|
||||
llm_batch_id: str,
|
||||
status: JobStatus,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
actor: PydanticUser,
|
||||
latest_polling_response: Optional[BetaMessageBatch] = None,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Update a batch job’s status and optionally its polling response."""
|
||||
@@ -107,8 +107,8 @@ class LLMBatchManager:
|
||||
async def list_llm_batch_jobs_async(
|
||||
self,
|
||||
letta_batch_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: Optional[int] = None,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> List[PydanticLLMBatchJob]:
|
||||
"""
|
||||
@@ -153,8 +153,8 @@ class LLMBatchManager:
|
||||
async def get_messages_for_letta_batch_async(
|
||||
self,
|
||||
letta_batch_job_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: int = 100,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
sort_descending: bool = True,
|
||||
cursor: Optional[str] = None, # Message ID as cursor
|
||||
|
||||
@@ -419,6 +419,9 @@ class MCPManager:
|
||||
"""
|
||||
# Create base MCPServer object
|
||||
if isinstance(server_config, StdioServerConfig):
|
||||
# Check if stdio MCP servers are disabled (not suitable for multi-tenant deployments)
|
||||
if tool_settings.mcp_disable_stdio:
|
||||
raise ValueError("MCP stdio servers are disabled. Set MCP_DISABLE_STDIO=false to enable them.")
|
||||
mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config)
|
||||
elif isinstance(server_config, SSEServerConfig):
|
||||
mcp_server = MCPServer(
|
||||
@@ -832,6 +835,9 @@ class MCPManager:
|
||||
server_config = SSEServerConfig(**server_config.model_dump())
|
||||
return AsyncFastMCPSSEClient(server_config=server_config, oauth=oauth, agent_id=agent_id)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
# Check if stdio MCP servers are disabled (not suitable for multi-tenant deployments)
|
||||
if tool_settings.mcp_disable_stdio:
|
||||
raise ValueError("MCP stdio servers are disabled. Set MCP_DISABLE_STDIO=false to enable them.")
|
||||
server_config = StdioServerConfig(**server_config.model_dump())
|
||||
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=None, agent_id=agent_id)
|
||||
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
|
||||
|
||||
@@ -516,6 +516,9 @@ class MCPServerManager:
|
||||
"""
|
||||
# Create base MCPServer object
|
||||
if isinstance(server_config, StdioServerConfig):
|
||||
# Check if stdio MCP servers are disabled (not suitable for multi-tenant deployments)
|
||||
if tool_settings.mcp_disable_stdio:
|
||||
raise ValueError("MCP stdio servers are disabled. Set MCP_DISABLE_STDIO=false to enable them.")
|
||||
mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config)
|
||||
elif isinstance(server_config, SSEServerConfig):
|
||||
mcp_server = MCPServer(
|
||||
@@ -1003,6 +1006,9 @@ class MCPServerManager:
|
||||
server_config = SSEServerConfig(**server_config.model_dump())
|
||||
return AsyncFastMCPSSEClient(server_config=server_config, oauth=oauth, agent_id=agent_id)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
# Check if stdio MCP servers are disabled (not suitable for multi-tenant deployments)
|
||||
if tool_settings.mcp_disable_stdio:
|
||||
raise ValueError("MCP stdio servers are disabled. Set MCP_DISABLE_STDIO=false to enable them.")
|
||||
server_config = StdioServerConfig(**server_config.model_dump())
|
||||
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=None, agent_id=agent_id)
|
||||
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.errors import EmbeddingConfigRequiredError
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
@@ -193,6 +192,93 @@ class PassageManager:
|
||||
|
||||
return passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_agent_passages_async(self, pydantic_passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple agent passages in a single database transaction.
|
||||
|
||||
Args:
|
||||
pydantic_passages: List of passages to create
|
||||
actor: User performing the operation
|
||||
|
||||
Returns:
|
||||
List of created passages
|
||||
"""
|
||||
if not pydantic_passages:
|
||||
return []
|
||||
|
||||
import numpy as np
|
||||
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
|
||||
use_tpuf = should_use_tpuf()
|
||||
passage_objects: List[ArchivalPassage] = []
|
||||
all_tags_data: List[tuple] = [] # (passage_index, tags) for creating tags after passages are created
|
||||
|
||||
for idx, pydantic_passage in enumerate(pydantic_passages):
|
||||
if not pydantic_passage.archive_id:
|
||||
raise ValueError("Agent passage must have archive_id")
|
||||
if pydantic_passage.source_id:
|
||||
raise ValueError("Agent passage cannot have source_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
|
||||
# Deduplicate tags if provided (for dual storage consistency)
|
||||
tags = data.get("tags")
|
||||
if tags:
|
||||
tags = list(set(tags))
|
||||
all_tags_data.append((idx, tags))
|
||||
|
||||
# Pad embeddings to MAX_EMBEDDING_DIM for pgvector (only when using Postgres as vector DB)
|
||||
embedding = data["embedding"]
|
||||
if embedding and not use_tpuf:
|
||||
np_embedding = np.array(embedding)
|
||||
if np_embedding.shape[0] != MAX_EMBEDDING_DIM:
|
||||
embedding = np.pad(np_embedding, (0, MAX_EMBEDDING_DIM - np_embedding.shape[0]), mode="constant").tolist()
|
||||
|
||||
# Sanitize text to remove null bytes which PostgreSQL rejects
|
||||
text = data["text"]
|
||||
if text and "\x00" in text:
|
||||
text = text.replace("\x00", "")
|
||||
logger.warning(f"Removed null bytes from passage text (length: {len(data['text'])} -> {len(text)})")
|
||||
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": text,
|
||||
"embedding": embedding,
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata_", {}),
|
||||
"tags": tags,
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
agent_fields = {"archive_id": data["archive_id"]}
|
||||
passage = ArchivalPassage(**common_fields, **agent_fields)
|
||||
passage_objects.append(passage)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Batch create all passages in a single transaction
|
||||
created_passages = await ArchivalPassage.batch_create_async(
|
||||
items=passage_objects,
|
||||
db_session=session,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Create tags for passages that have them
|
||||
for idx, tags in all_tags_data:
|
||||
created_passage = created_passages[idx]
|
||||
await self._create_tags_for_passage(
|
||||
session=session,
|
||||
passage_id=created_passage.id,
|
||||
archive_id=created_passage.archive_id,
|
||||
organization_id=created_passage.organization_id,
|
||||
tags=tags,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return [p.to_pydantic() for p in created_passages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_source_passage_async(
|
||||
@@ -474,15 +560,6 @@ class PassageManager:
|
||||
Returns:
|
||||
List of created passage objects
|
||||
"""
|
||||
if agent_state.embedding_config is None:
|
||||
raise EmbeddingConfigRequiredError(agent_id=agent_state.id, operation="insert_passage")
|
||||
|
||||
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=agent_state.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Get or create the default archive for the agent
|
||||
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=agent_state, actor=actor)
|
||||
|
||||
@@ -493,8 +570,16 @@ class PassageManager:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Generate embeddings for all chunks using the new async API
|
||||
embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
|
||||
# Generate embeddings if embedding config is available
|
||||
if agent_state.embedding_config is not None:
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=agent_state.embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
|
||||
else:
|
||||
# No embedding config - store passages without embeddings (text search only)
|
||||
embeddings = [None] * len(text_chunks)
|
||||
|
||||
passages = []
|
||||
|
||||
@@ -525,20 +610,21 @@ class PassageManager:
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Extract IDs and texts from the created passages
|
||||
# Extract IDs, texts, and embeddings from the created passages
|
||||
passage_ids = [p.id for p in passages]
|
||||
passage_texts = [p.text for p in passages]
|
||||
passage_embeddings = [p.embedding for p in passages]
|
||||
|
||||
# Insert to Turbopuffer with the same IDs as SQL
|
||||
# TurbopufferClient will generate embeddings internally using default config
|
||||
# Insert to Turbopuffer with the same IDs as SQL, reusing existing embeddings
|
||||
await tpuf_client.insert_archival_memories(
|
||||
archive_id=archive.id,
|
||||
text_chunks=passage_texts,
|
||||
passage_ids=passage_ids, # Use same IDs as SQL
|
||||
passage_ids=passage_ids,
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
tags=tags,
|
||||
created_at=passages[0].created_at if passages else None,
|
||||
embeddings=passage_embeddings,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
|
||||
|
||||
@@ -98,9 +98,35 @@ class ProviderManager:
|
||||
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
|
||||
|
||||
await deleted_provider.update_async(session, actor=actor)
|
||||
|
||||
# Also restore any soft-deleted models associated with this provider
|
||||
# This is needed because the unique constraint on provider_models doesn't include is_deleted,
|
||||
# so soft-deleted models would block creation of new models with the same handle
|
||||
from sqlalchemy import update
|
||||
|
||||
restore_models_stmt = (
|
||||
update(ProviderModelORM)
|
||||
.where(
|
||||
and_(
|
||||
ProviderModelORM.provider_id == deleted_provider.id,
|
||||
ProviderModelORM.is_deleted == True,
|
||||
)
|
||||
)
|
||||
.values(is_deleted=False)
|
||||
)
|
||||
result = await session.execute(restore_models_stmt)
|
||||
if result.rowcount > 0:
|
||||
logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'")
|
||||
|
||||
# Commit the provider and model restoration before syncing
|
||||
# This is needed because _sync_default_models_for_provider opens a new session
|
||||
# that can't see uncommitted changes from this session
|
||||
await session.commit()
|
||||
|
||||
provider_pydantic = deleted_provider.to_pydantic()
|
||||
|
||||
# For BYOK providers, automatically sync available models
|
||||
# This will add any new models and remove any that are no longer available
|
||||
if is_byok:
|
||||
await self._sync_default_models_for_provider(provider_pydantic, actor)
|
||||
|
||||
@@ -119,6 +145,14 @@ class ProviderManager:
|
||||
# if provider.name == provider.provider_type.value:
|
||||
# raise ValueError("Provider name must be unique and different from provider type")
|
||||
|
||||
# Fill in schema-default base_url if not provided
|
||||
# This ensures providers like ZAI get their default endpoint persisted to DB
|
||||
# rather than relying on cast_to_subtype() at read time
|
||||
if provider.base_url is None:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
if typed_provider.base_url is not None:
|
||||
provider.base_url = typed_provider.base_url
|
||||
|
||||
# Only assign organization id for non-base providers
|
||||
# Base providers should be globally accessible (org_id = None)
|
||||
if is_byok:
|
||||
@@ -201,6 +235,21 @@ class ProviderManager:
|
||||
await existing_provider.update_async(session, actor=actor)
|
||||
return existing_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None:
|
||||
"""Update the last_synced timestamp for a provider.
|
||||
|
||||
Note: actor is optional to support system-level operations (e.g., during server initialization
|
||||
for global providers). When actor is provided, org-scoping is enforced.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
|
||||
provider.last_synced = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@trace_method
|
||||
@@ -476,81 +525,19 @@ class ProviderManager:
|
||||
|
||||
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
|
||||
"""Sync models for a newly created BYOK provider by querying the provider's API."""
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
# Get the provider class and create an instance
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.providers.anthropic import AnthropicProvider
|
||||
from letta.schemas.providers.azure import AzureProvider
|
||||
from letta.schemas.providers.bedrock import BedrockProvider
|
||||
from letta.schemas.providers.google_gemini import GoogleAIProvider
|
||||
from letta.schemas.providers.groq import GroqProvider
|
||||
from letta.schemas.providers.ollama import OllamaProvider
|
||||
from letta.schemas.providers.openai import OpenAIProvider
|
||||
from letta.schemas.providers.zai import ZAIProvider
|
||||
# Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
llm_models = await typed_provider.list_llm_models_async()
|
||||
embedding_models = await typed_provider.list_embedding_models_async()
|
||||
|
||||
# ChatGPT OAuth requires cast_to_subtype to preserve api_key_enc and id
|
||||
# (needed for OAuth token refresh and database persistence)
|
||||
if provider.provider_type == ProviderType.chatgpt_oauth:
|
||||
provider_instance = provider.cast_to_subtype()
|
||||
else:
|
||||
provider_type_to_class = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
"groq": GroqProvider,
|
||||
"google": GoogleAIProvider,
|
||||
"ollama": OllamaProvider,
|
||||
"bedrock": BedrockProvider,
|
||||
"azure": AzureProvider,
|
||||
"zai": ZAIProvider,
|
||||
}
|
||||
|
||||
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)
|
||||
provider_class = provider_type_to_class.get(provider_type)
|
||||
|
||||
if not provider_class:
|
||||
logger.warning(f"No provider class found for type '{provider_type}'")
|
||||
return
|
||||
|
||||
# Create provider instance with necessary parameters
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
kwargs = {
|
||||
"name": provider.name,
|
||||
"api_key": api_key,
|
||||
"provider_category": provider.provider_category,
|
||||
}
|
||||
if provider.base_url:
|
||||
kwargs["base_url"] = provider.base_url
|
||||
if access_key:
|
||||
kwargs["access_key"] = access_key
|
||||
if provider.region:
|
||||
kwargs["region"] = provider.region
|
||||
if provider.api_version:
|
||||
kwargs["api_version"] = provider.api_version
|
||||
|
||||
provider_instance = provider_class(**kwargs)
|
||||
|
||||
# Query the provider's API for available models
|
||||
llm_models = await provider_instance.list_llm_models_async()
|
||||
embedding_models = await provider_instance.list_embedding_models_async()
|
||||
|
||||
# Update handles and provider_name for BYOK providers
|
||||
for model in llm_models:
|
||||
model.provider_name = provider.name
|
||||
model.handle = f"{provider.name}/{model.model}"
|
||||
model.provider_category = provider.provider_category
|
||||
|
||||
for model in embedding_models:
|
||||
model.handle = f"{provider.name}/{model.embedding_model}"
|
||||
|
||||
# Use existing sync_provider_models_async to save to database
|
||||
await self.sync_provider_models_async(
|
||||
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=actor.organization_id
|
||||
provider=provider,
|
||||
llm_models=llm_models,
|
||||
embedding_models=embedding_models,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
await self.update_provider_last_synced_async(provider.id, actor=actor)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
|
||||
@@ -713,7 +700,7 @@ class ProviderManager:
|
||||
enabled=True,
|
||||
model_endpoint_type=llm_config.model_endpoint_type,
|
||||
max_context_window=llm_config.context_window,
|
||||
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek"],
|
||||
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek", "openrouter"],
|
||||
supports_tool_calling=True, # Assume true for LLMs for now
|
||||
)
|
||||
|
||||
@@ -737,14 +724,27 @@ class ProviderManager:
|
||||
# Roll back the session to clear the failed transaction
|
||||
await session.rollback()
|
||||
else:
|
||||
# Check if max_context_window needs to be updated
|
||||
# Check if max_context_window or model_endpoint_type needs to be updated
|
||||
existing_model = existing[0]
|
||||
needs_update = False
|
||||
|
||||
if existing_model.max_context_window != llm_config.context_window:
|
||||
logger.info(
|
||||
f" Updating LLM model {llm_config.handle} max_context_window: "
|
||||
f"{existing_model.max_context_window} -> {llm_config.context_window}"
|
||||
)
|
||||
existing_model.max_context_window = llm_config.context_window
|
||||
needs_update = True
|
||||
|
||||
if existing_model.model_endpoint_type != llm_config.model_endpoint_type:
|
||||
logger.info(
|
||||
f" Updating LLM model {llm_config.handle} model_endpoint_type: "
|
||||
f"{existing_model.model_endpoint_type} -> {llm_config.model_endpoint_type}"
|
||||
)
|
||||
existing_model.model_endpoint_type = llm_config.model_endpoint_type
|
||||
needs_update = True
|
||||
|
||||
if needs_update:
|
||||
await existing_model.update_async(session)
|
||||
else:
|
||||
logger.info(f" LLM model {llm_config.handle} already exists (ID: {existing[0].id}), skipping")
|
||||
@@ -801,7 +801,17 @@ class ProviderManager:
|
||||
# Roll back the session to clear the failed transaction
|
||||
await session.rollback()
|
||||
else:
|
||||
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
|
||||
# Check if model_endpoint_type needs to be updated
|
||||
existing_model = existing[0]
|
||||
if existing_model.model_endpoint_type != embedding_config.embedding_endpoint_type:
|
||||
logger.info(
|
||||
f" Updating embedding model {embedding_config.handle} model_endpoint_type: "
|
||||
f"{existing_model.model_endpoint_type} -> {embedding_config.embedding_endpoint_type}"
|
||||
)
|
||||
existing_model.model_endpoint_type = embedding_config.embedding_endpoint_type
|
||||
await existing_model.update_async(session)
|
||||
else:
|
||||
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -972,8 +982,8 @@ class ProviderManager:
|
||||
# Determine the model endpoint - use provider's base_url if set,
|
||||
# otherwise use provider-specific defaults
|
||||
|
||||
if provider.base_url:
|
||||
model_endpoint = provider.base_url
|
||||
if typed_provider.base_url:
|
||||
model_endpoint = typed_provider.base_url
|
||||
elif provider.provider_type == ProviderType.chatgpt_oauth:
|
||||
# ChatGPT OAuth uses the ChatGPT backend API, not a generic endpoint pattern
|
||||
from letta.schemas.providers.chatgpt_oauth import CHATGPT_CODEX_ENDPOINT
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.orm.provider_trace_metadata import ProviderTraceMetadata as ProviderTraceMetadataModel
|
||||
from letta.schemas.provider_trace import ProviderTrace, ProviderTraceMetadata
|
||||
from letta.schemas.user import User
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.provider_trace_backends.base import ProviderTraceBackendClient
|
||||
from letta.settings import telemetry_settings
|
||||
|
||||
|
||||
class PostgresProviderTraceBackend(ProviderTraceBackendClient):
|
||||
@@ -15,7 +17,17 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
|
||||
self,
|
||||
actor: User,
|
||||
provider_trace: ProviderTrace,
|
||||
) -> ProviderTrace | ProviderTraceMetadata:
|
||||
if telemetry_settings.provider_trace_pg_metadata_only:
|
||||
return await self._create_metadata_only_async(actor, provider_trace)
|
||||
return await self._create_full_async(actor, provider_trace)
|
||||
|
||||
async def _create_full_async(
|
||||
self,
|
||||
actor: User,
|
||||
provider_trace: ProviderTrace,
|
||||
) -> ProviderTrace:
|
||||
"""Write full provider trace to provider_traces table."""
|
||||
async with db_registry.async_session() as session:
|
||||
provider_trace_model = ProviderTraceModel(**provider_trace.model_dump())
|
||||
provider_trace_model.organization_id = actor.organization_id
|
||||
@@ -31,11 +43,44 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
|
||||
await provider_trace_model.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
return provider_trace_model.to_pydantic()
|
||||
|
||||
async def _create_metadata_only_async(
|
||||
self,
|
||||
actor: User,
|
||||
provider_trace: ProviderTrace,
|
||||
) -> ProviderTraceMetadata:
|
||||
"""Write metadata-only trace to provider_trace_metadata table."""
|
||||
metadata = ProviderTraceMetadata(
|
||||
id=provider_trace.id,
|
||||
step_id=provider_trace.step_id,
|
||||
agent_id=provider_trace.agent_id,
|
||||
agent_tags=provider_trace.agent_tags,
|
||||
call_type=provider_trace.call_type,
|
||||
run_id=provider_trace.run_id,
|
||||
source=provider_trace.source,
|
||||
org_id=provider_trace.org_id,
|
||||
user_id=provider_trace.user_id,
|
||||
)
|
||||
metadata_model = ProviderTraceMetadataModel(**metadata.model_dump())
|
||||
metadata_model.organization_id = actor.organization_id
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
await metadata_model.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
return metadata_model.to_pydantic()
|
||||
|
||||
async def get_by_step_id_async(
|
||||
self,
|
||||
step_id: str,
|
||||
actor: User,
|
||||
) -> ProviderTrace | None:
|
||||
"""Read from provider_traces table. Always reads from full table regardless of write flag."""
|
||||
return await self._get_full_by_step_id_async(step_id, actor)
|
||||
|
||||
async def _get_full_by_step_id_async(
|
||||
self,
|
||||
step_id: str,
|
||||
actor: User,
|
||||
) -> ProviderTrace | None:
|
||||
"""Read from provider_traces table."""
|
||||
async with db_registry.async_session() as session:
|
||||
provider_trace_model = await ProviderTraceModel.read_async(
|
||||
db_session=session,
|
||||
@@ -43,3 +88,17 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
|
||||
actor=actor,
|
||||
)
|
||||
return provider_trace_model.to_pydantic() if provider_trace_model else None
|
||||
|
||||
async def _get_metadata_by_step_id_async(
|
||||
self,
|
||||
step_id: str,
|
||||
actor: User,
|
||||
) -> ProviderTraceMetadata | None:
|
||||
"""Read from provider_trace_metadata table."""
|
||||
async with db_registry.async_session() as session:
|
||||
metadata_model = await ProviderTraceMetadataModel.read_async(
|
||||
db_session=session,
|
||||
step_id=step_id,
|
||||
actor=actor,
|
||||
)
|
||||
return metadata_model.to_pydantic() if metadata_model else None
|
||||
|
||||
@@ -17,7 +17,8 @@ logger = get_logger(__name__)
|
||||
# Protocol version for crouton communication.
|
||||
# Bump this when making breaking changes to the record schema.
|
||||
# Must match ProtocolVersion in apps/crouton/main.go.
|
||||
PROTOCOL_VERSION = 1
|
||||
# v2: Added user_id, compaction_settings (summarization), llm_config (non-summarization)
|
||||
PROTOCOL_VERSION = 2
|
||||
|
||||
|
||||
class SocketProviderTraceBackend(ProviderTraceBackendClient):
|
||||
@@ -94,6 +95,11 @@ class SocketProviderTraceBackend(ProviderTraceBackendClient):
|
||||
"error": error,
|
||||
"error_type": error_type,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
# v2 protocol fields
|
||||
"org_id": provider_trace.org_id,
|
||||
"user_id": provider_trace.user_id,
|
||||
"compaction_settings": provider_trace.compaction_settings,
|
||||
"llm_config": provider_trace.llm_config,
|
||||
}
|
||||
|
||||
# Fire-and-forget in background thread
|
||||
|
||||
@@ -455,9 +455,11 @@ class RunManager:
|
||||
# Dispatch callback outside of database session if needed
|
||||
if needs_callback:
|
||||
if refresh_result_messages:
|
||||
# Defensive: ensure stop_reason is never None
|
||||
stop_reason_value = pydantic_run.stop_reason if pydantic_run.stop_reason else StopReasonType.completed
|
||||
result = LettaResponse(
|
||||
messages=await self.get_run_messages(run_id=run_id, actor=actor),
|
||||
stop_reason=LettaStopReason(stop_reason=pydantic_run.stop_reason),
|
||||
stop_reason=LettaStopReason(stop_reason=stop_reason_value),
|
||||
usage=await self.get_run_usage(run_id=run_id, actor=actor),
|
||||
)
|
||||
final_metadata["result"] = result.model_dump()
|
||||
@@ -719,7 +721,7 @@ class RunManager:
|
||||
)
|
||||
|
||||
# Use the standard function to create properly formatted approval response messages
|
||||
approval_response_messages = create_approval_response_message_from_input(
|
||||
approval_response_messages = await create_approval_response_message_from_input(
|
||||
agent_state=agent_state,
|
||||
input_message=approval_input,
|
||||
run_id=run_id,
|
||||
|
||||
@@ -167,9 +167,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_sandbox_config_by_type_async(
|
||||
self, type: SandboxType, actor: Optional[PydanticUser] = None
|
||||
) -> Optional[PydanticSandboxConfig]:
|
||||
async def get_sandbox_config_by_type_async(self, type: SandboxType, actor: PydanticUser) -> Optional[PydanticSandboxConfig]:
|
||||
"""Retrieve a sandbox config by its type."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
@@ -345,7 +343,7 @@ class SandboxConfigManager:
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
async def get_sandbox_env_var_by_key_and_sandbox_config_id_async(
|
||||
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
|
||||
self, key: str, sandbox_config_id: str, actor: PydanticUser
|
||||
) -> Optional[PydanticEnvVar]:
|
||||
"""Retrieve a sandbox environment variable by its key and sandbox_config_id."""
|
||||
async with db_registry.async_session() as session:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user