chore: bump version 0.16.5 (#3202)

This commit is contained in:
cthomas
2026-02-24 11:02:17 -08:00
committed by GitHub
410 changed files with 21274 additions and 6850 deletions

View File

@@ -16,7 +16,6 @@ from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager
from letta.settings import tool_settings
def pytest_configure(config):

View File

@@ -31,7 +31,7 @@ def get_support_status(passed_tests, feature_tests):
# Filter out error tests when checking for support
non_error_tests = [test for test in feature_tests if not test.endswith("_error")]
error_tests = [test for test in feature_tests if test.endswith("_error")]
[test for test in feature_tests if test.endswith("_error")]
# Check which non-error tests passed
passed_non_error_tests = [test for test in non_error_tests if test in passed_tests]
@@ -137,7 +137,7 @@ def get_github_repo_info():
else:
return None
return repo_path
except:
except Exception:
pass
# Default fallback
@@ -335,7 +335,7 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
# Format timestamp if it's a full ISO string
if "T" in str(last_scanned):
last_scanned = str(last_scanned).split("T")[0] # Just the date part
except:
except Exception:
last_scanned = "Unknown"
# Calculate support score for ranking

View File

@@ -1,16 +1,12 @@
import base64
import json
import os
import socket
import threading
import time
import uuid
from typing import Any, Dict, List
import httpx
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate, Run
from letta_client.core.api_error import ApiError
from letta_client.types import (
@@ -694,7 +690,7 @@ def test_token_streaming_agent_loop_error(
stream_tokens=True,
)
list(response)
except:
except Exception:
pass # only some models throw an error TODO: make this consistent
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)

View File

@@ -381,6 +381,10 @@ jobs:
GOOGLE_CLOUD_PROJECT: ${{ secrets.GOOGLE_CLOUD_PROJECT }}
GOOGLE_CLOUD_LOCATION: ${{ secrets.GOOGLE_CLOUD_LOCATION }}
# Real object store (required for git-backed memory integration test)
# Use DEV bucket/prefix variable to avoid prod resources.
LETTA_OBJECT_STORE_URI: ${{ vars.LETTA_OBJECT_STORE_URI_DEV }}
# Feature flags (shared across all test types)
LETTA_ENABLE_BATCH_JOB_POLLING: true

View File

@@ -23,3 +23,10 @@ repos:
- id: ruff-check
args: [ --fix ]
- id: ruff-format
- repo: local
hooks:
- id: ty
name: ty check
entry: uv run ty check .
language: python

View File

@@ -1,5 +1,5 @@
# Start with pgvector base for builder
FROM ankane/pgvector:v0.5.1 AS builder
FROM pgvector/pgvector:0.8.1-pg15 AS builder
# comment to trigger ci
# Install Python and required packages
RUN apt-get update && apt-get install -y \
@@ -39,7 +39,7 @@ COPY . .
RUN uv sync --frozen --no-dev --all-extras --python 3.11
# Runtime stage
FROM ankane/pgvector:v0.5.1 AS runtime
FROM pgvector/pgvector:0.8.1-pg15 AS runtime
# Overridable Node.js version with --build-arg NODE_VERSION
ARG NODE_VERSION=22

View File

@@ -8,8 +8,6 @@ Create Date: 2025-10-07 13:01:17.872405
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -8,8 +8,6 @@ Create Date: 2025-09-10 19:16:39.118760
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -8,8 +8,6 @@ Create Date: 2025-12-17 15:46:06.184858
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -8,8 +8,6 @@ Create Date: 2025-10-03 12:10:51.065067
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -0,0 +1,33 @@
"""add_usage_columns_to_steps
Revision ID: 3e54e2fa2f7e
Revises: a1b2c3d4e5f8
Create Date: 2026-02-03 16:35:51.327031
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "3e54e2fa2f7e"
down_revision: Union[str, None] = "a1b2c3d4e5f8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("steps", sa.Column("model_handle", sa.String(), nullable=True))
op.add_column("steps", sa.Column("cached_input_tokens", sa.Integer(), nullable=True))
op.add_column("steps", sa.Column("cache_write_tokens", sa.Integer(), nullable=True))
op.add_column("steps", sa.Column("reasoning_tokens", sa.Integer(), nullable=True))
def downgrade() -> None:
op.drop_column("steps", "reasoning_tokens")
op.drop_column("steps", "cache_write_tokens")
op.drop_column("steps", "cached_input_tokens")
op.drop_column("steps", "model_handle")

View File

@@ -9,6 +9,7 @@ Create Date: 2024-12-14 17:23:08.772554
from typing import Sequence, Union
import sqlalchemy as sa
from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects import postgresql
from alembic import op

View File

@@ -8,8 +8,6 @@ Create Date: 2025-09-19 10:58:19.658106
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -8,8 +8,6 @@ Create Date: 2025-10-06 13:17:09.918439
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -8,8 +8,6 @@ Create Date: 2025-11-11 19:16:00.000000
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.settings import settings

View File

@@ -8,8 +8,6 @@ Create Date: 2025-12-07 15:30:43.407495
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -8,8 +8,6 @@ Create Date: 2025-11-11 21:16:00.000000
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.settings import settings

View File

@@ -0,0 +1,29 @@
"""Add model and model_settings columns to conversations table for model overrides
Revision ID: b2c3d4e5f6a8
Revises: 3e54e2fa2f7e
Create Date: 2026-02-23 02:50:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "b2c3d4e5f6a8"
down_revision: Union[str, None] = "3e54e2fa2f7e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("conversations", sa.Column("model", sa.String(), nullable=True))
op.add_column("conversations", sa.Column("model_settings", sa.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column("conversations", "model_settings")
op.drop_column("conversations", "model")

View File

@@ -23,7 +23,7 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# determine backfill value based on current pinecone settings
try:
from pinecone import IndexEmbed, PineconeAsyncio
from pinecone import IndexEmbed, PineconeAsyncio # noqa: F401
pinecone_available = True
except ImportError:

View File

@@ -10,8 +10,6 @@ import json
import os
# Add the app directory to path to import our crypto utils
import sys
from pathlib import Path
from typing import Sequence, Union
import sqlalchemy as sa

View File

@@ -8,8 +8,6 @@ Create Date: 2025-11-07 15:43:59.446292
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.settings import settings

View File

@@ -8,8 +8,6 @@ Create Date: 2025-10-04 00:44:06.663817
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

412
conf.yaml Normal file
View File

@@ -0,0 +1,412 @@
# Letta Configuration File
# Place at ~/.letta/conf.yaml, ./conf.yaml, or set LETTA_CONFIG_PATH
# Environment variables take precedence over config file values
#
# Top-level keys and their env var mappings:
# letta: -> LETTA_*
# model: -> Provider-prefixed (OPENAI_*, ANTHROPIC_*, etc.)
# tool: -> Prefix-based (E2B_*, MCP_*, TOOL_*, etc.)
# datadog: -> DD_*
letta:
# =============================================================================
# Core Settings (LETTA_*)
# =============================================================================
debug: false
# environment: ""
# Default handles
# default_llm_handle: ""
# default_embedding_handle: ""
# SSE Streaming
enable_keepalive: true
keepalive_interval: 50.0
enable_cancellation_aware_streaming: true
# =============================================================================
# PostgreSQL (LETTA_PG_*)
# =============================================================================
pg:
# db: ""
# user: ""
# password: ""
# host: ""
# port: ""
# uri: ""
pool_size: 25
max_overflow: 10
pool_timeout: 30
pool_recycle: 1800
echo: false
# Connection pool settings (LETTA_POOL_*)
pool:
pre_ping: true
use_lifo: true
# Database settings (LETTA_DB_*)
# db:
# max_concurrent_sessions: ""
disable_sqlalchemy_pooling: true
enable_db_pool_monitoring: true
db_pool_monitoring_interval: 30
# =============================================================================
# Redis (LETTA_REDIS_*)
# =============================================================================
redis:
# host: ""
port: 6379
# =============================================================================
# Multi-Agent (LETTA_MULTI_AGENT_*)
# =============================================================================
multi_agent:
send_message_max_retries: 3
send_message_timeout: 1200
concurrent_sends: 50
# =============================================================================
# OTEL / Observability (LETTA_OTEL_*, LETTA_CLICKHOUSE_*)
# =============================================================================
otel:
# exporter_otlp_endpoint: ""
preferred_temporality: 1
clickhouse:
# endpoint: ""
database: otel
username: default
# password: ""
disable_tracing: false
llm_api_logging: true
track_last_agent_run: false
track_errored_messages: true
track_stop_reason: true
track_agent_run: true
track_provider_trace: true
# =============================================================================
# Uvicorn (LETTA_UVICORN_*)
# =============================================================================
uvicorn:
workers: 1
reload: false
timeout_keep_alive: 5
# Runtime settings
use_uvloop: false
use_granian: false
sqlalchemy_tracing: false
event_loop_threadpool_max_workers: 43
# =============================================================================
# Experimental
# =============================================================================
use_vertex_structured_outputs_experimental: false
use_asyncio_shield: true
# =============================================================================
# Lettuce (LETTA_USE_LETTUCE_*)
# =============================================================================
use_lettuce_for_file_uploads: false
# =============================================================================
# Batch Job Polling (LETTA_POLL_*, LETTA_BATCH_*)
# =============================================================================
enable_batch_job_polling: false
poll_running_llm_batches_interval_seconds: 300
poll_lock_retry_interval_seconds: 480
batch_job_polling_lookback_weeks: 2
# batch_job_polling_batch_size: ""
# =============================================================================
# LLM Timeouts (LETTA_LLM_*)
# =============================================================================
llm:
request_timeout_seconds: 60.0
stream_timeout_seconds: 600.0
# =============================================================================
# Pinecone (LETTA_PINECONE_*, LETTA_ENABLE_PINECONE, LETTA_UPSERT_PINECONE_INDICES)
# =============================================================================
enable_pinecone: false
upsert_pinecone_indices: false
pinecone:
# api_key: ""
source_index: sources
agent_index: recall
# =============================================================================
# Turbopuffer (LETTA_TPUF_*, LETTA_USE_TPUF, LETTA_EMBED_*)
# =============================================================================
use_tpuf: false
embed_all_messages: false
embed_tools: false
tpuf:
# api_key: ""
region: gcp-us-central1
# =============================================================================
# File Processing (LETTA_FILE_PROCESSING_*)
# =============================================================================
file_processing:
timeout_minutes: 30
timeout_error_message: "File processing timed out after {} minutes. Please try again."
# =============================================================================
# Letta Client (LETTA_DEFAULT_*)
# =============================================================================
default_base_url: http://localhost:8283
# default_token: ""
# =============================================================================
# Agent Architecture
# =============================================================================
use_letta_v1_agent: false
archival_memory_token_limit: 8192
# =============================================================================
# Security
# =============================================================================
no_default_actor: false
# encryption_key: ""
# =============================================================================
# OCR
# =============================================================================
# mistral_api_key: ""
# =============================================================================
# Summarizer (LETTA_SUMMARIZER_*)
# =============================================================================
summarizer:
mode: partial_evict_message_buffer_mode
message_buffer_limit: 60
message_buffer_min: 15
enable_summarization: true
max_summarization_retries: 3
partial_evict_summarizer_percentage: 0.30
evict_all_messages: false
max_summarizer_retries: 3
memory_warning_threshold: 0.75
send_memory_warning_message: false
desired_memory_token_pressure: 0.3
keep_last_n_messages: 0
# =============================================================================
# Logging (LETTA_LOGGING_*)
# =============================================================================
logging:
debug: false
json_logging: false
log_level: WARNING
verbose_telemetry_logging: false
# =============================================================================
# Telemetry (LETTA_TELEMETRY_*)
# =============================================================================
telemetry:
enable_datadog: false
provider_trace_backend: postgres
socket_path: /var/run/telemetry/telemetry.sock
provider_trace_pg_metadata_only: false
# source: ""
# Datadog settings (LETTA_TELEMETRY_DATADOG_*)
datadog:
agent_host: localhost
agent_port: 8126
service_name: letta-server
profiling_enabled: false
profiling_memory_enabled: false
profiling_heap_enabled: false
# git_repository_url: ""
# git_commit_sha: ""
main_package: letta
# =============================================================================
# Model Settings (-> OPENAI_*, ANTHROPIC_*, AWS_*, etc.)
# =============================================================================
model:
# Global settings
global_max_context_window_limit: 32000
inner_thoughts_kwarg: thinking
default_prompt_formatter: chatml
# OpenAI (-> OPENAI_*)
openai:
# api_key: ""
api_base: https://api.openai.com/v1
# Anthropic (-> ANTHROPIC_*)
anthropic:
# api_key: ""
max_retries: 3
sonnet_1m: false
# Azure OpenAI (-> AZURE_*)
azure:
# api_key: ""
# base_url: ""
api_version: "2024-09-01-preview"
# Google Gemini (-> GEMINI_*)
gemini:
# api_key: ""
base_url: https://generativelanguage.googleapis.com/
force_minimum_thinking_budget: false
max_retries: 5
# Google Vertex (-> GOOGLE_CLOUD_*)
# google_cloud:
# project: ""
# location: ""
# AWS Bedrock (-> AWS_*, BEDROCK_*)
aws:
# access_key_id: ""
# secret_access_key: ""
default_region: us-east-1
bedrock:
anthropic_version: bedrock-2023-05-31
# OpenRouter (-> OPENROUTER_*)
# openrouter:
# api_key: ""
# referer: ""
# title: ""
# handle_base: ""
# Groq (-> GROQ_*)
# groq:
# api_key: ""
# Together (-> TOGETHER_*)
# together:
# api_key: ""
# DeepSeek (-> DEEPSEEK_*)
# deepseek:
# api_key: ""
# xAI/Grok (-> XAI_*)
# xai:
# api_key: ""
# Z.ai/ZhipuAI (-> ZAI_*)
zai:
# api_key: ""
base_url: https://api.z.ai/api/paas/v4/
# MiniMax (-> MINIMAX_*)
# minimax:
# api_key: ""
# Ollama (-> OLLAMA_*)
# ollama:
# base_url: ""
# vLLM (-> VLLM_*)
# vllm:
# api_base: ""
# handle_base: ""
# SGLang (-> SGLANG_*)
# sglang:
# api_base: ""
# handle_base: ""
# LM Studio (-> LMSTUDIO_*)
# lmstudio:
# base_url: ""
# OpenLLM (-> OPENLLM_*)
# openllm:
# auth_type: ""
# api_key: ""
# =============================================================================
# Tool Settings (-> E2B_*, MCP_*, MODAL_*, TOOL_*, etc.)
# =============================================================================
tool:
# E2B Sandbox (-> E2B_*)
# e2b:
# api_key: ""
# sandbox_template_id: ""
# Modal Sandbox (-> MODAL_*)
# modal:
# token_id: ""
# token_secret: ""
# Search Providers (-> TAVILY_*, EXA_*)
# tavily:
# api_key: ""
# exa:
# api_key: ""
# Local Sandbox (-> TOOL_*)
tool:
# exec_dir: ""
sandbox_timeout: 180
# exec_venv_name: ""
exec_autoreload_venv: true
# MCP (-> MCP_*)
mcp:
connect_to_server_timeout: 30.0
list_tools_timeout: 30.0
execute_tool_timeout: 60.0
read_from_config: false
disable_stdio: true
# =============================================================================
# Datadog Agent Settings (-> DD_*)
# =============================================================================
# datadog:
# site: ""
# service: ""
# version: ""
#
# trace:
# enabled: false
# agent_url: ""
# health_metrics_enabled: false
#
# dogstatsd:
# url: ""
#
# logs:
# injection: false
#
# runtime:
# metrics_enabled: false
#
# appsec:
# enabled: false
# sca_enabled: false
#
# iast:
# enabled: false
#
# exception:
# replay_enabled: false
#
# llmobs:
# enabled: false
# ml_app: ""
#
# instrumentation:
# install_type: ""
#
# git:
# repository_url: ""
# commit_sha: ""
#
# main_package: ""

View File

@@ -1,6 +1,6 @@
services:
letta_db:
image: ankane/pgvector:v0.5.1
image: pgvector/pgvector:0.8.1-pg15
networks:
default:
aliases:

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@ try:
__version__ = version("letta")
except PackageNotFoundError:
# Fallback for development installations
__version__ = "0.16.4"
__version__ = "0.16.5"
if os.environ.get("LETTA_VERSION"):
__version__ = os.environ["LETTA_VERSION"]
@@ -16,26 +16,32 @@ try:
from letta.settings import DatabaseChoice, settings
if settings.database_engine == DatabaseChoice.SQLITE:
from letta.orm import sqlite_functions
from letta.orm import sqlite_functions # noqa: F401
except ImportError:
# If sqlite_vec is not installed, it's fine for client usage
pass
# # imports for easier access
from letta.schemas.agent import AgentState
from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage, LettaPing
from letta.schemas.letta_stop_reason import LettaStopReason
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, BasicBlockMemory, ChatMemory, Memory, RecallMemorySummary
from letta.schemas.message import Message
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.schemas.agent import AgentState as AgentState
from letta.schemas.block import Block as Block
from letta.schemas.embedding_config import EmbeddingConfig as EmbeddingConfig
from letta.schemas.enums import JobStatus as JobStatus
from letta.schemas.file import FileMetadata as FileMetadata
from letta.schemas.job import Job as Job
from letta.schemas.letta_message import LettaErrorMessage as LettaErrorMessage, LettaMessage as LettaMessage, LettaPing as LettaPing
from letta.schemas.letta_stop_reason import LettaStopReason as LettaStopReason
from letta.schemas.llm_config import LLMConfig as LLMConfig
from letta.schemas.memory import (
ArchivalMemorySummary as ArchivalMemorySummary,
BasicBlockMemory as BasicBlockMemory,
ChatMemory as ChatMemory,
Memory as Memory,
RecallMemorySummary as RecallMemorySummary,
)
from letta.schemas.message import Message as Message
from letta.schemas.organization import Organization as Organization
from letta.schemas.passage import Passage as Passage
from letta.schemas.source import Source as Source
from letta.schemas.tool import Tool as Tool
from letta.schemas.usage import LettaUsageStatistics as LettaUsageStatistics
from letta.schemas.user import User as User

View File

@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
from typing import AsyncGenerator
from letta.llm_api.llm_client_base import LLMClientBase
from letta.schemas.enums import LLMCallType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ToolCall
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.telemetry_manager import TelemetryManager
@@ -24,6 +25,7 @@ class LettaLLMAdapter(ABC):
self,
llm_client: LLMClientBase,
llm_config: LLMConfig,
call_type: LLMCallType,
agent_id: str | None = None,
agent_tags: list[str] | None = None,
run_id: str | None = None,
@@ -32,6 +34,7 @@ class LettaLLMAdapter(ABC):
) -> None:
self.llm_client: LLMClientBase = llm_client
self.llm_config: LLMConfig = llm_config
self.call_type: LLMCallType = call_type
self.agent_id: str | None = agent_id
self.agent_tags: list[str] | None = agent_tags
self.run_id: str | None = run_id
@@ -45,9 +48,14 @@ class LettaLLMAdapter(ABC):
self.content: list[TextContent | ReasoningContent | RedactedReasoningContent] | None = None
self.tool_call: ToolCall | None = None
self.tool_calls: list[ToolCall] = []
self.logprobs: ChoiceLogprobs | None = None
# SGLang native endpoint data (for multi-turn RL training)
self.output_ids: list[int] | None = None
self.output_token_logprobs: list[list[float]] | None = None
self.usage: LettaUsageStatistics = LettaUsageStatistics()
self.telemetry_manager: TelemetryManager = TelemetryManager()
self.llm_request_finish_timestamp_ns: int | None = None
self._finish_reason: str | None = None
@abstractmethod
async def invoke_llm(
@@ -85,6 +93,8 @@ class LettaLLMAdapter(ABC):
Returns:
str | None: The finish_reason if available, None otherwise
"""
if self._finish_reason is not None:
return self._finish_reason
if self.chat_completions_response and self.chat_completions_response.choices:
return self.chat_completions_response.choices[0].finish_reason
return None

View File

@@ -2,7 +2,7 @@ from typing import AsyncGenerator
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.otel.tracing import log_attributes, log_event, safe_json_dumps, trace_method
from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
from letta.schemas.provider_trace import ProviderTrace
@@ -66,7 +66,13 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
self.reasoning_content = [OmittedReasoningContent()]
elif self.chat_completions_response.choices[0].message.content:
# Reasoning placed into content for legacy reasons
self.reasoning_content = [TextContent(text=self.chat_completions_response.choices[0].message.content)]
# Carry thought_signature on TextContent when ReasoningContent doesn't exist to hold it
self.reasoning_content = [
TextContent(
text=self.chat_completions_response.choices[0].message.content,
signature=self.chat_completions_response.choices[0].message.reasoning_content_signature,
)
]
else:
# logger.info("No reasoning content found.")
self.reasoning_content = None
@@ -77,6 +83,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
else:
self.tool_call = None
# Extract logprobs if present
self.logprobs = self.chat_completions_response.choices[0].logprobs
# Extract usage statistics
self.usage.step_count = 1
self.usage.completion_tokens = self.chat_completions_response.usage.completion_tokens
@@ -127,6 +136,7 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type=self.call_type,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,

View File

@@ -1,16 +1,16 @@
from typing import AsyncGenerator
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
from letta.errors import LLMError
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
from letta.llm_api.llm_client_base import LLMClientBase
from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
from letta.schemas.enums import ProviderType
from letta.schemas.enums import LLMCallType, ProviderType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.settings import settings
from letta.utils import safe_create_task
@@ -30,13 +30,23 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
self,
llm_client: LLMClientBase,
llm_config: LLMConfig,
call_type: LLMCallType,
agent_id: str | None = None,
agent_tags: list[str] | None = None,
run_id: str | None = None,
org_id: str | None = None,
user_id: str | None = None,
) -> None:
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
super().__init__(
llm_client,
llm_config,
call_type=call_type,
agent_id=agent_id,
agent_tags=agent_tags,
run_id=run_id,
org_id=org_id,
user_id=user_id,
)
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
async def invoke_llm(
@@ -88,11 +98,23 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
# Extract optional parameters
# ttft_span = kwargs.get('ttft_span', None)
request_start_ns = get_utc_timestamp_ns()
# Start the streaming request (map provider errors to common LLMError types)
try:
stream = await self.llm_client.stream_async(request_data, self.llm_config)
except Exception as e:
raise self.llm_client.handle_llm_error(e)
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
await self.llm_client.log_provider_trace_async(
request_data=request_data,
response_json=None,
llm_config=self.llm_config,
latency_ms=latency_ms,
error_msg=str(e),
error_type=type(e).__name__,
)
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
# Process the stream and yield chunks immediately for TTFT
# Wrap in error handling to convert provider errors to common LLMError types
@@ -101,7 +123,19 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
# Yield each chunk immediately as it arrives
yield chunk
except Exception as e:
raise self.llm_client.handle_llm_error(e)
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
await self.llm_client.log_provider_trace_async(
request_data=request_data,
response_json=None,
llm_config=self.llm_config,
latency_ms=latency_ms,
error_msg=str(e),
error_type=type(e).__name__,
)
if isinstance(e, LLMError):
raise
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
# After streaming completes, extract the accumulated data
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
@@ -109,7 +143,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
# Extract tool call from the interface
try:
self.tool_call = self.interface.get_tool_call_object()
except ValueError as e:
except ValueError:
# No tool call, handle upstream
self.tool_call = None
@@ -183,6 +217,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type=self.call_type,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,

View File

@@ -0,0 +1,515 @@
"""
SGLang Native Adapter for multi-turn RL training.
This adapter uses SGLang's native /generate endpoint instead of the OpenAI-compatible
endpoint to get token IDs and per-token logprobs, which are essential for proper
multi-turn RL training with loss masking.
Uses HuggingFace tokenizer's apply_chat_template() for proper tool formatting.
"""
import json
import re
import time
import uuid
from typing import Any, AsyncGenerator, Optional
from letta.adapters.simple_llm_request_adapter import SimpleLLMRequestAdapter
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.llm_api.sglang_native_client import SGLangNativeClient
from letta.log import get_logger
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.openai.chat_completion_response import (
ChatCompletionResponse,
ChatCompletionTokenLogprob,
Choice,
ChoiceLogprobs,
FunctionCall,
Message as ChoiceMessage,
ToolCall,
UsageStatistics,
)
logger = get_logger(__name__)
# Global tokenizer cache
_tokenizer_cache: dict[str, Any] = {}
class SGLangNativeAdapter(SimpleLLMRequestAdapter):
"""
Adapter that uses SGLang's native /generate endpoint for multi-turn RL training.
Key differences from SimpleLLMRequestAdapter:
- Uses /generate instead of /v1/chat/completions
- Returns output_ids (token IDs) in addition to text
- Returns output_token_logprobs with [logprob, token_id] pairs
- Formats tools into prompt and parses tool calls from response
These are essential for building accurate loss masks in multi-turn training.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._sglang_client: Optional[SGLangNativeClient] = None
self._tokenizer: Any = None
def _get_tokenizer(self) -> Any:
"""Get or create tokenizer for the model."""
global _tokenizer_cache
# Get model name from llm_config
model_name = self.llm_config.model
if not model_name:
logger.warning("No model name in llm_config, cannot load tokenizer")
return None
# Check cache
if model_name in _tokenizer_cache:
return _tokenizer_cache[model_name]
try:
from transformers import AutoTokenizer
logger.info(f"Loading tokenizer for model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
_tokenizer_cache[model_name] = tokenizer
return tokenizer
except ImportError:
logger.warning("transformers not installed, falling back to manual formatting")
return None
except Exception as e:
logger.warning(f"Failed to load tokenizer: {e}, falling back to manual formatting")
return None
def _get_sglang_client(self) -> SGLangNativeClient:
"""Get or create SGLang native client."""
if self._sglang_client is None:
# Get base URL from llm_config, removing /v1 suffix if present
base_url = self.llm_config.model_endpoint or ""
# SGLang local instances typically don't need API key
self._sglang_client = SGLangNativeClient(
base_url=base_url,
api_key=None,
)
return self._sglang_client
def _format_tools_for_prompt(self, tools: list) -> str:
"""
Format tools in Qwen3 chat template format for the system prompt.
This matches the exact format produced by Qwen3's tokenizer.apply_chat_template()
with tools parameter.
"""
if not tools:
return ""
# Format each tool as JSON (matching Qwen3 template exactly)
tool_jsons = []
for tool in tools:
# Handle both dict and object formats
if isinstance(tool, dict):
# Already in OpenAI format
tool_jsons.append(json.dumps(tool))
else:
# Convert object to dict
tool_dict = {
"type": "function",
"function": {
"name": getattr(getattr(tool, "function", tool), "name", ""),
"description": getattr(getattr(tool, "function", tool), "description", ""),
"parameters": getattr(getattr(tool, "function", tool), "parameters", {}),
},
}
tool_jsons.append(json.dumps(tool_dict))
# Use exact Qwen3 format
tools_section = (
"\n\n# Tools\n\n"
"You may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n"
"<tools>\n" + "\n".join(tool_jsons) + "\n"
"</tools>\n\n"
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
"<tool_call>\n"
'{"name": <function-name>, "arguments": <args-json-object>}\n'
"</tool_call>"
)
return tools_section
def _convert_messages_to_openai_format(self, messages: list) -> list[dict]:
"""Convert Letta Message objects to OpenAI-style message dicts."""
openai_messages = []
for msg in messages:
# Handle both dict and Pydantic Message objects
if hasattr(msg, "role"):
role = msg.role
content = msg.content if hasattr(msg, "content") else ""
# Handle content that might be a list of content parts
if isinstance(content, list):
content = " ".join([c.text if hasattr(c, "text") else str(c) for c in content])
elif content is None:
content = ""
tool_calls = getattr(msg, "tool_calls", None)
tool_call_id = getattr(msg, "tool_call_id", None)
name = getattr(msg, "name", None)
else:
role = msg.get("role", "user")
content = msg.get("content", "")
tool_calls = msg.get("tool_calls", None)
tool_call_id = msg.get("tool_call_id", None)
name = msg.get("name", None)
openai_msg = {"role": role, "content": content}
if tool_calls:
# Convert tool calls to OpenAI format
openai_tool_calls = []
for tc in tool_calls:
if hasattr(tc, "function"):
tc_dict = {
"id": getattr(tc, "id", f"call_{uuid.uuid4().hex[:8]}"),
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
if isinstance(tc.function.arguments, str)
else json.dumps(tc.function.arguments),
},
}
else:
tc_dict = {
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
"type": "function",
"function": tc.get("function", {}),
}
openai_tool_calls.append(tc_dict)
openai_msg["tool_calls"] = openai_tool_calls
if tool_call_id:
openai_msg["tool_call_id"] = tool_call_id
if name and role == "tool":
openai_msg["name"] = name
openai_messages.append(openai_msg)
return openai_messages
def _convert_tools_to_openai_format(self, tools: list) -> list[dict]:
"""Convert tools to OpenAI format for tokenizer."""
openai_tools = []
for tool in tools:
if isinstance(tool, dict):
# Already a dict, ensure it's in the right format
if "function" in tool:
openai_tools.append(tool)
else:
# Might be the function directly
openai_tools.append({"type": "function", "function": tool})
else:
# Convert object to dict
func = getattr(tool, "function", tool)
tool_dict = {
"type": "function",
"function": {
"name": getattr(func, "name", ""),
"description": getattr(func, "description", ""),
"parameters": getattr(func, "parameters", {}),
},
}
openai_tools.append(tool_dict)
return openai_tools
def _format_messages_to_text(self, messages: list, tools: list) -> str:
"""
Format messages to text using tokenizer's apply_chat_template if available.
Falls back to manual formatting if tokenizer is not available.
"""
tokenizer = self._get_tokenizer()
if tokenizer is not None:
# Use tokenizer's apply_chat_template for proper formatting
openai_messages = self._convert_messages_to_openai_format(messages)
openai_tools = self._convert_tools_to_openai_format(tools) if tools else None
try:
formatted = tokenizer.apply_chat_template(
openai_messages,
tokenize=False,
add_generation_prompt=True,
tools=openai_tools,
)
logger.debug(f"Formatted prompt using tokenizer ({len(formatted)} chars)")
return formatted
except Exception as e:
logger.warning(f"apply_chat_template failed: {e}, falling back to manual formatting")
# Fallback to manual formatting
return self._format_messages_to_text_manual(messages, tools)
def _format_messages_to_text_manual(self, messages: list, tools: list) -> str:
"""Manual fallback formatting for when tokenizer is not available."""
formatted_parts = []
tools_section = self._format_tools_for_prompt(tools)
for msg in messages:
# Handle both dict and Pydantic Message objects
if hasattr(msg, "role"):
role = msg.role
content = msg.content if hasattr(msg, "content") else ""
if isinstance(content, list):
content = " ".join([c.text if hasattr(c, "text") else str(c) for c in content])
elif content is None:
content = ""
tool_calls = getattr(msg, "tool_calls", None)
else:
role = msg.get("role", "user")
content = msg.get("content", "")
tool_calls = msg.get("tool_calls", None)
if role == "system":
system_content = content + tools_section if tools_section else content
formatted_parts.append(f"<|im_start|>system\n{system_content}<|im_end|>")
tools_section = ""
elif role == "user":
formatted_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
if tool_calls:
tc_parts = []
for tc in tool_calls:
if hasattr(tc, "function"):
tc_name = tc.function.name
tc_args = tc.function.arguments
else:
tc_name = tc.get("function", {}).get("name", "")
tc_args = tc.get("function", {}).get("arguments", "{}")
if isinstance(tc_args, str):
try:
tc_args = json.loads(tc_args)
except Exception:
pass
tc_parts.append(f'<tool_call>\n{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n</tool_call>')
assistant_content = content + "\n" + "\n".join(tc_parts) if content else "\n".join(tc_parts)
formatted_parts.append(f"<|im_start|>assistant\n{assistant_content}<|im_end|>")
elif content:
formatted_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
elif role == "tool":
formatted_parts.append(f"<|im_start|>user\n<tool_response>\n{content}\n</tool_response><|im_end|>")
formatted_parts.append("<|im_start|>assistant\n")
return "\n".join(formatted_parts)
def _parse_tool_calls(self, text: str) -> list[ToolCall]:
"""
Parse tool calls from response text.
Looks for patterns like:
<tool_call>
{"name": "tool_name", "arguments": {...}}
</tool_call>
"""
tool_calls = []
# Find all tool_call blocks
pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
matches = re.findall(pattern, text, re.DOTALL)
for match in matches:
try:
tc_data = json.loads(match)
name = tc_data.get("name", "")
arguments = tc_data.get("arguments", {})
if isinstance(arguments, dict):
arguments = json.dumps(arguments)
tool_call = ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
type="function",
function=FunctionCall(
name=name,
arguments=arguments,
),
)
tool_calls.append(tool_call)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse tool call JSON: {e}")
continue
return tool_calls
def _extract_content_without_tool_calls(self, text: str) -> str:
"""Extract content from response, removing tool_call blocks."""
# Remove tool_call blocks
cleaned = re.sub(r"<tool_call>.*?</tool_call>", "", text, flags=re.DOTALL)
# Clean up whitespace
cleaned = cleaned.strip()
return cleaned
async def invoke_llm(
self,
request_data: dict,
messages: list,
tools: list,
use_assistant_message: bool,
requires_approval_tools: list[str] = [],
step_id: str | None = None,
actor: str | None = None,
) -> AsyncGenerator[LettaMessage | None, None]:
"""
Execute LLM request using SGLang native endpoint.
This method:
1. Formats messages and tools to text using chat template
2. Calls SGLang native /generate endpoint
3. Extracts output_ids and output_token_logprobs
4. Parses tool calls from response
5. Converts response to standard format
"""
self.request_data = request_data
# Get sampling params from request_data
sampling_params = {
"temperature": request_data.get("temperature", 0.7),
"max_new_tokens": request_data.get("max_tokens", 4096),
"top_p": request_data.get("top_p", 0.9),
}
# Format messages to text (includes tools in prompt)
text_input = self._format_messages_to_text(messages, tools)
# Call SGLang native endpoint
client = self._get_sglang_client()
try:
response = await client.generate(
text=text_input,
sampling_params=sampling_params,
return_logprob=True,
)
except Exception as e:
logger.error(f"SGLang native endpoint error: {e}")
raise
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
# Store native response data
self.response_data = response
# Extract SGLang native data
self.output_ids = response.get("output_ids")
# output_token_logprobs is inside meta_info
meta_info = response.get("meta_info", {})
self.output_token_logprobs = meta_info.get("output_token_logprobs")
# Extract text response
text_response = response.get("text", "")
# Remove trailing end token if present
if text_response.endswith("<|im_end|>"):
text_response = text_response[:-10]
# Parse tool calls from response
parsed_tool_calls = self._parse_tool_calls(text_response)
# Extract content (text without tool_call blocks)
content_text = self._extract_content_without_tool_calls(text_response)
# Determine finish reason
meta_info = response.get("meta_info", {})
finish_reason_info = meta_info.get("finish_reason", {})
if isinstance(finish_reason_info, dict):
finish_reason = finish_reason_info.get("type", "stop")
else:
finish_reason = "stop"
# If we have tool calls, set finish_reason to tool_calls
if parsed_tool_calls:
finish_reason = "tool_calls"
# Convert to standard ChatCompletionResponse format for compatibility
# Build logprobs in OpenAI format from SGLang format
logprobs_content = None
if self.output_token_logprobs:
logprobs_content = []
for i, lp_data in enumerate(self.output_token_logprobs):
# SGLang format: [logprob, token_id, top_logprob]
logprob = lp_data[0] if len(lp_data) > 0 else 0.0
token_id = lp_data[1] if len(lp_data) > 1 else 0
logprobs_content.append(
ChatCompletionTokenLogprob(
token=str(token_id),
logprob=logprob,
bytes=None,
top_logprobs=[],
)
)
choice_logprobs = ChoiceLogprobs(content=logprobs_content) if logprobs_content else None
# Build chat completion response
prompt_tokens = meta_info.get("prompt_tokens", 0)
completion_tokens = len(self.output_ids) if self.output_ids else 0
self.chat_completions_response = ChatCompletionResponse(
id=meta_info.get("id", "sglang-native"),
created=int(time.time()),
choices=[
Choice(
finish_reason=finish_reason,
index=0,
message=ChoiceMessage(
role="assistant",
content=content_text if content_text else None,
tool_calls=parsed_tool_calls if parsed_tool_calls else None,
),
logprobs=choice_logprobs,
)
],
usage=UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
# Extract content
if content_text:
self.content = [TextContent(text=content_text)]
else:
self.content = None
# No reasoning content from native endpoint
self.reasoning_content = None
# Set tool calls
self.tool_calls = parsed_tool_calls
self.tool_call = parsed_tool_calls[0] if parsed_tool_calls else None
# Set logprobs
self.logprobs = choice_logprobs
# Extract usage statistics
self.usage.step_count = 1
self.usage.completion_tokens = completion_tokens
self.usage.prompt_tokens = prompt_tokens
self.usage.total_tokens = prompt_tokens + completion_tokens
self.log_provider_trace(step_id=step_id, actor=actor)
logger.info(
f"SGLang native response: {len(self.output_ids or [])} tokens, "
f"{len(self.output_token_logprobs or [])} logprobs, "
f"{len(parsed_tool_calls)} tool calls"
)
yield None
return

View File

@@ -1,7 +1,9 @@
from typing import AsyncGenerator
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
from letta.errors import LLMError
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.schemas.enums import LLMCallType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens
@@ -45,7 +47,7 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
@@ -53,7 +55,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
try:
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
except Exception as e:
raise self.llm_client.handle_llm_error(e)
if isinstance(e, LLMError):
raise
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
@@ -80,7 +84,12 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
if self.chat_completions_response.choices[0].message.content:
# NOTE: big difference - 'content' goes into 'content'
# Reasoning placed into content for legacy reasons
self.content = [TextContent(text=self.chat_completions_response.choices[0].message.content)]
# Carry thought_signature on TextContent when ReasoningContent doesn't exist to hold it
# (e.g. Gemini 2.5 Flash with include_thoughts=False still returns thought_signature)
orphan_sig = (
self.chat_completions_response.choices[0].message.reasoning_content_signature if not self.reasoning_content else None
)
self.content = [TextContent(text=self.chat_completions_response.choices[0].message.content, signature=orphan_sig)]
else:
self.content = None
@@ -93,6 +102,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
self.tool_calls = list(tool_calls)
self.tool_call = self.tool_calls[0] if self.tool_calls else None
# Extract logprobs if present
self.logprobs = self.chat_completions_response.choices[0].logprobs
# Extract usage statistics
self.usage.step_count = 1
self.usage.completion_tokens = self.chat_completions_response.usage.completion_tokens

View File

@@ -1,7 +1,7 @@
import json
from typing import AsyncGenerator, List
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
from letta.errors import LLMError
from letta.log import get_logger
logger = get_logger(__name__)
@@ -70,6 +70,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
# Store request data
self.request_data = request_data
# Track request start time for latency calculation
request_start_ns = get_utc_timestamp_ns()
# Get cancellation event for this run to enable graceful cancellation (before branching)
cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None
@@ -138,7 +141,19 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
else:
stream = await self.llm_client.stream_async(request_data, self.llm_config)
except Exception as e:
raise self.llm_client.handle_llm_error(e)
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
await self.llm_client.log_provider_trace_async(
request_data=request_data,
response_json=None,
llm_config=self.llm_config,
latency_ms=latency_ms,
error_msg=str(e),
error_type=type(e).__name__,
)
if isinstance(e, LLMError):
raise
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
# Process the stream and yield chunks immediately for TTFT
try:
@@ -146,8 +161,19 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
# Yield each chunk immediately as it arrives
yield chunk
except Exception as e:
# Map provider-specific errors during streaming to common LLMError types
raise self.llm_client.handle_llm_error(e)
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
await self.llm_client.log_provider_trace_async(
request_data=request_data,
response_json=None,
llm_config=self.llm_config,
latency_ms=latency_ms,
error_msg=str(e),
error_type=type(e).__name__,
)
if isinstance(e, LLMError):
raise
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
# After streaming completes, extract the accumulated data
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
@@ -172,6 +198,22 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
# Store any additional data from the interface
self.message_id = self.interface.letta_message_id
# Populate finish_reason for downstream continuation logic.
# In Responses streaming, max_output_tokens is expressed via incomplete_details.reason.
if hasattr(self.interface, "final_response") and self.interface.final_response is not None:
resp = self.interface.final_response
incomplete_details = getattr(resp, "incomplete_details", None)
incomplete_reason = getattr(incomplete_details, "reason", None) if incomplete_details else None
if incomplete_reason == "max_output_tokens":
self._finish_reason = "length"
elif incomplete_reason == "content_filter":
self._finish_reason = "content_filter"
elif incomplete_reason is not None:
# Unknown incomplete reason — preserve it as-is for diagnostics
self._finish_reason = incomplete_reason
elif getattr(resp, "status", None) == "completed":
self._finish_reason = "stop"
# Log request and response data
self.log_provider_trace(step_id=step_id, actor=actor)
@@ -232,6 +274,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type=self.call_type,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,

View File

@@ -123,32 +123,17 @@ class BaseAgent(ABC):
curr_system_message = in_context_messages[0]
curr_system_message_text = curr_system_message.content[0].text
# extract the dynamic section that includes memory blocks, tool rules, and directories
# this avoids timestamp comparison issues
def extract_dynamic_section(text):
start_marker = "</base_instructions>"
end_marker = "<memory_metadata>"
start_idx = text.find(start_marker)
end_idx = text.find(end_marker)
if start_idx != -1 and end_idx != -1:
return text[start_idx:end_idx]
return text # fallback to full text if markers not found
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
# generate just the memory string with current state for comparison
# generate memory string with current state for comparison
curr_memory_str = agent_state.memory.compile(
tool_usage_rules=tool_constraint_block,
sources=agent_state.sources,
max_files_open=agent_state.max_files_open,
llm_config=agent_state.llm_config,
)
new_dynamic_section = extract_dynamic_section(curr_memory_str)
# compare just the dynamic sections (memory blocks, tool rules, directories)
if curr_dynamic_section == new_dynamic_section:
system_prompt_changed = agent_state.system not in curr_system_message_text
memory_changed = curr_memory_str not in curr_system_message_text
if (not system_prompt_changed) and (not memory_changed):
logger.debug(
f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
@@ -183,7 +168,7 @@ class BaseAgent(ABC):
actor=self.actor,
project_id=agent_state.project_id,
)
return [new_system_message] + in_context_messages[1:]
return [new_system_message, *in_context_messages[1:]]
else:
return in_context_messages

View File

@@ -25,6 +25,11 @@ class BaseAgentV2(ABC):
self.actor = actor
self.logger = get_logger(agent_state.id)
@property
def agent_id(self) -> str:
"""Return the agent ID for backward compatibility with code expecting self.agent_id."""
return self.agent_state.id
@abstractmethod
async def build_request(
self,
@@ -46,6 +51,7 @@ class BaseAgentV2(ABC):
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
client_tools: list["ClientToolSchema"] | None = None,
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
) -> LettaResponse:
"""
Execute the agent loop in blocking mode, returning all messages at once.
@@ -53,6 +59,7 @@ class BaseAgentV2(ABC):
Args:
client_tools: Optional list of client-side tools. When called, execution pauses
for client to provide tool returns.
include_compaction_messages: Not used in V2, but accepted for API compatibility.
"""
raise NotImplementedError
@@ -66,8 +73,9 @@ class BaseAgentV2(ABC):
use_assistant_message: bool = True,
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
conversation_id: str | None = None,
conversation_id: str | None = None,
client_tools: list["ClientToolSchema"] | None = None,
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
"""
Execute the agent loop in streaming mode, yielding chunks as they become available.
@@ -78,5 +86,6 @@ class BaseAgentV2(ABC):
Args:
client_tools: Optional list of client-side tools. When called, execution pauses
for client to provide tool returns.
include_compaction_messages: Not used in V2, but accepted for API compatibility.
"""
raise NotImplementedError

View File

@@ -8,7 +8,7 @@ from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.prompts.gpt_system import get_system_text
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.enums import MessageRole
from letta.schemas.enums import LLMCallType, MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
@@ -79,7 +79,7 @@ class EphemeralSummaryAgent(BaseAgent):
content=[TextContent(text=get_system_text("summary_system_prompt"))],
)
messages = await convert_message_creates_to_messages(
message_creates=[system_message_create] + input_messages,
message_creates=[system_message_create, *input_messages],
agent_id=self.agent_id,
timezone=agent_state.timezone,
run_id=None, # TODO: add this
@@ -92,7 +92,7 @@ class EphemeralSummaryAgent(BaseAgent):
telemetry_manager=TelemetryManager(),
agent_id=self.agent_id,
agent_tags=agent_state.tags,
call_type="summarization",
call_type=LLMCallType.summarization,
)
response_data = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
response = await llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)

View File

@@ -1,10 +1,12 @@
import json
import uuid
import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from uuid import UUID, uuid4
from letta.errors import PendingApprovalError
if TYPE_CHECKING:
from letta.schemas.tool import Tool
from letta.errors import LettaError, PendingApprovalError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
@@ -233,6 +235,11 @@ async def _prepare_in_context_messages_no_persist_async(
current_in_context_messages = [system_message]
else:
# Default mode: load messages from agent_state.message_ids
if not agent_state.message_ids:
raise LettaError(
message=f"Agent {agent_state.id} has no in-context messages. "
"This typically means the agent's system message was not initialized correctly.",
)
if agent_state.message_buffer_autoclear:
# If autoclear is enabled, only include the most recent system message (usually at index 0)
current_in_context_messages = [
@@ -242,6 +249,14 @@ async def _prepare_in_context_messages_no_persist_async(
# Otherwise, include the full list of messages by ID for context
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
# Convert ToolReturnCreate to ApprovalCreate for unified processing
if input_messages[0].type == "tool_return":
tool_return_msg = input_messages[0]
input_messages = [
ApprovalCreate(approvals=tool_return_msg.tool_returns),
*input_messages[1:],
]
# Check for approval-related message validation
if input_messages[0].type == "approval":
# User is trying to send an approval response
@@ -254,12 +269,31 @@ async def _prepare_in_context_messages_no_persist_async(
for msg in reversed(recent_messages):
if msg.role == "tool" and validate_persisted_tool_call_ids(msg, input_messages[0]):
logger.info(
f"Idempotency check: Found matching tool return in recent history. "
f"Idempotency check: Found matching tool return in recent in-context history. "
f"tool_returns={msg.tool_returns}, approval_response.approvals={input_messages[0].approvals}"
)
approval_already_processed = True
break
# If not found in context and summarization just happened, check full history
non_system_summary_messages = [
m for m in current_in_context_messages if m.role not in (MessageRole.system, MessageRole.summary)
]
if not approval_already_processed and len(non_system_summary_messages) == 0:
last_tool_messages = await message_manager.list_messages(
actor=actor,
agent_id=agent_state.id,
roles=[MessageRole.tool],
limit=1,
ascending=False, # Most recent first
)
if len(last_tool_messages) == 1 and validate_persisted_tool_call_ids(last_tool_messages[0], input_messages[0]):
logger.info(
f"Idempotency check: Found matching tool return in full history (post-compaction). "
f"tool_returns={last_tool_messages[0].tool_returns}, approval_response.approvals={input_messages[0].approvals}"
)
approval_already_processed = True
if approval_already_processed:
# Approval already handled, just process follow-up messages if any or manually inject keep-alive message
keep_alive_messages = input_messages[1:] or [

View File

@@ -13,14 +13,13 @@ from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.agents.helpers import (
_build_rule_violation_result,
_create_letta_response,
_load_last_function_response,
_pop_heartbeat,
_prepare_in_context_messages_no_persist_async,
_safe_load_tool_call_str,
generate_step_id,
)
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
from letta.errors import ContextWindowExceededError
from letta.errors import ContextWindowExceededError, LLMError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import AsyncTimer, get_utc_time, get_utc_timestamp_ns, ns_to_ms
from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
@@ -35,7 +34,7 @@ from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.otel.tracing import log_event, trace_method, tracer
from letta.schemas.agent import AgentState, UpdateAgent
from letta.schemas.enums import JobStatus, ProviderType, StepStatus, ToolType
from letta.schemas.enums import JobStatus, LLMCallType, ProviderType, StepStatus, ToolType
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
@@ -49,7 +48,6 @@ from letta.schemas.openai.chat_completion_response import (
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
@@ -294,6 +292,7 @@ class LettaAgent(BaseAgent):
agent_step_span.set_attributes({"step_id": step_id})
step_progression = StepProgression.START
caught_exception = None
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
@@ -312,6 +311,7 @@ class LettaAgent(BaseAgent):
step_id=step_id,
project_id=agent_state.project_id,
status=StepStatus.PENDING,
model_handle=agent_state.llm_config.handle,
)
# Only use step_id in messages if step was actually created
effective_step_id = step_id if logged_step else None
@@ -370,8 +370,12 @@ class LettaAgent(BaseAgent):
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
elif response.choices[0].message.content:
# Carry thought_signature on TextContent when ReasoningContent doesn't exist to hold it
reasoning = [
TextContent(text=response.choices[0].message.content)
TextContent(
text=response.choices[0].message.content,
signature=response.choices[0].message.reasoning_content_signature,
)
] # reasoning placed into content for legacy reasons
else:
self.logger.info("No reasoning content found.")
@@ -409,24 +413,6 @@ class LettaAgent(BaseAgent):
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# Log LLM Trace
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace=ProviderTrace(
request_json=request_data,
response_json=response_data,
step_id=step_id,
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE
# stream step
# TODO: improve TTFT
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
@@ -453,6 +439,7 @@ class LettaAgent(BaseAgent):
)
except Exception as e:
caught_exception = e
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
@@ -499,8 +486,8 @@ class LettaAgent(BaseAgent):
await self.step_manager.update_step_error_async(
actor=self.actor,
step_id=step_id, # Use original step_id for telemetry
error_type=type(e).__name__ if "e" in locals() else "Unknown",
error_message=str(e) if "e" in locals() else "Unknown error",
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
error_traceback=traceback.format_exc(),
stop_reason=stop_reason,
)
@@ -646,6 +633,7 @@ class LettaAgent(BaseAgent):
agent_step_span.set_attributes({"step_id": step_id})
step_progression = StepProgression.START
caught_exception = None
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
@@ -664,6 +652,7 @@ class LettaAgent(BaseAgent):
step_id=step_id,
project_id=agent_state.project_id,
status=StepStatus.PENDING,
model_handle=agent_state.llm_config.handle,
)
# Only use step_id in messages if step was actually created
effective_step_id = step_id if logged_step else None
@@ -720,8 +709,12 @@ class LettaAgent(BaseAgent):
)
]
elif response.choices[0].message.content:
# Carry thought_signature on TextContent when ReasoningContent doesn't exist to hold it
reasoning = [
TextContent(text=response.choices[0].message.content)
TextContent(
text=response.choices[0].message.content,
signature=response.choices[0].message.reasoning_content_signature,
)
] # reasoning placed into content for legacy reasons
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
@@ -762,24 +755,6 @@ class LettaAgent(BaseAgent):
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# Log LLM Trace
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace=ProviderTrace(
request_json=request_data,
response_json=response_data,
step_id=step_id,
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
@@ -795,6 +770,7 @@ class LettaAgent(BaseAgent):
)
except Exception as e:
caught_exception = e
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
@@ -837,8 +813,8 @@ class LettaAgent(BaseAgent):
await self.step_manager.update_step_error_async(
actor=self.actor,
step_id=step_id, # Use original step_id for telemetry
error_type=type(e).__name__ if "e" in locals() else "Unknown",
error_message=str(e) if "e" in locals() else "Unknown error",
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
error_traceback=traceback.format_exc(),
stop_reason=stop_reason,
)
@@ -1000,6 +976,7 @@ class LettaAgent(BaseAgent):
agent_step_span.set_attributes({"step_id": step_id})
step_progression = StepProgression.START
caught_exception = None
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
@@ -1018,6 +995,7 @@ class LettaAgent(BaseAgent):
step_id=step_id,
project_id=agent_state.project_id,
status=StepStatus.PENDING,
model_handle=agent_state.llm_config.handle,
)
# Only use step_id in messages if step was actually created
effective_step_id = step_id if logged_step else None
@@ -1152,6 +1130,8 @@ class LettaAgent(BaseAgent):
"output_tokens": interface.output_tokens,
},
},
llm_config=agent_state.llm_config,
latency_ms=int(llm_request_ms),
)
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
@@ -1220,41 +1200,6 @@ class LettaAgent(BaseAgent):
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
# log_event("agent.stream.llm_response.processed") # [4^]
# Log LLM Trace
# We are piecing together the streamed response here.
# Content here does not match the actual response schema as streams come in chunks.
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace=ProviderTrace(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
},
},
step_id=step_id,
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE
if persisted_messages[-1].role != "approval":
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
@@ -1287,6 +1232,7 @@ class LettaAgent(BaseAgent):
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
except Exception as e:
caught_exception = e
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
@@ -1333,8 +1279,8 @@ class LettaAgent(BaseAgent):
await self.step_manager.update_step_error_async(
actor=self.actor,
step_id=step_id, # Use original step_id for telemetry
error_type=type(e).__name__ if "e" in locals() else "Unknown",
error_message=str(e) if "e" in locals() else "Unknown error",
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
error_traceback=traceback.format_exc(),
stop_reason=stop_reason,
)
@@ -1481,7 +1427,7 @@ class LettaAgent(BaseAgent):
agent_tags=agent_state.tags,
run_id=self.current_run_id,
step_id=step_metrics.id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
)
response = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
@@ -1554,13 +1500,13 @@ class LettaAgent(BaseAgent):
agent_tags=agent_state.tags,
run_id=self.current_run_id,
step_id=step_id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
)
# Attempt LLM request with telemetry wrapper
return (
request_data,
await llm_client.stream_async_with_telemetry(request_data, agent_state.llm_config),
await llm_client.stream_async(request_data, agent_state.llm_config),
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
@@ -1605,8 +1551,10 @@ class LettaAgent(BaseAgent):
run_id=run_id,
step_id=step_id,
)
elif isinstance(e, LLMError):
raise
else:
raise llm_client.handle_llm_error(e)
raise llm_client.handle_llm_error(e, llm_config=llm_config)
@trace_method
async def _rebuild_context_window(
@@ -1626,7 +1574,7 @@ class LettaAgent(BaseAgent):
self.logger.warning(
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
)
new_in_context_messages, updated = await self.summarizer.summarize(
new_in_context_messages, _updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
force=True,
@@ -1639,7 +1587,7 @@ class LettaAgent(BaseAgent):
self.logger.info(
f"Total tokens {total_tokens} does not exceed configured max tokens {llm_config.context_window}, passing summarizing w/o force."
)
new_in_context_messages, updated = await self.summarizer.summarize(
new_in_context_messages, _updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
run_id=run_id,
@@ -1659,7 +1607,7 @@ class LettaAgent(BaseAgent):
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
message_ids = agent_state.message_ids
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=self.actor)
new_in_context_messages, updated = await self.summarizer.summarize(
new_in_context_messages, _updated = await self.summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=[], force=True
)
return await self.agent_manager.update_message_ids_async(

View File

@@ -217,7 +217,7 @@ class LettaAgentBatch(BaseAgent):
if batch_items:
log_event(name="bulk_create_batch_items")
batch_items_persisted = await self.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=self.actor)
await self.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=self.actor)
log_event(name="return_batch_response")
return LettaBatchResponse(

View File

@@ -9,7 +9,6 @@ from letta.adapters.letta_llm_adapter import LettaLLMAdapter
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
from letta.agents.base_agent_v2 import BaseAgentV2
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.agents.helpers import (
_build_rule_violation_result,
_load_last_function_response,
@@ -20,7 +19,7 @@ from letta.agents.helpers import (
generate_step_id,
)
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
from letta.errors import ContextWindowExceededError, LLMError
from letta.errors import ContextWindowExceededError, InsufficientCreditsError, LLMError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
@@ -31,7 +30,7 @@ from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method, tracer
from letta.prompts.prompt_generator import PromptGenerator
from letta.schemas.agent import AgentState, UpdateAgent
from letta.schemas.enums import AgentType, MessageStreamStatus, RunStatus, StepStatus
from letta.schemas.enums import AgentType, LLMCallType, MessageStreamStatus, RunStatus, StepStatus
from letta.schemas.letta_message import LettaMessage, MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import ClientToolSchema
@@ -58,6 +57,7 @@ from letta.server.rest_api.utils import (
from letta.services.agent_manager import AgentManager
from letta.services.archive_manager import ArchiveManager
from letta.services.block_manager import BlockManager
from letta.services.credit_verification_service import CreditVerificationService
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
@@ -67,10 +67,10 @@ from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.services.telemetry_manager import TelemetryManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import model_settings, settings, summarizer_settings
from letta.settings import settings, summarizer_settings
from letta.system import package_function_response
from letta.types import JsonDict
from letta.utils import log_telemetry, safe_create_task, united_diff, validate_function_response
from letta.utils import log_telemetry, safe_create_task, safe_create_task_with_return, united_diff, validate_function_response
class LettaAgentV2(BaseAgentV2):
@@ -106,6 +106,7 @@ class LettaAgentV2(BaseAgentV2):
self.passage_manager = PassageManager()
self.step_manager = StepManager()
self.telemetry_manager = TelemetryManager()
self.credit_verification_service = CreditVerificationService()
## TODO: Expand to more
# if summarizer_settings.enable_summarization and model_settings.openai_api_key:
@@ -158,6 +159,8 @@ class LettaAgentV2(BaseAgentV2):
llm_adapter=LettaLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
org_id=self.actor.organization_id,
user_id=self.actor.id,
@@ -181,6 +184,7 @@ class LettaAgentV2(BaseAgentV2):
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
) -> LettaResponse:
"""
Execute the agent loop in blocking mode, returning all messages at once.
@@ -193,6 +197,7 @@ class LettaAgentV2(BaseAgentV2):
include_return_message_types: Filter for which message types to return
request_start_timestamp_ns: Start time for tracking request duration
client_tools: Optional list of client-side tools (not used in V2, for API compatibility)
include_compaction_messages: Not used in V2, but accepted for API compatibility.
Returns:
LettaResponse: Complete response with all messages and metadata
@@ -205,15 +210,25 @@ class LettaAgentV2(BaseAgentV2):
)
in_context_messages = in_context_messages + input_messages_to_persist
response_letta_messages = []
credit_task = None
for i in range(max_steps):
remaining_turns = max_steps - i - 1
# Await credit check from previous iteration before running next step
if credit_task is not None:
if not await credit_task:
self.should_continue = False
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.insufficient_credits)
break
credit_task = None
response = self._step(
messages=in_context_messages + self.response_messages,
input_messages_to_persist=input_messages_to_persist,
llm_adapter=LettaLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
@@ -233,6 +248,9 @@ class LettaAgentV2(BaseAgentV2):
if not self.should_continue:
break
# Fire credit check to run in parallel with loop overhead / next step setup
credit_task = safe_create_task_with_return(self._check_credits())
input_messages_to_persist = []
# Rebuild context window after stepping
@@ -271,6 +289,7 @@ class LettaAgentV2(BaseAgentV2):
request_start_timestamp_ns: int | None = None,
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
) -> AsyncGenerator[str, None]:
"""
Execute the agent loop in streaming mode, yielding chunks as they become available.
@@ -289,6 +308,7 @@ class LettaAgentV2(BaseAgentV2):
include_return_message_types: Filter for which message types to return
request_start_timestamp_ns: Start time for tracking request duration
client_tools: Optional list of client-side tools (not used in V2, for API compatibility)
include_compaction_messages: Not used in V2, but accepted for API compatibility.
Yields:
str: JSON-formatted SSE data chunks for each completed step
@@ -301,6 +321,7 @@ class LettaAgentV2(BaseAgentV2):
llm_adapter = LettaLLMStreamAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
@@ -311,6 +332,7 @@ class LettaAgentV2(BaseAgentV2):
llm_adapter = LettaLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
@@ -323,7 +345,16 @@ class LettaAgentV2(BaseAgentV2):
input_messages, self.agent_state, self.message_manager, self.actor, run_id
)
in_context_messages = in_context_messages + input_messages_to_persist
credit_task = None
for i in range(max_steps):
# Await credit check from previous iteration before running next step
if credit_task is not None:
if not await credit_task:
self.should_continue = False
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.insufficient_credits)
break
credit_task = None
response = self._step(
messages=in_context_messages + self.response_messages,
input_messages_to_persist=input_messages_to_persist,
@@ -342,6 +373,9 @@ class LettaAgentV2(BaseAgentV2):
if not self.should_continue:
break
# Fire credit check to run in parallel with loop overhead / next step setup
credit_task = safe_create_task_with_return(self._check_credits())
input_messages_to_persist = []
if self.stop_reason is None:
@@ -420,8 +454,9 @@ class LettaAgentV2(BaseAgentV2):
raise AssertionError("run_id is required when enforce_run_id_set is True")
step_progression = StepProgression.START
caught_exception = None
# TODO(@caren): clean this up
tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, _step_start_ns, step_metrics = (
None,
None,
None,
@@ -580,6 +615,7 @@ class LettaAgentV2(BaseAgentV2):
)
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
except Exception as e:
caught_exception = e
self.logger.warning(f"Error during step processing: {e}")
self.job_update_metadata = {"error": str(e)}
@@ -615,8 +651,8 @@ class LettaAgentV2(BaseAgentV2):
await self.step_manager.update_step_error_async(
actor=self.actor,
step_id=step_id, # Use original step_id for telemetry
error_type=type(e).__name__ if "e" in locals() else "Unknown",
error_message=str(e) if "e" in locals() else "Unknown error",
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
error_traceback=traceback.format_exc(),
stop_reason=self.stop_reason,
)
@@ -667,6 +703,17 @@ class LettaAgentV2(BaseAgentV2):
self.last_function_response = None
self.response_messages = []
async def _check_credits(self) -> bool:
"""Check if the organization still has credits. Returns True if OK or not configured."""
try:
await self.credit_verification_service.verify_credits(self.actor.organization_id, self.agent_state.id)
return True
except InsufficientCreditsError:
self.logger.warning(
f"Insufficient credits for organization {self.actor.organization_id}, agent {self.agent_state.id}, stopping agent loop"
)
return False
@trace_method
async def _check_run_cancellation(self, run_id) -> bool:
try:
@@ -678,20 +725,37 @@ class LettaAgentV2(BaseAgentV2):
return False
@trace_method
async def _refresh_messages(self, in_context_messages: list[Message]):
num_messages = await self.message_manager.size_async(
agent_id=self.agent_state.id,
actor=self.actor,
)
num_archival_memories = await self.passage_manager.agent_passage_size_async(
agent_id=self.agent_state.id,
actor=self.actor,
)
in_context_messages = await self._rebuild_memory(
in_context_messages,
num_messages=num_messages,
num_archival_memories=num_archival_memories,
)
async def _refresh_messages(self, in_context_messages: list[Message], force_system_prompt_refresh: bool = False):
"""Refresh in-context messages.
This performs two tasks:
1) Rebuild the *system prompt* only if the memory/tool-rules/directories section has changed.
This avoids rebuilding the system prompt on every step due to dynamic metadata (e.g. message counts),
which can bust prefix caching.
2) Scrub inner thoughts from messages.
Args:
in_context_messages: Current in-context messages
force_system_prompt_refresh: If True, forces evaluation of whether the system prompt needs to be rebuilt.
(The rebuild will still be skipped if memory/tool-rules/directories haven't changed.)
Returns:
Refreshed in-context messages.
"""
# Only rebuild when explicitly forced (e.g., after compaction).
# Normal turns should not trigger system prompt recompilation.
if force_system_prompt_refresh:
try:
in_context_messages = await self._rebuild_memory(
in_context_messages,
num_messages=None,
num_archival_memories=None,
force=True,
)
except Exception:
raise
# Always scrub inner thoughts regardless of system prompt refresh
in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, self.agent_state.llm_config)
return in_context_messages
@@ -699,8 +763,9 @@ class LettaAgentV2(BaseAgentV2):
async def _rebuild_memory(
self,
in_context_messages: list[Message],
num_messages: int,
num_archival_memories: int,
num_messages: int | None,
num_archival_memories: int | None,
force: bool = False,
):
agent_state = await self.agent_manager.refresh_memory_async(agent_state=self.agent_state, actor=self.actor)
@@ -721,49 +786,26 @@ class LettaAgentV2(BaseAgentV2):
else:
archive_tags = None
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_system_message_text = curr_system_message.content[0].text
# Extract the memory section that includes <memory_blocks>, tool rules, and directories.
# This avoids timestamp comparison issues in <memory_metadata>, which is dynamic.
def extract_memory_section(text: str) -> str:
# Primary pattern: everything from <memory_blocks> up to <memory_metadata>
mem_start = text.find("<memory_blocks>")
meta_start = text.find("<memory_metadata>")
if mem_start != -1:
if meta_start != -1 and meta_start > mem_start:
return text[mem_start:meta_start]
return text[mem_start:]
# Fallback pattern used in some legacy prompts: between </base_instructions> and <memory_metadata>
base_end = text.find("</base_instructions>")
if base_end != -1:
if meta_start != -1 and meta_start > base_end:
return text[base_end + len("</base_instructions>") : meta_start]
return text[base_end + len("</base_instructions>") :]
# Last resort: return full text
return text
curr_memory_section = extract_memory_section(curr_system_message_text)
# refresh files
agent_state = await self.agent_manager.refresh_file_blocks(agent_state=agent_state, actor=self.actor)
# generate just the memory string with current state for comparison
# generate memory string with current state
curr_memory_str = agent_state.memory.compile(
tool_usage_rules=tool_constraint_block,
sources=agent_state.sources,
max_files_open=agent_state.max_files_open,
llm_config=agent_state.llm_config,
)
new_memory_section = extract_memory_section(curr_memory_str)
# compare just the memory sections (memory blocks, tool rules, directories)
if curr_memory_section.strip() == new_memory_section.strip():
# Skip rebuild unless explicitly forced and unless system/memory content actually changed.
system_prompt_changed = agent_state.system not in curr_system_message_text
memory_changed = curr_memory_str not in curr_system_message_text
if (not force) and (not system_prompt_changed) and (not memory_changed):
self.logger.debug(
f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
f"Memory, sources, and system prompt haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
@@ -793,7 +835,7 @@ class LettaAgentV2(BaseAgentV2):
new_system_message = await self.message_manager.update_message_by_id_async(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
return [new_system_message] + in_context_messages[1:]
return [new_system_message, *in_context_messages[1:]]
else:
return in_context_messages
@@ -864,6 +906,7 @@ class LettaAgentV2(BaseAgentV2):
step_id=step_id,
project_id=self.agent_state.project_id,
status=StepStatus.PENDING,
model_handle=self.agent_state.llm_config.handle,
)
# Also create step metrics early and update at the end of the step
@@ -1279,7 +1322,7 @@ class LettaAgentV2(BaseAgentV2):
self.logger.warning(
f"Total tokens {total_tokens} exceeds configured max tokens {self.agent_state.llm_config.context_window}, forcefully clearing message history."
)
new_in_context_messages, updated = await self.summarizer.summarize(
new_in_context_messages, _updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
force=True,
@@ -1292,7 +1335,7 @@ class LettaAgentV2(BaseAgentV2):
self.logger.info(
f"Total tokens {total_tokens} does not exceed configured max tokens {self.agent_state.llm_config.context_window}, passing summarizing w/o force."
)
new_in_context_messages, updated = await self.summarizer.summarize(
new_in_context_messages, _updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
run_id=run_id,

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,13 @@
import json
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, AsyncGenerator, Dict, List, Optional
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional
import openai
if TYPE_CHECKING:
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.agents.base_agent import BaseAgent
from letta.agents.exceptions import IncompatibleAgentType
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
@@ -250,7 +253,6 @@ class VoiceAgent(BaseAgent):
agent_state=agent_state,
)
tool_result = tool_execution_result.func_return
success_flag = tool_execution_result.success_flag
# 3. Provide function_call response back into the conversation
# TODO: fix this tool format
@@ -292,7 +294,7 @@ class VoiceAgent(BaseAgent):
new_letta_messages = await self.message_manager.create_many_messages_async(letta_message_db_queue, actor=self.actor)
# TODO: Make this more general and configurable, less brittle
new_in_context_messages, updated = await summarizer.summarize(
new_in_context_messages, _updated = await summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
)

View File

@@ -1,4 +1,9 @@
from typing import AsyncGenerator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Tuple, Union
if TYPE_CHECKING:
from opentelemetry.trace import Span
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.agents.helpers import _create_letta_response, serialize_message_history
from letta.agents.letta_agent import LettaAgent
@@ -89,7 +94,7 @@ class VoiceSleeptimeAgent(LettaAgent):
current_in_context_messages, new_in_context_messages, stop_reason, usage = await super()._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
)
new_in_context_messages, updated = await self.summarizer.summarize(
new_in_context_messages, _updated = await self.summarizer.summarize(
in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages
)
self.agent_manager.set_in_context_messages(

View File

@@ -5,7 +5,6 @@ from typing import Annotated, Optional
import typer
from letta.log import get_logger
from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
logger = get_logger(__name__)

232
letta/config_file.py Normal file
View File

@@ -0,0 +1,232 @@
"""
Letta Configuration File Support
Loads hierarchical YAML config and maps it to environment variables.
Supported top-level keys and their env var prefixes:
letta: -> LETTA_*
model: -> * (provider-prefixed: OPENAI_*, ANTHROPIC_*, etc.)
tool: -> * (prefix-based: E2B_*, MCP_*, TOOL_*, etc.)
datadog: -> DD_*
Config file format:
letta:
telemetry:
enable_datadog: true
pg:
host: localhost
model:
openai:
api_key: sk-xxx
anthropic:
api_key: sk-yyy
tool:
e2b:
api_key: xxx
mcp:
disable_stdio: true
datadog:
site: us5.datadoghq.com
service: memgpt-server
This maps to environment variables:
LETTA_TELEMETRY_ENABLE_DATADOG=true
LETTA_PG_HOST=localhost
OPENAI_API_KEY=sk-xxx
ANTHROPIC_API_KEY=sk-yyy
E2B_API_KEY=xxx
MCP_DISABLE_STDIO=true
DD_SITE=us5.datadoghq.com
DD_SERVICE=memgpt-server
Config file locations (in order of precedence):
1. ~/.letta/conf.yaml
2. ./conf.yaml
3. LETTA_CONFIG_PATH environment variable
"""
import os
from pathlib import Path
from typing import Any
import yaml
# Config file locations
DEFAULT_USER_CONFIG = Path.home() / ".letta" / "conf.yaml"
DEFAULT_PROJECT_CONFIG = Path.cwd() / "conf.yaml"
def load_config_file(config_path: str | Path | None = None) -> dict[str, Any]:
"""
Load configuration from YAML file.
Args:
config_path: Optional explicit path to config file
Returns:
Loaded config dict, or empty dict if no config found
"""
paths_to_check = []
# Check in order of precedence (lowest to highest)
if DEFAULT_USER_CONFIG.exists():
paths_to_check.append(DEFAULT_USER_CONFIG)
if DEFAULT_PROJECT_CONFIG.exists():
paths_to_check.append(DEFAULT_PROJECT_CONFIG)
# Environment variable override
env_path = os.environ.get("LETTA_CONFIG_PATH")
if env_path and Path(env_path).exists():
paths_to_check.append(Path(env_path))
# Explicit path has highest precedence
if config_path:
p = Path(config_path)
if p.exists():
paths_to_check.append(p)
# Merge configs (later files override earlier)
config: dict[str, Any] = {}
for path in paths_to_check:
try:
with open(path, "r") as f:
file_config = yaml.safe_load(f)
if file_config:
config = _deep_merge(config, file_config)
except Exception:
pass
return config
def _deep_merge(base: dict, override: dict) -> dict:
"""Deep merge two dicts, override values take precedence."""
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
def _flatten_with_prefix(d: dict, prefix: str, env_vars: dict[str, str]) -> None:
"""Flatten a dict with a given prefix."""
for key, value in d.items():
env_key = f"{prefix}_{key}".upper() if prefix else key.upper()
if isinstance(value, dict):
_flatten_with_prefix(value, env_key, env_vars)
elif value is not None:
if isinstance(value, bool):
env_vars[env_key] = str(value).lower()
else:
env_vars[env_key] = str(value)
def _flatten_model_settings(d: dict, env_vars: dict[str, str]) -> None:
"""
Flatten model settings where nested keys become prefixes.
model:
openai:
api_key: xxx -> OPENAI_API_KEY
api_base: yyy -> OPENAI_API_BASE
anthropic:
api_key: zzz -> ANTHROPIC_API_KEY
global_max_context_window_limit: 32000 -> GLOBAL_MAX_CONTEXT_WINDOW_LIMIT
"""
for key, value in d.items():
if isinstance(value, dict):
# Nested provider config: openai.api_key -> OPENAI_API_KEY
_flatten_with_prefix(value, key.upper(), env_vars)
elif value is not None:
# Top-level model setting
env_key = key.upper()
if isinstance(value, bool):
env_vars[env_key] = str(value).lower()
else:
env_vars[env_key] = str(value)
def _flatten_tool_settings(d: dict, env_vars: dict[str, str]) -> None:
"""
Flatten tool settings where nested keys become prefixes.
tool:
e2b:
api_key: xxx -> E2B_API_KEY
sandbox_template_id: y -> E2B_SANDBOX_TEMPLATE_ID
mcp:
disable_stdio: true -> MCP_DISABLE_STDIO
tool_sandbox_timeout: 180 -> TOOL_SANDBOX_TIMEOUT
"""
for key, value in d.items():
if isinstance(value, dict):
# Nested tool config: e2b.api_key -> E2B_API_KEY
_flatten_with_prefix(value, key.upper(), env_vars)
elif value is not None:
# Top-level tool setting
env_key = key.upper()
if isinstance(value, bool):
env_vars[env_key] = str(value).lower()
else:
env_vars[env_key] = str(value)
def config_to_env_vars(config: dict[str, Any]) -> dict[str, str]:
"""
Convert hierarchical config to flat environment variables.
Supports multiple top-level keys with different prefix behaviors:
- letta: -> LETTA_* prefix
- model: -> provider-prefixed (OPENAI_*, ANTHROPIC_*, etc.)
- tool: -> prefix-based (E2B_*, MCP_*, TOOL_*, etc.)
- datadog: -> DD_* prefix
Args:
config: Hierarchical config dict
Returns:
Dict of environment variable name -> value
"""
env_vars: dict[str, str] = {}
# Handle 'letta' section with LETTA_ prefix
if "letta" in config:
_flatten_with_prefix(config["letta"], "LETTA", env_vars)
# Handle 'model' section (provider-prefixed env vars)
if "model" in config:
_flatten_model_settings(config["model"], env_vars)
# Handle 'tool' section (prefix-based env vars)
if "tool" in config:
_flatten_tool_settings(config["tool"], env_vars)
# Handle 'datadog' section with DD_ prefix
if "datadog" in config:
_flatten_with_prefix(config["datadog"], "DD", env_vars)
return env_vars
def apply_config_to_env(config_path: str | Path | None = None) -> None:
"""
Load config file and apply values to environment variables.
Environment variables already set take precedence over config file values.
Args:
config_path: Optional explicit path to config file
"""
config = load_config_file(config_path)
if not config:
return
env_vars = config_to_env_vars(config)
for key, value in env_vars.items():
# Only set if not already in environment (env vars take precedence)
if key not in os.environ:
os.environ[key] = value

View File

@@ -78,7 +78,7 @@ DEFAULT_CONTEXT_WINDOW = 32000
# Summarization trigger threshold (multiplier of context_window limit)
# Summarization triggers when step usage > context_window * SUMMARIZATION_TRIGGER_MULTIPLIER
SUMMARIZATION_TRIGGER_MULTIPLIER = 1.0
SUMMARIZATION_TRIGGER_MULTIPLIER = 0.9 # using instead of 1.0 to avoid "too many tokens in prompt" fallbacks
# number of concurrent embedding requests to sent
EMBEDDING_BATCH_SIZE = 200
@@ -252,8 +252,11 @@ LLM_MAX_CONTEXT_WINDOW = {
"deepseek-chat": 64000,
"deepseek-reasoner": 64000,
# glm (Z.AI)
"glm-4.6": 200000,
"glm-4.5": 128000,
"glm-4.6": 200000,
"glm-4.7": 200000,
"glm-5": 200000,
"glm-5-code": 200000,
## OpenAI models: https://platform.openai.com/docs/models/overview
# gpt-5
"gpt-5": 272000,
@@ -383,6 +386,7 @@ LLM_MAX_CONTEXT_WINDOW = {
"gemini-2.5-computer-use-preview-10-2025": 1048576,
# gemini 3
"gemini-3-pro-preview": 1048576,
"gemini-3.1-pro-preview": 1048576,
"gemini-3-flash-preview": 1048576,
# gemini latest aliases
"gemini-flash-latest": 1048576,
@@ -457,10 +461,18 @@ REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
CONVERSATION_LOCK_PREFIX = "conversation:lock:"
CONVERSATION_LOCK_TTL_SECONDS = 300 # 5 minutes
# Memory repo locks - prevents concurrent modifications to git-based memory
MEMORY_REPO_LOCK_PREFIX = "memory_repo:lock:"
MEMORY_REPO_LOCK_TTL_SECONDS = 60 # 1 minute (git operations should be fast)
# TODO: This is temporary, eventually use token-based eviction
# File based controls
DEFAULT_MAX_FILES_OPEN = 5
DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT: int = 50000
# Max values for file controls (int32 limit to match database INTEGER type)
MAX_INT32: int = 2147483647
MAX_PER_FILE_VIEW_WINDOW_CHAR_LIMIT: int = MAX_INT32
MAX_FILES_OPEN_LIMIT: int = 1000 # Practical limit - no agent needs 1000+ files open
GET_PROVIDERS_TIMEOUT_SECONDS = 10

View File

@@ -1,4 +1,7 @@
from typing import Dict, Iterator, List, Tuple
from typing import TYPE_CHECKING, Dict, Iterator, List, Tuple
if TYPE_CHECKING:
from letta.schemas.user import User
import typer
@@ -143,7 +146,13 @@ async def load_data(connector: DataConnector, source: Source, passage_manager: P
class DirectoryConnector(DataConnector):
def __init__(self, input_files: List[str] = None, input_directory: str = None, recursive: bool = False, extensions: List[str] = None):
def __init__(
self,
input_files: List[str] | None = None,
input_directory: str | None = None,
recursive: bool = False,
extensions: List[str] | None = None,
):
"""
Connector for reading text data from a directory of files.

View File

@@ -2,8 +2,16 @@ import asyncio
from functools import wraps
from typing import Any, Dict, List, Optional, Set, Union
from letta.constants import CONVERSATION_LOCK_PREFIX, CONVERSATION_LOCK_TTL_SECONDS, REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
from letta.errors import ConversationBusyError
from letta.constants import (
CONVERSATION_LOCK_PREFIX,
CONVERSATION_LOCK_TTL_SECONDS,
MEMORY_REPO_LOCK_PREFIX,
MEMORY_REPO_LOCK_TTL_SECONDS,
REDIS_EXCLUDE,
REDIS_INCLUDE,
REDIS_SET_DEFAULT_VAL,
)
from letta.errors import ConversationBusyError, MemoryRepoBusyError
from letta.log import get_logger
from letta.settings import settings
@@ -141,7 +149,7 @@ class AsyncRedisClient:
try:
client = await self.get_client()
return await client.get(key)
except:
except Exception:
return default
@with_retry()
@@ -230,6 +238,64 @@ class AsyncRedisClient:
logger.warning(f"Failed to release conversation lock for conversation {conversation_id}: {e}")
return False
async def acquire_memory_repo_lock(
self,
agent_id: str,
token: str,
) -> Optional["Lock"]:
"""
Acquire a distributed lock for a memory repository.
Prevents concurrent modifications to an agent's git-based memory.
Args:
agent_id: The agent ID whose memory is being modified
token: Unique identifier for the lock holder (for debugging/tracing)
Returns:
Lock object if acquired, raises MemoryRepoBusyError if in use
"""
if Lock is None:
return None
client = await self.get_client()
lock_key = f"{MEMORY_REPO_LOCK_PREFIX}{agent_id}"
lock = Lock(
client,
lock_key,
timeout=MEMORY_REPO_LOCK_TTL_SECONDS,
blocking=False,
thread_local=False,
raise_on_release_error=False,
)
if await lock.acquire(token=token):
return lock
lock_holder_token = await client.get(lock_key)
raise MemoryRepoBusyError(
agent_id=agent_id,
lock_holder_token=lock_holder_token,
)
async def release_memory_repo_lock(self, agent_id: str) -> bool:
"""
Release a memory repo lock by agent_id.
Args:
agent_id: The agent ID to release the lock for
Returns:
True if lock was released, False if release failed
"""
try:
client = await self.get_client()
lock_key = f"{MEMORY_REPO_LOCK_PREFIX}{agent_id}"
await client.delete(lock_key)
return True
except Exception as e:
logger.warning(f"Failed to release memory repo lock for agent {agent_id}: {e}")
return False
@with_retry()
async def exists(self, *keys: str) -> int:
"""Check if keys exist."""
@@ -254,7 +320,7 @@ class AsyncRedisClient:
client = await self.get_client()
result = await client.smismember(key, values)
return result if isinstance(values, list) else result[0]
except:
except Exception:
return [0] * len(values) if isinstance(values, list) else 0
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
@@ -464,6 +530,16 @@ class NoopAsyncRedisClient(AsyncRedisClient):
async def release_conversation_lock(self, conversation_id: str) -> bool:
return False
async def acquire_memory_repo_lock(
self,
agent_id: str,
token: str,
) -> Optional["Lock"]:
return None
async def release_memory_repo_lock(self, agent_id: str) -> bool:
return False
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
return False

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
# Avoid circular imports
if TYPE_CHECKING:
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message
@@ -20,6 +21,7 @@ class ErrorCode(Enum):
TIMEOUT = "TIMEOUT"
CONFLICT = "CONFLICT"
EXPIRED = "EXPIRED"
PAYMENT_REQUIRED = "PAYMENT_REQUIRED"
class LettaError(Exception):
@@ -91,6 +93,22 @@ class ConversationBusyError(LettaError):
super().__init__(message=message, code=code, details=details)
class MemoryRepoBusyError(LettaError):
"""Error raised when attempting to modify memory while another operation is in progress."""
def __init__(self, agent_id: str, lock_holder_token: Optional[str] = None):
self.agent_id = agent_id
self.lock_holder_token = lock_holder_token
message = "Cannot modify memory: Another operation is currently in progress for this agent's memory. Please wait for the current operation to complete."
code = ErrorCode.CONFLICT
details = {
"error_code": "MEMORY_REPO_BUSY",
"agent_id": agent_id,
"lock_holder_token": lock_holder_token,
}
super().__init__(message=message, code=code, details=details)
class LettaToolCreateError(LettaError):
"""Error raised when a tool cannot be created."""
@@ -167,7 +185,9 @@ class LettaImageFetchError(LettaError):
def __init__(self, url: str, reason: str):
details = {"url": url, "reason": reason}
super().__init__(
message=f"Failed to fetch image from {url}: {reason}", code=ErrorCode.INVALID_ARGUMENT, details=details,
message=f"Failed to fetch image from {url}: {reason}",
code=ErrorCode.INVALID_ARGUMENT,
details=details,
)
@@ -238,6 +258,10 @@ class LLMBadRequestError(LLMError):
"""Error when LLM service cannot process request"""
class LLMInsufficientCreditsError(LLMError):
"""Error when LLM provider reports insufficient credits or quota"""
class LLMAuthenticationError(LLMError):
"""Error when authentication fails with LLM service"""
@@ -308,7 +332,9 @@ class ContextWindowExceededError(LettaError):
def __init__(self, message: str, details: dict = {}):
error_message = f"{message} ({details})"
super().__init__(
message=error_message, code=ErrorCode.CONTEXT_WINDOW_EXCEEDED, details=details,
message=error_message,
code=ErrorCode.CONTEXT_WINDOW_EXCEEDED,
details=details,
)
@@ -328,7 +354,9 @@ class RateLimitExceededError(LettaError):
def __init__(self, message: str, max_retries: int):
error_message = f"{message} ({max_retries})"
super().__init__(
message=error_message, code=ErrorCode.RATE_LIMIT_EXCEEDED, details={"max_retries": max_retries},
message=error_message,
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details={"max_retries": max_retries},
)
@@ -383,7 +411,8 @@ class HandleNotFoundError(LettaError):
def __init__(self, handle: str, available_handles: List[str]):
super().__init__(
message=f"Handle {handle} not found, must be one of {available_handles}", code=ErrorCode.NOT_FOUND,
message=f"Handle {handle} not found, must be one of {available_handles}",
code=ErrorCode.NOT_FOUND,
)
@@ -423,6 +452,16 @@ class AgentFileImportError(Exception):
"""Exception raised during agent file import operations"""
class InsufficientCreditsError(LettaError):
"""Raised when an organization has no remaining credits."""
def __init__(self):
super().__init__(
message="Insufficient credits to process this request.",
details={"error_code": "INSUFFICIENT_CREDITS"},
)
class RunCancelError(LettaError):
"""Error raised when a run cannot be cancelled."""

View File

@@ -1,10 +1,11 @@
from typing import TYPE_CHECKING, Any, List, Literal, Optional
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
from typing import TYPE_CHECKING, List, Literal, Optional
if TYPE_CHECKING:
from letta.agents.letta_agent import LettaAgent as Agent
from letta.schemas.agent import AgentState
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
def memory(
agent_state: "AgentState",
@@ -242,7 +243,7 @@ async def archival_memory_search(
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> str: # type: ignore
"""
Append to the contents of core memory.
@@ -251,15 +252,15 @@ def core_memory_append(agent_state: "AgentState", label: str, content: str) -> O
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
str: The updated value of the memory block.
"""
current_value = str(agent_state.memory.get_block(label).value)
new_value = current_value + "\n" + str(content)
agent_state.memory.update_block_value(label=label, value=new_value)
return None
return new_value
def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> str: # type: ignore
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
@@ -269,14 +270,14 @@ def core_memory_replace(agent_state: "AgentState", label: str, old_content: str,
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
str: The updated value of the memory block.
"""
current_value = str(agent_state.memory.get_block(label).value)
if old_content not in current_value:
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
new_value = current_value.replace(str(old_content), str(new_content))
agent_state.memory.update_block_value(label=label, value=new_value)
return None
return new_value
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> None:
@@ -307,125 +308,118 @@ SNIPPET_LINES: int = 4
# Based off of: https://github.com/anthropics/anthropic-quickstarts/blob/main/computer-use-demo/computer_use_demo/tools/edit.py?ref=musings.yasyf.com#L154
def memory_replace(agent_state: "AgentState", label: str, old_str: str, new_str: str) -> str: # type: ignore
def memory_replace(agent_state: "AgentState", label: str, old_string: str, new_string: str) -> str: # type: ignore
"""
The memory_replace command allows you to replace a specific string in a memory block with a new string. This is used for making precise edits.
Do NOT attempt to replace long strings, e.g. do not attempt to replace the entire contents of a memory block with a new string.
Args:
label (str): Section of the memory to be edited, identified by its label.
old_str (str): The text to replace (must match exactly, including whitespace and indentation).
new_str (str): The new text to insert in place of the old text. Do not include line number prefixes.
old_string (str): The text to replace (must match exactly, including whitespace and indentation).
new_string (str): The new text to insert in place of the old text. Do not include line number prefixes.
Examples:
# Update a block containing information about the user
memory_replace(label="human", old_str="Their name is Alice", new_str="Their name is Bob")
memory_replace(label="human", old_string="Their name is Alice", new_string="Their name is Bob")
# Update a block containing a todo list
memory_replace(label="todos", old_str="- [ ] Step 5: Search the web", new_str="- [x] Step 5: Search the web")
memory_replace(label="todos", old_string="- [ ] Step 5: Search the web", new_string="- [x] Step 5: Search the web")
# Pass an empty string to
memory_replace(label="human", old_str="Their name is Alice", new_str="")
memory_replace(label="human", old_string="Their name is Alice", new_string="")
# Bad example - do NOT add (view-only) line numbers to the args
memory_replace(label="human", old_str="1: Their name is Alice", new_str="1: Their name is Bob")
memory_replace(label="human", old_string="1: Their name is Alice", new_string="1: Their name is Bob")
# Bad example - do NOT include the line number warning either
memory_replace(label="human", old_str="# NOTE: Line numbers shown below (with arrows like '1→') are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\\n1→ Their name is Alice", new_str="1→ Their name is Bob")
memory_replace(label="human", old_string="# NOTE: Line numbers shown below (with arrows like '1→') are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\\n1→ Their name is Alice", new_string="1→ Their name is Bob")
# Good example - no line numbers or line number warning (they are view-only), just the text
memory_replace(label="human", old_str="Their name is Alice", new_str="Their name is Bob")
memory_replace(label="human", old_string="Their name is Alice", new_string="Their name is Bob")
Returns:
str: The success message
str: The updated value of the memory block.
"""
import re
if bool(re.search(r"\nLine \d+: ", old_str)):
if bool(re.search(r"\nLine \d+: ", old_string)):
raise ValueError(
"old_str contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
"old_string contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
)
if CORE_MEMORY_LINE_NUMBER_WARNING in old_str:
if CORE_MEMORY_LINE_NUMBER_WARNING in old_string:
raise ValueError(
"old_str contains a line number warning, which is not allowed. Do not include line number information when calling memory tools (line numbers are for display purposes only)."
"old_string contains a line number warning, which is not allowed. Do not include line number information when calling memory tools (line numbers are for display purposes only)."
)
if bool(re.search(r"\nLine \d+: ", new_str)):
if bool(re.search(r"\nLine \d+: ", new_string)):
raise ValueError(
"new_str contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
"new_string contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
)
old_str = str(old_str).expandtabs()
new_str = str(new_str).expandtabs()
old_string = str(old_string).expandtabs()
new_string = str(new_string).expandtabs()
current_value = str(agent_state.memory.get_block(label).value).expandtabs()
# Check if old_str is unique in the block
occurences = current_value.count(old_str)
# Check if old_string is unique in the block
occurences = current_value.count(old_string)
if occurences == 0:
raise ValueError(f"No replacement was performed, old_str `{old_str}` did not appear verbatim in memory block with label `{label}`.")
raise ValueError(
f"No replacement was performed, old_string `{old_string}` did not appear verbatim in memory block with label `{label}`."
)
elif occurences > 1:
content_value_lines = current_value.split("\n")
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_str in line]
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_string in line]
raise ValueError(
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique."
f"No replacement was performed. Multiple occurrences of old_string `{old_string}` in lines {lines}. Please ensure it is unique."
)
# Replace old_str with new_str
new_value = current_value.replace(str(old_str), str(new_str))
# Replace old_string with new_string
new_value = current_value.replace(str(old_string), str(new_string))
# Write the new content to the block
agent_state.memory.update_block_value(label=label, value=new_value)
# Create a snippet of the edited section
# SNIPPET_LINES = 3
# replacement_line = current_value.split(old_str)[0].count("\n")
# replacement_line = current_value.split(old_string)[0].count("\n")
# start_line = max(0, replacement_line - SNIPPET_LINES)
# end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
# end_line = replacement_line + SNIPPET_LINES + new_string.count("\n")
# snippet = "\n".join(new_value.split("\n")[start_line : end_line + 1])
# Prepare the success message
success_msg = (
f"The core memory block with label `{label}` has been successfully edited. "
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
f"Review the changes and make sure they are as expected (correct indentation, "
f"no duplicate lines, etc). Edit the memory block again if necessary."
)
# return None
return success_msg
return new_value
def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_line: int = -1) -> Optional[str]: # type: ignore
def memory_insert(agent_state: "AgentState", label: str, new_string: str, insert_line: int = -1) -> str: # type: ignore
"""
The memory_insert command allows you to insert text at a specific location in a memory block.
Args:
label (str): Section of the memory to be edited, identified by its label.
new_str (str): The text to insert. Do not include line number prefixes.
new_string (str): The text to insert. Do not include line number prefixes.
insert_line (int): The line number after which to insert the text (0 for beginning of file). Defaults to -1 (end of the file).
Examples:
# Update a block containing information about the user (append to the end of the block)
memory_insert(label="customer", new_str="The customer's ticket number is 12345")
memory_insert(label="customer", new_string="The customer's ticket number is 12345")
# Update a block containing information about the user (insert at the beginning of the block)
memory_insert(label="customer", new_str="The customer's ticket number is 12345", insert_line=0)
memory_insert(label="customer", new_string="The customer's ticket number is 12345", insert_line=0)
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
import re
if bool(re.search(r"\nLine \d+: ", new_str)):
if bool(re.search(r"\nLine \d+: ", new_string)):
raise ValueError(
"new_str contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
"new_string contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
)
if CORE_MEMORY_LINE_NUMBER_WARNING in new_str:
if CORE_MEMORY_LINE_NUMBER_WARNING in new_string:
raise ValueError(
"new_str contains a line number warning, which is not allowed. Do not include line number information when calling memory tools (line numbers are for display purposes only)."
"new_string contains a line number warning, which is not allowed. Do not include line number information when calling memory tools (line numbers are for display purposes only)."
)
current_value = str(agent_state.memory.get_block(label).value).expandtabs()
new_str = str(new_str).expandtabs()
new_string = str(new_string).expandtabs()
current_value_lines = current_value.split("\n")
n_lines = len(current_value_lines)
@@ -438,11 +432,11 @@ def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_li
)
# Insert the new string as a line
new_str_lines = new_str.split("\n")
new_value_lines = current_value_lines[:insert_line] + new_str_lines + current_value_lines[insert_line:]
snippet_lines = (
new_string_lines = new_string.split("\n")
new_value_lines = current_value_lines[:insert_line] + new_string_lines + current_value_lines[insert_line:]
(
current_value_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ new_string_lines
+ current_value_lines[insert_line : insert_line + SNIPPET_LINES]
)
@@ -453,15 +447,7 @@ def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_li
# Write into the block
agent_state.memory.update_block_value(label=label, value=new_value)
# Prepare the success message
success_msg = (
f"The core memory block with label `{label}` has been successfully edited. "
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
f"Review the changes and make sure they are as expected (correct indentation, "
f"no duplicate lines, etc). Edit the memory block again if necessary."
)
return success_msg
return new_value
def memory_apply_patch(agent_state: "AgentState", label: str, patch: str) -> str: # type: ignore
@@ -499,7 +485,7 @@ def memory_apply_patch(agent_state: "AgentState", label: str, patch: str) -> str
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
def memory_rethink(agent_state: "AgentState", label: str, new_memory: str) -> None:
def memory_rethink(agent_state: "AgentState", label: str, new_memory: str) -> str:
"""
The memory_rethink command allows you to completely rewrite the contents of a memory block. Use this tool to make large sweeping changes (e.g. when you want to condense or reorganize the memory blocks), do NOT use this tool to make small precise edits (e.g. add or remove a line, replace a specific string, etc).
@@ -528,17 +514,7 @@ def memory_rethink(agent_state: "AgentState", label: str, new_memory: str) -> No
agent_state.memory.set_block(new_block)
agent_state.memory.update_block_value(label=label, value=new_memory)
# Prepare the success message
success_msg = (
f"The core memory block with label `{label}` has been successfully edited. "
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
f"Review the changes and make sure they are as expected (correct indentation, "
f"no duplicate lines, etc). Edit the memory block again if necessary."
)
# return None
return success_msg
return new_memory
def memory_finish_edits(agent_state: "AgentState") -> None: # type: ignore

View File

@@ -1,12 +1,13 @@
import asyncio
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING, List
if TYPE_CHECKING:
from letta.agents.letta_agent import LettaAgent as Agent
from letta.functions.helpers import (
_send_message_to_agents_matching_tags_async,
_send_message_to_all_agents_in_group_async,
execute_send_message_to_agent,
extract_send_message_from_steps_messages,
fire_and_forget_send_to_agent,
)
from letta.schemas.enums import MessageRole
@@ -27,9 +28,7 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_
str: The response from the target agent.
"""
augmented_message = (
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
f"[Incoming message from agent with ID '{self.agent_state.id}' - your response will be delivered to the sender] {message}"
)
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
@@ -56,57 +55,18 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
in the returned list.
"""
server = get_letta_server()
augmented_message = (
f"[Incoming message from external Letta agent - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
augmented_message = f"[Incoming message from external Letta agent - your response will be delivered to the sender] {message}"
# Find matching agents
matching_agents = server.agent_manager.list_agents_matching_tags(actor=self.user, match_all=match_all, match_some=match_some)
if not matching_agents:
return []
def process_agent(agent_id: str) -> str:
"""Loads an agent, formats the message, and executes .step()"""
actor = self.user # Ensure correct actor context
agent = server.load_agent(agent_id=agent_id, interface=None, actor=actor)
# Prepare the message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
# Prepare the message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
# Run .step() and return the response
usage_stats = agent.step(
input_messages=messages,
chaining=True,
max_chaining_steps=None,
stream=False,
skip_verify=True,
metadata=None,
put_inner_thoughts_first=True,
)
send_messages = extract_send_message_from_steps_messages(usage_stats.steps_messages, logger=agent.logger)
response_data = {
"agent_id": agent_id,
"response_messages": send_messages if send_messages else ["<no response>"],
}
return json.dumps(response_data, indent=2)
# Use ThreadPoolExecutor for parallel execution
results = []
with ThreadPoolExecutor(max_workers=settings.multi_agent_concurrent_sends) as executor:
future_to_agent = {executor.submit(process_agent, agent_state.id): agent_state for agent_state in matching_agents}
for future in as_completed(future_to_agent):
try:
results.append(future.result()) # Collect results
except Exception as e:
# Log or handle failure for specific agents if needed
self.logger.exception(f"Error processing agent {future_to_agent[future]}: {e}")
return results
# Use async helper for parallel message sending
return asyncio.run(_send_message_to_agents_matching_tags_async(self, server, messages, matching_agents))
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:
@@ -138,8 +98,8 @@ def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str
raise RuntimeError("This tool is not allowed to be run on Letta Cloud.")
message = (
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message_to_agent_async' tool, or the agent will not receive your message] "
f"[Incoming message from agent with ID '{self.agent_state.id}' - "
f"this is a one-way notification; if you need to respond, use an agent-to-agent messaging tool if available] "
f"{message}"
)
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]

View File

@@ -1,5 +1,8 @@
## Voice chat + sleeptime tools
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from letta.schemas.agent import AgentState
from pydantic import BaseModel, Field

View File

@@ -179,7 +179,7 @@ def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dic
pass # Field is required, no default
else:
field_kwargs["default"] = default_val
except:
except Exception:
pass
fields[field_name] = Field(**field_kwargs)
@@ -188,7 +188,7 @@ def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dic
try:
default_val = ast.literal_eval(stmt.value)
fields[field_name] = default_val
except:
except Exception:
pass
# Create the dynamic Pydantic model

View File

@@ -3,7 +3,17 @@ import json
import logging
import threading
from random import uniform
from typing import Any, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
if TYPE_CHECKING:
from letta.agents.letta_agent import LettaAgent as Agent
from letta.schemas.agent import AgentState
from letta.server.server import SyncServer
try:
from langchain.tools.base import BaseTool as LangChainBaseTool
except ImportError:
LangChainBaseTool = None
import humps
from pydantic import BaseModel, Field, create_model
@@ -21,6 +31,8 @@ from letta.server.rest_api.dependencies import get_letta_server
from letta.settings import settings
from letta.utils import safe_create_task
_background_tasks: set[asyncio.Task] = set()
# TODO needed?
def generate_mcp_tool_wrapper(mcp_tool_name: str) -> tuple[str, str]:
@@ -36,7 +48,8 @@ def {mcp_tool_name}(**kwargs):
def generate_langchain_tool_wrapper(
tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None
tool: "LangChainBaseTool",
additional_imports_module_attr_map: dict[str, str] | None = None,
) -> tuple[str, str]:
tool_name = tool.__class__.__name__
import_statement = f"from langchain_community.tools import {tool_name}"
@@ -428,15 +441,18 @@ def fire_and_forget_send_to_agent(
# 4) Try to schedule the coroutine in an existing loop, else spawn a thread
try:
loop = asyncio.get_running_loop()
# If we get here, a loop is running; schedule the coroutine in background
loop.create_task(background_task())
task = loop.create_task(background_task())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
# Means no event loop is running in this thread
run_in_background_thread(background_task())
async def _send_message_to_agents_matching_tags_async(
sender_agent: "Agent", server: "SyncServer", messages: List[MessageCreate], matching_agents: List["AgentState"]
sender_agent: "Agent",
server: "SyncServer",
messages: List[MessageCreate],
matching_agents: List["AgentState"],
) -> List[str]:
async def _send_single(agent_state):
return await _async_send_message_with_retries(
@@ -464,9 +480,7 @@ async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", mess
server = get_letta_server()
augmented_message = (
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - your response will be delivered to the sender] {message}"
)
worker_agents_ids = sender_agent.agent_state.multi_agent_group.agent_ids
@@ -520,7 +534,9 @@ def generate_model_from_args_json_schema(schema: Dict[str, Any]) -> Type[BaseMod
return _create_model_from_schema(schema.get("title", "DynamicModel"), schema, nested_models)
def _create_model_from_schema(name: str, model_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] = None) -> Type[BaseModel]:
def _create_model_from_schema(
name: str, model_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] | None = None
) -> Type[BaseModel]:
fields = {}
for field_name, field_schema in model_schema["properties"].items():
field_type = _get_field_type(field_schema, nested_models)
@@ -531,7 +547,7 @@ def _create_model_from_schema(name: str, model_schema: Dict[str, Any], nested_mo
return create_model(name, **fields)
def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] = None) -> Any:
def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] | None = None) -> Any:
"""Helper to convert JSON schema types to Python types."""
if field_schema.get("type") == "string":
return str

View File

@@ -98,6 +98,32 @@ class BaseServerConfig(BaseModel):
return result
@staticmethod
def _sanitize_dict_key(key: str) -> str:
"""Strip surrounding quotes and trailing colons from a dict key."""
key = key.strip()
for quote in ('"', "'"):
if key.startswith(quote) and key.endswith(quote):
key = key[1:-1]
break
key = key.rstrip(":")
return key.strip()
@staticmethod
def _sanitize_dict_value(value: str) -> str:
"""Strip surrounding quotes from a dict value."""
value = value.strip()
for quote in ('"', "'"):
if value.startswith(quote) and value.endswith(quote):
value = value[1:-1]
break
return value
@classmethod
def _sanitize_dict(cls, d: Dict[str, str]) -> Dict[str, str]:
"""Sanitize a string dict by stripping quotes from keys and values."""
return {cls._sanitize_dict_key(k): cls._sanitize_dict_value(v) for k, v in d.items()}
def resolve_custom_headers(
self, custom_headers: Optional[Dict[str, str]], environment_variables: Optional[Dict[str, str]] = None
) -> Optional[Dict[str, str]]:
@@ -114,6 +140,8 @@ class BaseServerConfig(BaseModel):
if custom_headers is None:
return None
custom_headers = self._sanitize_dict(custom_headers)
resolved_headers = {}
for key, value in custom_headers.items():
# Resolve templated variables in each header value
@@ -164,8 +192,12 @@ class HTTPBasedServerConfig(BaseServerConfig):
return None
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
self.auth_token = super().get_tool_variable(self.auth_token, environment_variables)
if self.auth_header:
self.auth_header = self._sanitize_dict_key(self.auth_header)
if self.auth_token:
self.auth_token = self._sanitize_dict_value(self.auth_token)
if super().is_templated_tool_variable(self.auth_token):
self.auth_token = super().get_tool_variable(self.auth_token, environment_variables)
self.custom_headers = super().resolve_custom_headers(self.custom_headers, environment_variables)
@@ -176,11 +208,11 @@ class HTTPBasedServerConfig(BaseServerConfig):
Returns:
Dictionary of headers or None if no headers are configured
"""
if self.custom_headers is not None or (self.auth_header is not None and self.auth_token is not None):
if self.custom_headers is not None or (self.auth_header and self.auth_token):
headers = self.custom_headers.copy() if self.custom_headers else {}
# Add auth header if specified
if self.auth_header is not None and self.auth_token is not None:
# Add auth header if specified (skip if either is empty to avoid illegal header values)
if self.auth_header and self.auth_token:
headers[self.auth_header] = self.auth_token
return headers

View File

@@ -96,7 +96,7 @@ def type_to_json_schema_type(py_type) -> dict:
# Handle array types
origin = get_origin(py_type)
if py_type == list or origin in (list, List):
if py_type is list or origin in (list, List):
args = get_args(py_type)
if len(args) == 0:
# is this correct
@@ -142,7 +142,7 @@ def type_to_json_schema_type(py_type) -> dict:
}
# Handle object types
if py_type == dict or origin in (dict, Dict):
if py_type is dict or origin in (dict, Dict):
args = get_args(py_type)
if not args:
# Generic dict without type arguments
@@ -704,8 +704,9 @@ def generate_tool_schema_for_mcp(
name = mcp_tool.name
description = mcp_tool.description
assert "type" in parameters_schema, parameters_schema
assert "properties" in parameters_schema, parameters_schema
if "type" not in parameters_schema:
parameters_schema["type"] = "object"
parameters_schema.setdefault("properties", {})
# assert "required" in parameters_schema, parameters_schema
# Normalize the schema to fix common issues with MCP schemas

View File

@@ -56,7 +56,7 @@ def validate_complete_json_schema(schema: Dict[str, Any]) -> Tuple[SchemaHealth,
"""
if obj_schema.get("type") != "object":
return False
props = obj_schema.get("properties", {})
obj_schema.get("properties", {})
required = obj_schema.get("required", [])
additional = obj_schema.get("additionalProperties", True)

View File

@@ -1,4 +1,7 @@
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from letta.agents.letta_agent import LettaAgent as Agent
from letta.agents.base_agent import BaseAgent
from letta.agents.letta_agent import LettaAgent
@@ -92,7 +95,7 @@ class DynamicMultiAgent(BaseAgent):
# Parse manager response
responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages)
assistant_message = [response for response in responses if response.message_type == "assistant_message"][0]
assistant_message = next(response for response in responses if response.message_type == "assistant_message")
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
if name.lower() in assistant_message.content.lower():
speaker_id = agent_id

View File

@@ -98,7 +98,7 @@ def stringify_message(message: Message, use_assistant_name: bool = False) -> str
elif isinstance(content, ImageContent):
messages.append(f"{message.name or 'user'}: [Image Here]")
return "\n".join(messages)
except:
except Exception:
if message.content and len(message.content) > 0:
return f"{message.name or 'user'}: {message.content[0].text}"
return None

View File

@@ -1,4 +1,3 @@
import asyncio
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
@@ -213,7 +212,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
)
for sleeptime_agent_id in self.group.agent_ids:
run_id = await self._issue_background_task(
await self._issue_background_task(
sleeptime_agent_id,
last_response_messages,
last_processed_message_id,

View File

@@ -9,7 +9,6 @@ from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import RunStatus
from letta.schemas.group import Group, ManagerType
from letta.schemas.job import JobUpdate
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import ClientToolSchema
@@ -47,6 +46,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
include_return_message_types: list[MessageType] | None = None,
request_start_timestamp_ns: int | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False,
) -> LettaResponse:
self.run_ids = []
@@ -61,6 +61,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
include_return_message_types=include_return_message_types,
request_start_timestamp_ns=request_start_timestamp_ns,
client_tools=client_tools,
include_compaction_messages=include_compaction_messages,
)
await self.run_sleeptime_agents()
@@ -79,6 +80,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False,
) -> AsyncGenerator[str, None]:
self.run_ids = []
@@ -96,6 +98,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
include_return_message_types=include_return_message_types,
request_start_timestamp_ns=request_start_timestamp_ns,
client_tools=client_tools,
include_compaction_messages=include_compaction_messages,
):
yield chunk
finally:

View File

@@ -1,4 +1,3 @@
import asyncio
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
@@ -7,9 +6,8 @@ from letta.constants import DEFAULT_MAX_STEPS
from letta.groups.helpers import stringify_message
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import JobStatus, RunStatus
from letta.schemas.enums import RunStatus
from letta.schemas.group import Group, ManagerType
from letta.schemas.job import JobUpdate
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import ClientToolSchema
@@ -48,6 +46,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
request_start_timestamp_ns: int | None = None,
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False,
) -> LettaResponse:
self.run_ids = []
@@ -63,6 +62,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
request_start_timestamp_ns=request_start_timestamp_ns,
conversation_id=conversation_id,
client_tools=client_tools,
include_compaction_messages=include_compaction_messages,
)
run_ids = await self.run_sleeptime_agents()
@@ -81,6 +81,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
include_return_message_types: list[MessageType] | None = None,
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False,
) -> AsyncGenerator[str, None]:
self.run_ids = []
@@ -99,6 +100,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
request_start_timestamp_ns=request_start_timestamp_ns,
conversation_id=conversation_id,
client_tools=client_tools,
include_compaction_messages=include_compaction_messages,
):
yield chunk
finally:

View File

@@ -1,19 +1,9 @@
from typing import List, Optional
from typing import List
from letta.agents.base_agent import BaseAgent
from letta.constants import DEFAULT_MESSAGE_TOOL
from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.interface import AgentInterface
from letta.orm import User
from letta.schemas.agent import AgentState
from letta.schemas.enums import ToolType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import MessageCreate
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager
from letta.services.tool_manager import ToolManager

View File

@@ -1 +1 @@
from letta.helpers.tool_rule_solver import ToolRulesSolver
from letta.helpers.tool_rule_solver import ToolRulesSolver as ToolRulesSolver

View File

@@ -1,4 +1,7 @@
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
if TYPE_CHECKING:
from letta.services.summarizer.summarizer_config import CompactionSettings
import numpy as np
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse

View File

@@ -4,6 +4,53 @@ from datetime import datetime
from typing import Any
def sanitize_unicode_surrogates(value: Any) -> Any:
"""Recursively remove invalid Unicode surrogate characters from strings.
Unicode surrogate pairs (U+D800 to U+DFFF) are used internally by UTF-16 encoding
but are invalid as standalone characters in UTF-8. When present, they cause
UnicodeEncodeError when encoding to UTF-8, breaking API requests that need to
serialize data to JSON.
This function sanitizes:
- Strings: removes unpaired surrogates that can't be encoded to UTF-8
- Dicts: recursively sanitizes all string values
- Lists: recursively sanitizes all elements
- Other types: returned as-is
Args:
value: The value to sanitize
Returns:
The sanitized value with surrogate characters removed from all strings
"""
if isinstance(value, str):
# Remove lone surrogate characters (U+D800 to U+DFFF) which are invalid in UTF-8
# Using character filtering is more reliable than encode/decode for edge cases
try:
# Filter out any character in the surrogate range
return "".join(char for char in value if not (0xD800 <= ord(char) <= 0xDFFF))
except Exception:
# Fallback: try encode with errors="replace" which replaces surrogates with <20>
try:
return value.encode("utf-8", errors="replace").decode("utf-8")
except Exception:
# Last resort: return original (should never reach here)
return value
elif isinstance(value, dict):
# Recursively sanitize dictionary keys and values
return {sanitize_unicode_surrogates(k): sanitize_unicode_surrogates(v) for k, v in value.items()}
elif isinstance(value, list):
# Recursively sanitize list elements
return [sanitize_unicode_surrogates(item) for item in value]
elif isinstance(value, tuple):
# Recursively sanitize tuple elements (return as tuple)
return tuple(sanitize_unicode_surrogates(item) for item in value)
else:
# Return other types as-is (int, float, bool, None, etc.)
return value
def sanitize_null_bytes(value: Any) -> Any:
"""Recursively remove null bytes (0x00) from strings.

View File

@@ -139,6 +139,20 @@ async def _convert_message_create_to_message(
image_media_type, _ = mimetypes.guess_type(file_path)
if not image_media_type:
image_media_type = "image/jpeg" # default fallback
elif url.startswith("data:"):
# Handle data: URLs (inline base64 encoded images)
# Format: data:[<mediatype>][;base64],<data>
try:
# Split header from data
header, image_data = url.split(",", 1)
# Extract media type from header (e.g., "data:image/jpeg;base64")
header_parts = header.split(";")
image_media_type = header_parts[0].replace("data:", "") or "image/jpeg"
# Data is already base64 encoded, set directly and continue
content.source = Base64Image(media_type=image_media_type, data=image_data)
continue # Skip the common conversion path below
except ValueError:
raise LettaImageFetchError(url=url[:100] + "...", reason="Invalid data URL format")
else:
# Handle http(s):// URLs using async httpx
image_bytes, image_media_type = await _fetch_image_from_url(url)

View File

@@ -306,7 +306,9 @@ async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any],
@pinecone_retry()
@trace_method
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
async def list_pinecone_index_for_files(
file_id: str, actor: User, limit: int | None = None, pagination_token: str | None = None
) -> List[str]:
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")

View File

@@ -1,6 +1,6 @@
import copy
from collections import OrderedDict
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, MCP_TOOL_METADATA_SCHEMA_WARNINGS
@@ -201,7 +201,7 @@ def add_pre_execution_message(tool_schema: Dict[str, Any], description: Optional
# Ensure pre-execution message is the first required field
if PRE_EXECUTION_MESSAGE_ARG not in required:
required = [PRE_EXECUTION_MESSAGE_ARG] + required
required = [PRE_EXECUTION_MESSAGE_ARG, *required]
# Update the schema with ordered properties and required list
schema["parameters"] = {

View File

@@ -3,12 +3,20 @@
import asyncio
import json
import logging
import random
from datetime import datetime, timezone
from typing import Any, Callable, List, Optional, Tuple
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar
if TYPE_CHECKING:
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.user import User as PydanticUser
import httpx
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.errors import LettaInvalidArgumentError
from letta.otel.tracing import trace_method
from letta.otel.tracing import log_event, trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, TagMatchMode
from letta.schemas.passage import Passage as PydanticPassage
@@ -16,6 +24,136 @@ from letta.settings import model_settings, settings
logger = logging.getLogger(__name__)
# Type variable for generic async retry decorator
T = TypeVar("T")
# Default retry configuration for turbopuffer operations
TPUF_MAX_RETRIES = 3
TPUF_INITIAL_DELAY = 1.0 # seconds
TPUF_EXPONENTIAL_BASE = 2.0
TPUF_JITTER = True
def is_transient_error(error: Exception) -> bool:
"""Check if an error is transient and should be retried.
Args:
error: The exception to check
Returns:
True if the error is transient and can be retried
"""
# httpx connection errors (network issues, DNS failures, etc.)
if isinstance(error, httpx.ConnectError):
return True
# httpx timeout errors
if isinstance(error, httpx.TimeoutException):
return True
# httpx network errors
if isinstance(error, httpx.NetworkError):
return True
# Check for connection-related errors in the error message
error_str = str(error).lower()
transient_patterns = [
"connect call failed",
"connection refused",
"connection reset",
"connection timed out",
"temporary failure",
"name resolution",
"dns",
"network unreachable",
"no route to host",
"ssl handshake",
]
for pattern in transient_patterns:
if pattern in error_str:
return True
return False
def async_retry_with_backoff(
max_retries: int = TPUF_MAX_RETRIES,
initial_delay: float = TPUF_INITIAL_DELAY,
exponential_base: float = TPUF_EXPONENTIAL_BASE,
jitter: bool = TPUF_JITTER,
):
"""Decorator for async functions that retries on transient errors with exponential backoff.
Args:
max_retries: Maximum number of retry attempts
initial_delay: Initial delay between retries in seconds
exponential_base: Base for exponential backoff calculation
jitter: Whether to add random jitter to delays
Returns:
Decorated async function with retry logic
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args, **kwargs) -> Any:
num_retries = 0
delay = initial_delay
while True:
try:
return await func(*args, **kwargs)
except Exception as e:
# Check if this is a retryable error
if not is_transient_error(e):
# Not a transient error, re-raise immediately
raise
num_retries += 1
# Log the retry attempt
log_event(
"turbopuffer_retry_attempt",
{
"attempt": num_retries,
"delay": delay,
"error_type": type(e).__name__,
"error": str(e),
"function": func.__name__,
},
)
logger.warning(
f"Turbopuffer operation '{func.__name__}' failed with transient error "
f"(attempt {num_retries}/{max_retries}): {e}. Retrying in {delay:.1f}s..."
)
# Check if max retries exceeded
if num_retries > max_retries:
log_event(
"turbopuffer_max_retries_exceeded",
{
"max_retries": max_retries,
"error_type": type(e).__name__,
"error": str(e),
"function": func.__name__,
},
)
logger.error(f"Turbopuffer operation '{func.__name__}' failed after {max_retries} retries: {e}")
raise
# Wait with exponential backoff
await asyncio.sleep(delay)
# Calculate next delay with optional jitter
delay *= exponential_base
if jitter:
delay *= 1 + random.random() * 0.1 # Add up to 10% jitter
return wrapper
return decorator
# Global semaphore for Turbopuffer operations to prevent overwhelming the service
# This is separate from embedding semaphore since Turbopuffer can handle more concurrency
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
@@ -25,11 +163,11 @@ def _run_turbopuffer_write_in_thread(
api_key: str,
region: str,
namespace_name: str,
upsert_columns: dict = None,
deletes: list = None,
delete_by_filter: tuple = None,
upsert_columns: dict | None = None,
deletes: list | None = None,
delete_by_filter: tuple | None = None,
distance_metric: str = "cosine_distance",
schema: dict = None,
schema: dict | None = None,
):
"""
Sync wrapper to run turbopuffer write in isolated event loop.
@@ -93,7 +231,7 @@ class TurbopufferClient:
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
)
def __init__(self, api_key: str = None, region: str = None):
def __init__(self, api_key: str | None = None, region: str | None = None):
"""Initialize Turbopuffer client."""
self.api_key = api_key or settings.tpuf_api_key
self.region = region or settings.tpuf_region
@@ -222,6 +360,7 @@ class TurbopufferClient:
return json.dumps(parts)
@trace_method
@async_retry_with_backoff()
async def insert_tools(
self,
tools: List["PydanticTool"],
@@ -238,7 +377,6 @@ class TurbopufferClient:
Returns:
True if successful
"""
from turbopuffer import AsyncTurbopuffer
if not tools:
return True
@@ -313,6 +451,7 @@ class TurbopufferClient:
raise
@trace_method
@async_retry_with_backoff()
async def insert_archival_memories(
self,
archive_id: str,
@@ -339,7 +478,6 @@ class TurbopufferClient:
Returns:
List of PydanticPassage objects that were inserted
"""
from turbopuffer import AsyncTurbopuffer
# filter out empty text chunks
filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
@@ -464,6 +602,7 @@ class TurbopufferClient:
raise
@trace_method
@async_retry_with_backoff()
async def insert_messages(
self,
agent_id: str,
@@ -494,7 +633,6 @@ class TurbopufferClient:
Returns:
True if successful
"""
from turbopuffer import AsyncTurbopuffer
# filter out empty message texts
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
@@ -609,6 +747,7 @@ class TurbopufferClient:
raise
@trace_method
@async_retry_with_backoff()
async def _execute_query(
self,
namespace_name: str,
@@ -1377,9 +1516,9 @@ class TurbopufferClient:
return sorted_results[:top_k]
@trace_method
@async_retry_with_backoff()
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
"""Delete a passage from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_archive_namespace_name(archive_id)
@@ -1399,9 +1538,9 @@ class TurbopufferClient:
raise
@trace_method
@async_retry_with_backoff()
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
"""Delete multiple passages from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
if not passage_ids:
return True
@@ -1424,6 +1563,7 @@ class TurbopufferClient:
raise
@trace_method
@async_retry_with_backoff()
async def delete_all_passages(self, archive_id: str) -> bool:
"""Delete all passages for an archive from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
@@ -1442,9 +1582,9 @@ class TurbopufferClient:
raise
@trace_method
@async_retry_with_backoff()
async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
"""Delete multiple messages from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
if not message_ids:
return True
@@ -1467,9 +1607,9 @@ class TurbopufferClient:
raise
@trace_method
@async_retry_with_backoff()
async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
"""Delete all messages for an agent from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_message_namespace_name(organization_id)
@@ -1509,6 +1649,7 @@ class TurbopufferClient:
return namespace_name
@trace_method
@async_retry_with_backoff()
async def insert_file_passages(
self,
source_id: str,
@@ -1531,7 +1672,6 @@ class TurbopufferClient:
Returns:
List of PydanticPassage objects that were inserted
"""
from turbopuffer import AsyncTurbopuffer
if not text_chunks:
return []
@@ -1765,9 +1905,9 @@ class TurbopufferClient:
return passages_with_scores
@trace_method
@async_retry_with_backoff()
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
"""Delete all passages for a specific file from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_file_passages_namespace_name(organization_id)
@@ -1793,9 +1933,9 @@ class TurbopufferClient:
raise
@trace_method
@async_retry_with_backoff()
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
"""Delete all passages for a source from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_file_passages_namespace_name(organization_id)
@@ -1817,6 +1957,7 @@ class TurbopufferClient:
# tool methods
@trace_method
@async_retry_with_backoff()
async def delete_tools(self, organization_id: str, tool_ids: List[str]) -> bool:
"""Delete tools from Turbopuffer.
@@ -1827,7 +1968,6 @@ class TurbopufferClient:
Returns:
True if successful
"""
from turbopuffer import AsyncTurbopuffer
if not tool_ids:
return True

View File

@@ -136,7 +136,7 @@ class CLIInterface(AgentInterface):
else:
try:
msg_json = json_loads(msg)
except:
except Exception:
printd(f"{CLI_WARNING_PREFIX}failed to parse user message into json")
printd_user_message("🧑", msg)
return

View File

@@ -3,7 +3,12 @@ import json
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from opentelemetry.trace import Span
from letta.schemas.usage import LettaUsageStatistics
from anthropic import AsyncStream
from anthropic.types.beta import (
@@ -274,7 +279,13 @@ class SimpleAnthropicStreamingInterface:
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
)
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
# Transform Anthropic errors into our custom error types for consistent handling
from letta.llm_api.anthropic_client import AnthropicClient
client = AnthropicClient()
transformed_error = client.handle_llm_error(e)
raise transformed_error
finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
@@ -316,6 +327,7 @@ class SimpleAnthropicStreamingInterface:
id=decrement_message_uuid(self.letta_message_id),
# Do not emit placeholder arguments here to avoid UI duplicates
tool_call=ToolCallDelta(name=name, tool_call_id=call_id),
tool_calls=ToolCallDelta(name=name, tool_call_id=call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
run_id=self.run_id,
@@ -421,6 +433,7 @@ class SimpleAnthropicStreamingInterface:
tool_call_msg = ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
tool_call=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json),
tool_calls=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
run_id=self.run_id,

View File

@@ -3,7 +3,12 @@ import json
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from opentelemetry.trace import Span
from letta.schemas.usage import LettaUsageStatistics
from anthropic import AsyncStream
from anthropic.types.beta import (
@@ -116,7 +121,7 @@ class AnthropicStreamingInterface:
# Attempt to use OptimisticJSONParser to handle incomplete/malformed JSON
try:
tool_input = self.json_parser.parse(args_str)
except:
except Exception:
logger.warning(
f"Failed to decode tool call arguments for tool_call_id={self.tool_call_id}, "
f"name={self.tool_call_name}. Raw input: {args_str!r}. Error: {e}"
@@ -263,7 +268,13 @@ class AnthropicStreamingInterface:
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
)
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
# Transform Anthropic errors into our custom error types for consistent handling
from letta.llm_api.anthropic_client import AnthropicClient
client = AnthropicClient()
transformed_error = client.handle_llm_error(e)
raise transformed_error
finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
@@ -424,16 +435,19 @@ class AnthropicStreamingInterface:
if current_inner_thoughts:
tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "")
tool_call_delta = ToolCallDelta(
name=self.tool_call_name,
tool_call_id=self.tool_call_id,
arguments=tool_call_args,
)
approval_msg = ApprovalRequestMessage(
id=self.letta_message_id,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
date=datetime.now(timezone.utc).isoformat(),
name=self.tool_call_name,
tool_call=ToolCallDelta(
name=self.tool_call_name,
tool_call_id=self.tool_call_id,
arguments=tool_call_args,
),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
run_id=self.run_id,
step_id=self.step_id,
)
@@ -493,6 +507,9 @@ class AnthropicStreamingInterface:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
tool_calls=ToolCallDelta(
name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json
),
date=datetime.now(timezone.utc).isoformat(),
run_id=self.run_id,
step_id=self.step_id,
@@ -576,9 +593,15 @@ class AnthropicStreamingInterface:
pass
elif isinstance(event, BetaRawContentBlockStopEvent):
# If we're exiting a tool use block and there are still buffered messages,
# we should flush them now
# we should flush them now.
# Ensure each flushed chunk has an otid before yielding.
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
for buffered_msg in self.tool_call_buffer:
if not buffered_msg.otid:
if prev_message_type and prev_message_type != buffered_msg.message_type:
message_index += 1
buffered_msg.otid = Message.generate_otid_from_id(buffered_msg.id, message_index)
prev_message_type = buffered_msg.message_type
yield buffered_msg
self.tool_call_buffer = []
@@ -644,7 +667,7 @@ class SimpleAnthropicStreamingInterface:
# Attempt to use OptimisticJSONParser to handle incomplete/malformed JSON
try:
tool_input = self.json_parser.parse(args_str)
except:
except Exception:
logger.warning(
f"Failed to decode tool call arguments for tool_call_id={self.tool_call_id}, "
f"name={self.tool_call_name}. Raw input: {args_str!r}. Error: {e}"
@@ -827,6 +850,7 @@ class SimpleAnthropicStreamingInterface:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
tool_calls=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
@@ -911,6 +935,7 @@ class SimpleAnthropicStreamingInterface:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
tool_calls=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,

View File

@@ -3,7 +3,12 @@ import base64
import json
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import AsyncIterator, List, Optional
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
if TYPE_CHECKING:
from opentelemetry.trace import Span
from letta.schemas.usage import LettaUsageStatistics
from google.genai.types import (
GenerateContentResponse,
@@ -97,9 +102,11 @@ class SimpleGeminiStreamingInterface:
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
"""This is (unusually) in chunked format, instead of merged"""
has_reasoning = any(isinstance(c, ReasoningContent) for c in self.content_parts)
for content in self.content_parts:
if isinstance(content, ReasoningContent):
# This assumes there is only one signature per turn
content.signature = self.thinking_signature
elif isinstance(content, TextContent) and not has_reasoning and self.thinking_signature:
content.signature = self.thinking_signature
return self.content_parts
@@ -322,15 +329,18 @@ class SimpleGeminiStreamingInterface:
self.collected_tool_calls.append(ToolCall(id=call_id, function=FunctionCall(name=name, arguments=arguments_str)))
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
tool_call_delta = ToolCallDelta(
name=name,
arguments=arguments_str,
tool_call_id=call_id,
)
yield ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=name,
arguments=arguments_str,
tool_call_id=call_id,
),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
run_id=self.run_id,
step_id=self.step_id,
)

View File

@@ -1,7 +1,12 @@
import asyncio
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from opentelemetry.trace import Span
from letta.schemas.usage import LettaUsageStatistics
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@@ -14,6 +19,7 @@ from openai.types.responses import (
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
@@ -314,9 +320,6 @@ class OpenAIStreamingInterface:
# Track events for diagnostics
self.total_events_received += 1
self.last_event_type = "ChatCompletionChunk"
# Track events for diagnostics
self.total_events_received += 1
self.last_event_type = "ChatCompletionChunk"
if not self.model or not self.message_id:
self.model = chunk.model
@@ -414,25 +417,22 @@ class OpenAIStreamingInterface:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
self.tool_call_name = str(self._get_function_name_buffer())
tool_call_delta = ToolCallDelta(
name=self._get_function_name_buffer(),
arguments=None,
tool_call_id=self._get_current_function_id(),
)
if self.tool_call_name in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self._get_function_name_buffer(),
arguments=None,
tool_call_id=self._get_current_function_id(),
),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
run_id=self.run_id,
step_id=self.step_id,
)
else:
tool_call_delta = ToolCallDelta(
name=self._get_function_name_buffer(),
arguments=None,
tool_call_id=self._get_current_function_id(),
)
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
@@ -471,7 +471,7 @@ class OpenAIStreamingInterface:
# Minimal, robust extraction: only emit the value of "message".
# If we buffered a prefix while name was streaming, feed it first.
if self._function_args_buffer_parts:
payload = "".join(self._function_args_buffer_parts + [tool_call.function.arguments])
payload = "".join([*self._function_args_buffer_parts, tool_call.function.arguments])
self._function_args_buffer_parts = None
else:
payload = tool_call.function.arguments
@@ -498,29 +498,26 @@ class OpenAIStreamingInterface:
# if the previous chunk had arguments but we needed to flush name
if self._function_args_buffer_parts:
# In this case, we should release the buffer + new data at once
combined_chunk = "".join(self._function_args_buffer_parts + [updates_main_json])
combined_chunk = "".join([*self._function_args_buffer_parts, updates_main_json])
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_delta = ToolCallDelta(
name=self._get_function_name_buffer(),
arguments=combined_chunk,
tool_call_id=self._get_current_function_id(),
)
if self._get_function_name_buffer() in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self._get_function_name_buffer(),
arguments=combined_chunk,
tool_call_id=self._get_current_function_id(),
),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
# name=name,
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
run_id=self.run_id,
step_id=self.step_id,
)
else:
tool_call_delta = ToolCallDelta(
name=self._get_function_name_buffer(),
arguments=combined_chunk,
tool_call_id=self._get_current_function_id(),
)
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
@@ -540,26 +537,23 @@ class OpenAIStreamingInterface:
# If there's no buffer to clear, just output a new chunk with new data
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_delta = ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self._get_current_function_id(),
)
if self._get_function_name_buffer() in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self._get_current_function_id(),
),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
# name=name,
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
run_id=self.run_id,
step_id=self.step_id,
)
else:
tool_call_delta = ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self._get_current_function_id(),
)
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
@@ -588,7 +582,7 @@ class SimpleOpenAIStreamingInterface:
messages: Optional[list] = None,
tools: Optional[list] = None,
requires_approval_tools: list = [],
model: str = None,
model: str | None = None,
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
@@ -639,7 +633,6 @@ class SimpleOpenAIStreamingInterface:
def get_content(self) -> list[TextContent | OmittedReasoningContent | ReasoningContent]:
shown_omitted = False
concat_content = ""
merged_messages = []
reasoning_content = []
concat_content_parts: list[str] = []
@@ -837,6 +830,10 @@ class SimpleOpenAIStreamingInterface:
prev_message_type: Optional[str] = None,
message_index: int = 0,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
# Track events for diagnostics
self.total_events_received += 1
self.last_event_type = "ChatCompletionChunk"
if not self.model or not self.message_id:
self.model = chunk.model
self.message_id = chunk.id
@@ -887,14 +884,10 @@ class SimpleOpenAIStreamingInterface:
prev_message_type = assistant_msg.message_type
yield assistant_msg
if (
hasattr(chunk, "choices")
and len(chunk.choices) > 0
and hasattr(chunk.choices[0], "delta")
and hasattr(chunk.choices[0].delta, "reasoning_content")
):
if hasattr(chunk, "choices") and len(chunk.choices) > 0 and hasattr(chunk.choices[0], "delta"):
delta = chunk.choices[0].delta
reasoning_content = getattr(delta, "reasoning_content", None)
# Check for reasoning_content (standard) or reasoning (OpenRouter)
reasoning_content = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
if reasoning_content is not None and reasoning_content != "":
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
@@ -945,7 +938,7 @@ class SimpleOpenAIStreamingInterface:
if resolved_id is None:
continue
delta = ToolCallDelta(
tool_call_delta = ToolCallDelta(
name=tool_call.function.name if (tool_call.function and tool_call.function.name) else None,
arguments=tool_call.function.arguments if (tool_call.function and tool_call.function.arguments) else None,
tool_call_id=resolved_id,
@@ -956,7 +949,8 @@ class SimpleOpenAIStreamingInterface:
tool_call_msg = ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
date=datetime.now(timezone.utc),
tool_call=delta,
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
run_id=self.run_id,
step_id=self.step_id,
@@ -967,8 +961,8 @@ class SimpleOpenAIStreamingInterface:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=delta,
tool_calls=delta,
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
@@ -988,7 +982,7 @@ class SimpleOpenAIResponsesStreamingInterface:
messages: Optional[list] = None,
tools: Optional[list] = None,
requires_approval_tools: list = [],
model: str = None,
model: str | None = None,
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
@@ -1029,6 +1023,9 @@ class SimpleOpenAIResponsesStreamingInterface:
self.last_event_type: str | None = None
self.total_events_received: int = 0
self.stream_was_cancelled: bool = False
# For downstream finish_reason mapping (e.g. max_output_tokens -> "length")
# None means no incomplete reason was observed.
self.incomplete_reason: str | None = None
# -------- Mapping helpers (no broad try/except) --------
def _record_tool_mapping(self, event: object, item: object) -> tuple[str | None, str | None, int | None, str | None]:
@@ -1089,6 +1086,10 @@ class SimpleOpenAIResponsesStreamingInterface:
text=response.content[0].text,
)
)
elif len(response.content) == 0:
# Incomplete responses may have an output message with no content parts
# (model started the message item but hit max_output_tokens before producing text)
logger.warning("ResponseOutputMessage has 0 content parts (likely from an incomplete response), skipping.")
else:
raise ValueError(f"Got {len(response.content)} content parts, expected 1")
@@ -1254,8 +1255,6 @@ class SimpleOpenAIResponsesStreamingInterface:
if isinstance(new_event_item, ResponseReasoningItem):
# Look for summary delta, or encrypted_content
summary = new_event_item.summary
content = new_event_item.content # NOTE: always none
encrypted_content = new_event_item.encrypted_content
# TODO change to summarize reasoning message, but we need to figure out the streaming indices of summary problem
concat_summary = "".join([s.text for s in summary])
if concat_summary != "":
@@ -1283,27 +1282,24 @@ class SimpleOpenAIResponsesStreamingInterface:
self.tool_call_name = name
# Record mapping so subsequent argument deltas can be associated
self._record_tool_mapping(event, new_event_item)
tool_call_delta = ToolCallDelta(
name=name,
arguments=arguments if arguments != "" else None,
tool_call_id=call_id,
)
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
yield ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=name,
arguments=arguments if arguments != "" else None,
tool_call_id=call_id,
),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
run_id=self.run_id,
step_id=self.step_id,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_delta = ToolCallDelta(
name=name,
arguments=arguments if arguments != "" else None,
tool_call_id=call_id,
)
yield ToolCallMessage(
id=self.letta_message_id,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
@@ -1394,7 +1390,6 @@ class SimpleOpenAIResponsesStreamingInterface:
# NOTE: is this inclusive of the deltas?
# If not, we should add it to the rolling
summary_index = event.summary_index
text = event.text
return
# Reasoning summary streaming
@@ -1436,7 +1431,6 @@ class SimpleOpenAIResponsesStreamingInterface:
# Assistant message streaming
elif isinstance(event, ResponseTextDoneEvent):
# NOTE: inclusive, can skip
text = event.text
return
# Assistant message done
@@ -1451,7 +1445,7 @@ class SimpleOpenAIResponsesStreamingInterface:
delta = event.delta
# Resolve tool_call_id/name using output_index or item_id
resolved_call_id, resolved_name, out_idx, item_id = self._resolve_mapping_for_delta(event)
resolved_call_id, resolved_name, _out_idx, _item_id = self._resolve_mapping_for_delta(event)
# Fallback to last seen tool name for approval routing if mapping name missing
if not resolved_name:
@@ -1462,27 +1456,24 @@ class SimpleOpenAIResponsesStreamingInterface:
return
# We have a call id; emit approval or tool-call message accordingly
tool_call_delta = ToolCallDelta(
name=None,
arguments=delta,
tool_call_id=resolved_call_id,
)
if resolved_name and resolved_name in self.requires_approval_tools:
yield ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=delta,
tool_call_id=resolved_call_id,
),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
run_id=self.run_id,
step_id=self.step_id,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_delta = ToolCallDelta(
name=None,
arguments=delta,
tool_call_id=resolved_call_id,
)
yield ToolCallMessage(
id=self.letta_message_id,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
@@ -1497,7 +1488,6 @@ class SimpleOpenAIResponsesStreamingInterface:
# Function calls
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
# NOTE: inclusive
full_args = event.arguments
return
# Generic
@@ -1506,31 +1496,55 @@ class SimpleOpenAIResponsesStreamingInterface:
return
# Generic finish
elif isinstance(event, ResponseCompletedEvent):
# NOTE we can "rebuild" the final state of the stream using the values in here, instead of relying on the accumulators
elif isinstance(event, (ResponseCompletedEvent, ResponseIncompleteEvent)):
# ResponseIncompleteEvent has the same response structure as ResponseCompletedEvent,
# but indicates the response was cut short (e.g. due to max_output_tokens).
# We still extract the partial response and usage data so they aren't silently lost.
if isinstance(event, ResponseIncompleteEvent):
self.incomplete_reason = (
getattr(event.response.incomplete_details, "reason", None) if event.response.incomplete_details else None
)
reason = self.incomplete_reason or "unknown"
logger.warning(
f"OpenAI Responses API returned an incomplete response (reason: {reason}). "
f"Model: {event.response.model}, output_tokens: {event.response.usage.output_tokens if event.response.usage else 'N/A'}. "
f"The partial response content will still be used."
)
self.final_response = event.response
self.model = event.response.model
self.input_tokens = event.response.usage.input_tokens
self.output_tokens = event.response.usage.output_tokens
self.message_id = event.response.id
# Store raw usage for transparent provider trace logging
try:
self.raw_usage = event.response.usage.model_dump(exclude_none=True)
except Exception as e:
logger.error(f"Failed to capture raw_usage from OpenAI Responses API: {e}")
self.raw_usage = None
# Capture cache token details (Responses API uses input_tokens_details)
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
if hasattr(event.response.usage, "input_tokens_details") and event.response.usage.input_tokens_details:
details = event.response.usage.input_tokens_details
if hasattr(details, "cached_tokens") and details.cached_tokens is not None:
self.cached_tokens = details.cached_tokens
# Capture reasoning token details (Responses API uses output_tokens_details)
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
if hasattr(event.response.usage, "output_tokens_details") and event.response.usage.output_tokens_details:
details = event.response.usage.output_tokens_details
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens is not None:
self.reasoning_tokens = details.reasoning_tokens
usage = event.response.usage
if usage is not None:
self.input_tokens = usage.input_tokens
self.output_tokens = usage.output_tokens
# Store raw usage for transparent provider trace logging
try:
self.raw_usage = usage.model_dump(exclude_none=True)
except Exception as e:
logger.error(f"Failed to capture raw_usage from OpenAI Responses API: {e}")
self.raw_usage = None
# Capture cache token details (Responses API uses input_tokens_details)
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
if hasattr(usage, "input_tokens_details") and usage.input_tokens_details:
details = usage.input_tokens_details
if hasattr(details, "cached_tokens") and details.cached_tokens is not None:
self.cached_tokens = details.cached_tokens
# Capture reasoning token details (Responses API uses output_tokens_details)
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
if hasattr(usage, "output_tokens_details") and usage.output_tokens_details:
details = usage.output_tokens_details
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens is not None:
self.reasoning_tokens = details.reasoning_tokens
else:
logger.warning(
"OpenAI Responses API finish event had no usage payload. "
"Proceeding with partial response but token metrics may be incomplete."
)
return
else:

View File

@@ -94,7 +94,7 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
if scheduler.running:
try:
scheduler.shutdown(wait=False)
except:
except Exception:
pass
return False
finally:

View File

@@ -19,6 +19,7 @@ from letta.errors import (
LLMAuthenticationError,
LLMBadRequestError,
LLMConnectionError,
LLMInsufficientCreditsError,
LLMNotFoundError,
LLMPermissionDeniedError,
LLMProviderOverloaded,
@@ -29,13 +30,16 @@ from letta.errors import (
)
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.decorators import deprecated
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.anthropic_constants import ANTHROPIC_MAX_STRICT_TOOLS, ANTHROPIC_STRICT_MODE_ALLOWLIST
from letta.llm_api.error_utils import is_insufficient_credits_message
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentType
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
@@ -45,9 +49,7 @@ from letta.schemas.openai.chat_completion_response import (
FunctionCall,
Message as ChoiceMessage,
ToolCall,
UsageStatistics,
)
from letta.schemas.response_format import JsonSchemaResponseFormat
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings
@@ -62,10 +64,16 @@ class AnthropicClient(LLMClientBase):
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(llm_config, async_client=False)
betas: list[str] = []
# Interleaved thinking for reasoner (sync path parity)
# Opus 4.6 / Sonnet 4.6 Auto Thinking
if llm_config.enable_reasoner:
betas.append("interleaved-thinking-2025-05-14")
# 1M context beta for Sonnet 4/4.5 when enabled
if llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6"):
betas.append("adaptive-thinking-2026-01-28")
# Interleaved thinking for other reasoners (sync path parity)
else:
betas.append("interleaved-thinking-2025-05-14")
# 1M context beta for Sonnet 4/4.5 or Opus 4.6 when enabled
try:
from letta.settings import model_settings
@@ -73,12 +81,23 @@ class AnthropicClient(LLMClientBase):
llm_config.model.startswith("claude-sonnet-4") or llm_config.model.startswith("claude-sonnet-4-5")
):
betas.append("context-1m-2025-08-07")
elif model_settings.anthropic_opus_1m and llm_config.model.startswith("claude-opus-4-6"):
betas.append("context-1m-2025-08-07")
except Exception:
pass
# Opus 4.5 effort parameter - to extend to other models, modify the model check
if llm_config.model.startswith("claude-opus-4-5") and llm_config.effort is not None:
# Effort parameter for Opus 4.5, Opus 4.6, and Sonnet 4.6 - to extend to other models, modify the model check
if (
llm_config.model.startswith("claude-opus-4-5")
or llm_config.model.startswith("claude-opus-4-6")
or llm_config.model.startswith("claude-sonnet-4-6")
) and llm_config.effort is not None:
betas.append("effort-2025-11-24")
# Max effort beta for Opus 4.6 / Sonnet 4.6
if (
llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6")
) and llm_config.effort == "max":
betas.append("max-effort-2026-01-24")
# Context management for Opus 4.5 to preserve thinking blocks (improves cache hits)
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
@@ -88,21 +107,46 @@ class AnthropicClient(LLMClientBase):
if llm_config.strict and _supports_structured_outputs(llm_config.model):
betas.append("structured-outputs-2025-11-13")
if betas:
response = client.beta.messages.create(**request_data, betas=betas)
else:
response = client.beta.messages.create(**request_data)
return response.model_dump()
try:
if betas:
response = client.beta.messages.create(**request_data, betas=betas)
else:
response = client.beta.messages.create(**request_data)
return response.model_dump()
except ValueError as e:
# Anthropic SDK raises ValueError when streaming is required for long-running operations
# See: https://github.com/anthropics/anthropic-sdk-python#streaming
if "streaming is required" in str(e).lower():
logger.warning(
"[Anthropic] Non-streaming request rejected due to potential long duration. Error: %s. "
"Note: Synchronous fallback to streaming is not supported. Use async API instead.",
str(e),
)
# Re-raise as LLMBadRequestError (maps to 502 Bad Gateway) since this is a downstream provider constraint
raise LLMBadRequestError(
message="This operation may take longer than 10 minutes and requires streaming. "
"Please use the async API (request_async) instead of the deprecated sync API. "
f"Original error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
) from e
raise
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
request_data = sanitize_unicode_surrogates(request_data)
client = await self._get_anthropic_client_async(llm_config, async_client=True)
betas: list[str] = []
# interleaved thinking for reasoner
if llm_config.enable_reasoner:
betas.append("interleaved-thinking-2025-05-14")
# 1M context beta for Sonnet 4/4.5 when enabled
# Opus 4.6 / Sonnet 4.6 Auto Thinking
if llm_config.enable_reasoner:
if llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6"):
betas.append("adaptive-thinking-2026-01-28")
# Interleaved thinking for other reasoners (sync path parity)
else:
betas.append("interleaved-thinking-2025-05-14")
# 1M context beta for Sonnet 4/4.5 or Opus 4.6 when enabled
try:
from letta.settings import model_settings
@@ -110,12 +154,23 @@ class AnthropicClient(LLMClientBase):
llm_config.model.startswith("claude-sonnet-4") or llm_config.model.startswith("claude-sonnet-4-5")
):
betas.append("context-1m-2025-08-07")
elif model_settings.anthropic_opus_1m and llm_config.model.startswith("claude-opus-4-6"):
betas.append("context-1m-2025-08-07")
except Exception:
pass
# Opus 4.5 effort parameter - to extend to other models, modify the model check
if llm_config.model.startswith("claude-opus-4-5") and llm_config.effort is not None:
# Effort parameter for Opus 4.5, Opus 4.6, and Sonnet 4.6 - to extend to other models, modify the model check
if (
llm_config.model.startswith("claude-opus-4-5")
or llm_config.model.startswith("claude-opus-4-6")
or llm_config.model.startswith("claude-sonnet-4-6")
) and llm_config.effort is not None:
betas.append("effort-2025-11-24")
# Max effort beta for Opus 4.6 / Sonnet 4.6
if (
llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6")
) and llm_config.effort == "max":
betas.append("max-effort-2026-01-24")
# Context management for Opus 4.5 to preserve thinking blocks (improves cache hits)
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
@@ -254,6 +309,8 @@ class AnthropicClient(LLMClientBase):
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
request_data = sanitize_unicode_surrogates(request_data)
client = await self._get_anthropic_client_async(llm_config, async_client=True)
request_data["stream"] = True
@@ -262,12 +319,15 @@ class AnthropicClient(LLMClientBase):
# See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/fine-grained-streaming
betas = ["fine-grained-tool-streaming-2025-05-14"]
# If extended thinking, turn on interleaved header
# https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
# Opus 4.6 / Sonnet 4.6 Auto Thinking
if llm_config.enable_reasoner:
betas.append("interleaved-thinking-2025-05-14")
if llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6"):
betas.append("adaptive-thinking-2026-01-28")
# Interleaved thinking for other reasoners (sync path parity)
else:
betas.append("interleaved-thinking-2025-05-14")
# 1M context beta for Sonnet 4/4.5 when enabled
# 1M context beta for Sonnet 4/4.5 or Opus 4.6 when enabled
try:
from letta.settings import model_settings
@@ -275,12 +335,23 @@ class AnthropicClient(LLMClientBase):
llm_config.model.startswith("claude-sonnet-4") or llm_config.model.startswith("claude-sonnet-4-5")
):
betas.append("context-1m-2025-08-07")
elif model_settings.anthropic_opus_1m and llm_config.model.startswith("claude-opus-4-6"):
betas.append("context-1m-2025-08-07")
except Exception:
pass
# Opus 4.5 effort parameter - to extend to other models, modify the model check
if llm_config.model.startswith("claude-opus-4-5") and llm_config.effort is not None:
# Effort parameter for Opus 4.5, Opus 4.6, and Sonnet 4.6 - to extend to other models, modify the model check
if (
llm_config.model.startswith("claude-opus-4-5")
or llm_config.model.startswith("claude-opus-4-6")
or llm_config.model.startswith("claude-sonnet-4-6")
) and llm_config.effort is not None:
betas.append("effort-2025-11-24")
# Max effort beta for Opus 4.6 / Sonnet 4.6
if (
llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6")
) and llm_config.effort == "max":
betas.append("max-effort-2026-01-24")
# Context management for Opus 4.5 to preserve thinking blocks (improves cache hits)
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
@@ -335,7 +406,7 @@ class AnthropicClient(LLMClientBase):
for agent_id in agent_messages_mapping
}
client = await self._get_anthropic_client_async(list(agent_llm_config_mapping.values())[0], async_client=True)
client = await self._get_anthropic_client_async(next(iter(agent_llm_config_mapping.values())), async_client=True)
anthropic_requests = [
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
@@ -461,25 +532,43 @@ class AnthropicClient(LLMClientBase):
}
# Extended Thinking
if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
thinking_budget = max(llm_config.max_reasoning_tokens, 1024)
if thinking_budget != llm_config.max_reasoning_tokens:
logger.warning(
f"Max reasoning tokens must be at least 1024 for Claude. Setting max_reasoning_tokens to 1024 for model {llm_config.model}."
)
data["thinking"] = {
"type": "enabled",
"budget_tokens": thinking_budget,
}
# Note: Anthropic does not allow thinking when forcing tool use with split_thread_agent
should_enable_thinking = (
self.is_reasoning_model(llm_config)
and llm_config.enable_reasoner
and not (agent_type == AgentType.split_thread_agent and force_tool_call is not None)
)
if should_enable_thinking:
# Opus 4.6 / Sonnet 4.6 uses Auto Thinking (no budget tokens)
if llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6"):
data["thinking"] = {
"type": "adaptive",
}
else:
# Traditional extended thinking with budget tokens
thinking_budget = max(llm_config.max_reasoning_tokens, 1024)
if thinking_budget != llm_config.max_reasoning_tokens:
logger.warning(
f"Max reasoning tokens must be at least 1024 for Claude. Setting max_reasoning_tokens to 1024 for model {llm_config.model}."
)
data["thinking"] = {
"type": "enabled",
"budget_tokens": thinking_budget,
}
# `temperature` may only be set to 1 when thinking is enabled. Please consult our documentation at https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking'
data["temperature"] = 1.0
# Silently disable prefix_fill for now
prefix_fill = False
# Effort configuration for Opus 4.5 (controls token spending)
# Effort configuration for Opus 4.5, Opus 4.6, and Sonnet 4.6 (controls token spending)
# To extend to other models, modify the model check
if llm_config.model.startswith("claude-opus-4-5") and llm_config.effort is not None:
if (
llm_config.model.startswith("claude-opus-4-5")
or llm_config.model.startswith("claude-opus-4-6")
or llm_config.model.startswith("claude-sonnet-4-6")
) and llm_config.effort is not None:
data["output_config"] = {"effort": llm_config.effort}
# Context management for Opus 4.5 to preserve thinking blocks and improve cache hits
@@ -510,11 +599,16 @@ class AnthropicClient(LLMClientBase):
# Special case for summarization path
tools_for_request = None
tool_choice = None
elif self.is_reasoning_model(llm_config) and llm_config.enable_reasoner or agent_type == AgentType.letta_v1_agent:
elif (self.is_reasoning_model(llm_config) and llm_config.enable_reasoner) or agent_type == AgentType.letta_v1_agent:
# NOTE: reasoning models currently do not allow for `any`
# NOTE: react agents should always have auto on, since the precense/absense of tool calls controls chaining
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
tools_for_request = [OpenAITool(function=f) for f in tools]
# NOTE: react agents should always have at least auto on, since the precense/absense of tool calls controls chaining
if agent_type == AgentType.split_thread_agent and force_tool_call is not None:
tool_choice = {"type": "tool", "name": force_tool_call, "disable_parallel_tool_use": True}
# When forcing a specific tool, only include that tool
tools_for_request = [OpenAITool(function=f) for f in tools if f["name"] == force_tool_call]
else:
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
tools_for_request = [OpenAITool(function=f) for f in tools]
elif force_tool_call is not None:
tool_choice = {"type": "tool", "name": force_tool_call, "disable_parallel_tool_use": True}
tools_for_request = [OpenAITool(function=f) for f in tools if f["name"] == force_tool_call]
@@ -691,7 +785,9 @@ class AnthropicClient(LLMClientBase):
return data
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
async def count_tokens(
self, messages: List[dict] | None = None, model: str | None = None, tools: List[OpenAITool] | None = None
) -> int:
logging.getLogger("httpx").setLevel(logging.WARNING)
# Use the default client; token counting is lightweight and does not require BYOK overrides
client = anthropic.AsyncAnthropic()
@@ -811,6 +907,8 @@ class AnthropicClient(LLMClientBase):
and (model.startswith("claude-sonnet-4") or model.startswith("claude-sonnet-4-5"))
):
betas.append("context-1m-2025-08-07")
elif model and model_settings.anthropic_opus_1m and model.startswith("claude-opus-4-6"):
betas.append("context-1m-2025-08-07")
except Exception:
pass
@@ -851,10 +949,16 @@ class AnthropicClient(LLMClientBase):
or llm_config.model.startswith("claude-haiku-4-5")
# Opus 4.5 support - to extend effort parameter to other models, modify this check
or llm_config.model.startswith("claude-opus-4-5")
# Opus 4.6 support - uses Auto Thinking
or llm_config.model.startswith("claude-opus-4-6")
# Sonnet 4.6 support - same API as Opus 4.6
or llm_config.model.startswith("claude-sonnet-4-6")
)
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
# make sure to check for overflow errors, regardless of error type
error_str = str(e).lower()
if (
@@ -869,6 +973,7 @@ class AnthropicClient(LLMClientBase):
logger.warning(f"[Anthropic] Context window exceeded: {str(e)}")
return ContextWindowExceededError(
message=f"Context window exceeded for Anthropic: {str(e)}",
details={"is_byok": is_byok},
)
if isinstance(e, anthropic.APITimeoutError):
@@ -876,7 +981,7 @@ class AnthropicClient(LLMClientBase):
return LLMTimeoutError(
message=f"Request to Anthropic timed out: {str(e)}",
code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
if isinstance(e, anthropic.APIConnectionError):
@@ -884,7 +989,7 @@ class AnthropicClient(LLMClientBase):
return LLMConnectionError(
message=f"Failed to connect to Anthropic: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
# Handle httpx.RemoteProtocolError which can occur during streaming
@@ -895,7 +1000,7 @@ class AnthropicClient(LLMClientBase):
return LLMConnectionError(
message=f"Connection error during Anthropic streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
# Handle httpx network errors which can occur during streaming
@@ -905,7 +1010,7 @@ class AnthropicClient(LLMClientBase):
return LLMConnectionError(
message=f"Network error during Anthropic streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
)
if isinstance(e, anthropic.RateLimitError):
@@ -913,6 +1018,7 @@ class AnthropicClient(LLMClientBase):
return LLMRateLimitError(
message=f"Rate limited by Anthropic: {str(e)}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details={"is_byok": is_byok},
)
if isinstance(e, anthropic.BadRequestError):
@@ -930,11 +1036,13 @@ class AnthropicClient(LLMClientBase):
# 400 - {'type': 'error', 'error': {'type': 'invalid_request_error', 'message': 'input length and `max_tokens` exceed context limit: 173298 + 32000 > 200000, decrease input length or `max_tokens` and try again'}}
return ContextWindowExceededError(
message=f"Bad request to Anthropic (context window exceeded): {str(e)}",
details={"is_byok": is_byok},
)
else:
return LLMBadRequestError(
message=f"Bad request to Anthropic: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
if isinstance(e, anthropic.AuthenticationError):
@@ -942,6 +1050,7 @@ class AnthropicClient(LLMClientBase):
return LLMAuthenticationError(
message=f"Authentication failed with Anthropic: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
if isinstance(e, anthropic.PermissionDeniedError):
@@ -949,6 +1058,7 @@ class AnthropicClient(LLMClientBase):
return LLMPermissionDeniedError(
message=f"Permission denied by Anthropic: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
if isinstance(e, anthropic.NotFoundError):
@@ -956,6 +1066,7 @@ class AnthropicClient(LLMClientBase):
return LLMNotFoundError(
message=f"Resource not found in Anthropic: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
if isinstance(e, anthropic.UnprocessableEntityError):
@@ -963,23 +1074,29 @@ class AnthropicClient(LLMClientBase):
return LLMUnprocessableEntityError(
message=f"Invalid request content for Anthropic: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
if isinstance(e, anthropic.APIStatusError):
logger.warning(f"[Anthropic] API status error: {str(e)}")
# Handle 413 Request Entity Too Large - request payload exceeds size limits
if hasattr(e, "status_code") and e.status_code == 413:
logger.warning(f"[Anthropic] Request too large (413): {str(e)}")
return ContextWindowExceededError(
message=f"Request too large for Anthropic (413): {str(e)}",
if isinstance(e, anthropic.InternalServerError):
error_str = str(e).lower()
if "overflow" in error_str or "upstream connect error" in error_str:
logger.warning(f"[Anthropic] Upstream infrastructure error (transient): {str(e)}")
return LLMServerError(
message=f"Anthropic upstream infrastructure error (transient, may resolve on retry): {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.status_code if hasattr(e, "status_code") else None,
"transient": True,
},
)
if "overloaded" in str(e).lower():
if "overloaded" in error_str:
return LLMProviderOverloaded(
message=f"Anthropic API is overloaded: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
logger.warning(f"[Anthropic] Internal server error: {str(e)}")
return LLMServerError(
message=f"Anthropic API error: {str(e)}",
message=f"Anthropic internal server error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.status_code if hasattr(e, "status_code") else None,
@@ -987,7 +1104,38 @@ class AnthropicClient(LLMClientBase):
},
)
return super().handle_llm_error(e)
if isinstance(e, anthropic.APIStatusError):
logger.warning(f"[Anthropic] API status error: {str(e)}")
if (hasattr(e, "status_code") and e.status_code == 402) or is_insufficient_credits_message(str(e)):
msg = str(e)
return LLMInsufficientCreditsError(
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
code=ErrorCode.PAYMENT_REQUIRED,
details={"status_code": getattr(e, "status_code", None), "is_byok": is_byok},
)
if hasattr(e, "status_code") and e.status_code == 413:
logger.warning(f"[Anthropic] Request too large (413): {str(e)}")
return ContextWindowExceededError(
message=f"Request too large for Anthropic (413): {str(e)}",
details={"is_byok": is_byok},
)
if "overloaded" in str(e).lower():
return LLMProviderOverloaded(
message=f"Anthropic API is overloaded: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
return LLMServerError(
message=f"Anthropic API error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.status_code if hasattr(e, "status_code") else None,
"response": str(e.response) if hasattr(e, "response") else None,
"is_byok": is_byok,
},
)
return super().handle_llm_error(e, llm_config=llm_config)
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
"""Extract usage statistics from Anthropic response and return as LettaUsageStatistics."""
@@ -1027,6 +1175,11 @@ class AnthropicClient(LLMClientBase):
input_messages: List[PydanticMessage],
llm_config: LLMConfig,
) -> ChatCompletionResponse:
if isinstance(response_data, str):
raise LLMServerError(
message="Anthropic endpoint returned a raw string instead of a JSON object. This usually indicates the endpoint URL is incorrect or returned an error page.",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
"""
Example response from Claude 3:
response.json = {
@@ -1096,7 +1249,7 @@ class AnthropicClient(LLMClientBase):
args_json = json.loads(arguments)
if not isinstance(args_json, dict):
raise LLMServerError("Expected parseable json object for arguments")
except:
except Exception:
arguments = str(tool_input["function"]["arguments"])
else:
arguments = json.dumps(tool_input, indent=2)
@@ -1117,7 +1270,23 @@ class AnthropicClient(LLMClientBase):
redacted_reasoning_content = content_part.data
else:
raise RuntimeError("Unexpected empty content in response")
# Log the full response for debugging
logger.error(
"[Anthropic] Received response with empty content. Response ID: %s, Model: %s, Stop reason: %s, Full response: %s",
response.id,
response.model,
response.stop_reason,
json.dumps(response_data),
)
raise LLMServerError(
message=f"LLM provider returned empty content in response (ID: {response.id}, model: {response.model}, stop_reason: {response.stop_reason})",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"response_id": response.id,
"model": response.model,
"stop_reason": response.stop_reason,
},
)
assert response.role == "assistant"
choice = Choice(
@@ -1372,7 +1541,7 @@ def is_heartbeat(message: dict, is_ping: bool = False) -> bool:
try:
message_json = json.loads(message["content"])
except:
except Exception:
return False
# Check if message_json is a dict (not int, str, list, etc.)

View File

@@ -1,18 +1,31 @@
import json
import os
from typing import List, Optional, Tuple
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.responses.response_stream_event import ResponseStreamEvent
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings
logger = get_logger(__name__)
class AzureClient(OpenAIClient):
@staticmethod
def _is_v1_endpoint(base_url: str) -> bool:
if not base_url:
return False
return base_url.rstrip("/").endswith("/openai/v1")
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
@@ -29,38 +42,99 @@ class AzureClient(OpenAIClient):
return None, None, None
def _resolve_credentials(self, api_key, base_url, api_version):
"""Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required."""
if not api_key:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
if not base_url:
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
if not api_version and not self._is_v1_endpoint(base_url):
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
return api_key, base_url, api_version
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
if self._is_v1_endpoint(base_url):
client = OpenAI(api_key=api_key, base_url=base_url)
else:
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
try:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = await client.chat.completions.create(**request_data)
except Exception as e:
raise self.handle_llm_error(e)
request_data = sanitize_unicode_surrogates(request_data)
return response.model_dump()
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
try:
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = await client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
except Exception as e:
raise self.handle_llm_error(e, llm_config=llm_config)
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]:
"""
Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
try:
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
**request_data,
stream=True,
)
except Exception as e:
logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}")
raise e
else:
try:
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data,
stream=True,
stream_options={"include_usage": True},
)
except Exception as e:
logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}")
raise e
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
@@ -68,7 +142,12 @@ class AzureClient(OpenAIClient):
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage

View File

@@ -1,10 +1,10 @@
"""ChatGPT OAuth Client - handles requests to chatgpt.com/backend-api/codex/responses."""
import asyncio
import json
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
from typing import Any, AsyncIterator, Dict, List, Optional
import httpx
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.responses import (
Response,
ResponseCompletedEvent,
@@ -32,6 +32,7 @@ from openai.types.responses.response_stream_event import ResponseStreamEvent
from letta.errors import (
ContextWindowExceededError,
ErrorCode,
LettaError,
LLMAuthenticationError,
LLMBadRequestError,
LLMConnectionError,
@@ -39,6 +40,7 @@ from letta.errors import (
LLMServerError,
LLMTimeoutError,
)
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.llm_client_base import LLMClientBase
from letta.log import get_logger
from letta.otel.tracing import trace_method
@@ -47,11 +49,6 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_response import (
ChatCompletionResponse,
Choice,
FunctionCall,
Message as ChoiceMessage,
ToolCall,
UsageStatistics,
)
from letta.schemas.providers.chatgpt_oauth import ChatGPTOAuthCredentials, ChatGPTOAuthProvider
from letta.schemas.usage import LettaUsageStatistics
@@ -100,6 +97,10 @@ class ChatGPTOAuthClient(LLMClientBase):
4. Transforms responses back to OpenAI ChatCompletion format
"""
MAX_RETRIES = 3
# Transient httpx errors that are safe to retry (connection drops, transport-level failures)
_RETRYABLE_ERRORS = (httpx.ReadError, httpx.WriteError, httpx.ConnectError, httpx.RemoteProtocolError, LLMConnectionError)
@trace_method
async def _get_provider_and_credentials_async(self, llm_config: LLMConfig) -> tuple[ChatGPTOAuthProvider, ChatGPTOAuthCredentials]:
"""Get the ChatGPT OAuth provider and credentials with automatic refresh if needed.
@@ -153,6 +154,11 @@ class ChatGPTOAuthClient(LLMClientBase):
Returns:
Dictionary of HTTP headers.
"""
if not creds.access_token:
raise LLMAuthenticationError(
message="ChatGPT OAuth access_token is empty or missing",
code=ErrorCode.UNAUTHENTICATED,
)
return {
"Authorization": f"Bearer {creds.access_token}",
"ChatGPT-Account-Id": creds.account_id,
@@ -356,38 +362,68 @@ class ChatGPTOAuthClient(LLMClientBase):
Returns:
Response data in OpenAI ChatCompletion format.
"""
request_data = sanitize_unicode_surrogates(request_data)
_, creds = await self._get_provider_and_credentials_async(llm_config)
headers = self._build_headers(creds)
endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT
# ChatGPT backend requires streaming, so we use client.stream() to handle SSE
async with httpx.AsyncClient() as client:
# Retry on transient network errors with exponential backoff
for attempt in range(self.MAX_RETRIES):
try:
async with client.stream(
"POST",
endpoint,
json=request_data,
headers=headers,
timeout=120.0,
) as response:
response.raise_for_status()
# Accumulate SSE events into a final response
return await self._accumulate_sse_response(response)
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
endpoint,
json=request_data,
headers=headers,
timeout=120.0,
) as response:
response.raise_for_status()
# Accumulate SSE events into a final response
return await self._accumulate_sse_response(response)
except httpx.HTTPStatusError as e:
raise self._handle_http_error(e)
mapped = self._handle_http_error(e)
if isinstance(mapped, tuple(self._RETRYABLE_ERRORS)) and attempt < self.MAX_RETRIES - 1:
wait = 2**attempt
logger.warning(
f"[ChatGPT] Retryable HTTP error on request (attempt {attempt + 1}/{self.MAX_RETRIES}), "
f"retrying in {wait}s: {type(mapped).__name__}: {mapped}"
)
await asyncio.sleep(wait)
continue
raise mapped
except httpx.TimeoutException:
raise LLMTimeoutError(
message="ChatGPT backend request timed out",
code=ErrorCode.TIMEOUT,
)
except self._RETRYABLE_ERRORS as e:
if attempt < self.MAX_RETRIES - 1:
wait = 2**attempt
logger.warning(
f"[ChatGPT] Transient error on request (attempt {attempt + 1}/{self.MAX_RETRIES}), "
f"retrying in {wait}s: {type(e).__name__}: {e}"
)
await asyncio.sleep(wait)
continue
raise LLMConnectionError(
message=f"Failed to connect to ChatGPT backend after {self.MAX_RETRIES} attempts: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
)
except httpx.RequestError as e:
raise LLMConnectionError(
message=f"Failed to connect to ChatGPT backend: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
# Should not be reached, but satisfy type checker
raise LLMConnectionError(message="ChatGPT request failed after all retries", code=ErrorCode.INTERNAL_SERVER_ERROR)
async def _accumulate_sse_response(self, response: httpx.Response) -> dict:
"""Accumulate SSE stream into a final response.
@@ -550,64 +586,102 @@ class ChatGPTOAuthClient(LLMClientBase):
Returns:
Async generator yielding ResponseStreamEvent objects.
"""
request_data = sanitize_unicode_surrogates(request_data)
_, creds = await self._get_provider_and_credentials_async(llm_config)
headers = self._build_headers(creds)
endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT
async def stream_generator():
event_count = 0
# Track output item index for proper event construction
output_index = 0
# Track sequence_number in case backend doesn't provide it
# (OpenAI SDK expects incrementing sequence numbers starting at 0)
sequence_counter = 0
# Track whether we've yielded any events — once we have, we can't
# transparently retry because the caller has already consumed partial data.
has_yielded = False
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
endpoint,
json=request_data,
headers=headers,
timeout=120.0,
) as response:
# Check for error status
if response.status_code != 200:
error_body = await response.aread()
logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}")
raise self._handle_http_error_from_status(response.status_code, error_body.decode())
for attempt in range(self.MAX_RETRIES):
try:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
endpoint,
json=request_data,
headers=headers,
timeout=120.0,
) as response:
# Check for error status
if response.status_code != 200:
error_body = await response.aread()
logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}")
raise self._handle_http_error_from_status(response.status_code, error_body.decode())
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
data_str = line[6:]
if data_str == "[DONE]":
break
try:
raw_event = json.loads(data_str)
event_type = raw_event.get("type")
event_count += 1
try:
raw_event = json.loads(data_str)
event_type = raw_event.get("type")
# Use backend-provided sequence_number if available, else use counter
# This ensures proper ordering even if backend doesn't provide it
if "sequence_number" not in raw_event:
raw_event["sequence_number"] = sequence_counter
sequence_counter = raw_event["sequence_number"] + 1
# Check for error events from the API (context window, rate limit, etc.)
if event_type == "error":
logger.error(f"ChatGPT SSE error event: {json.dumps(raw_event, default=str)[:1000]}")
raise self._handle_sse_error_event(raw_event)
# Track output index for output_item.added events
if event_type == "response.output_item.added":
output_index = raw_event.get("output_index", output_index)
# Check for response.failed or response.incomplete events
if event_type in ("response.failed", "response.incomplete"):
logger.error(f"ChatGPT SSE {event_type} event: {json.dumps(raw_event, default=str)[:1000]}")
resp_obj = raw_event.get("response", {})
error_info = resp_obj.get("error", {})
if error_info:
raise self._handle_sse_error_event({"error": error_info, "type": event_type})
else:
raise LLMBadRequestError(
message=f"ChatGPT request failed with status '{event_type}' (no error details provided)",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
# Convert to OpenAI SDK ResponseStreamEvent
sdk_event = self._convert_to_sdk_event(raw_event, output_index)
if sdk_event:
yield sdk_event
# Use backend-provided sequence_number if available, else use counter
# This ensures proper ordering even if backend doesn't provide it
if "sequence_number" not in raw_event:
raw_event["sequence_number"] = sequence_counter
sequence_counter = raw_event["sequence_number"] + 1
except json.JSONDecodeError:
logger.warning(f"Failed to parse SSE event: {data_str[:100]}")
continue
# Track output index for output_item.added events
if event_type == "response.output_item.added":
output_index = raw_event.get("output_index", output_index)
# Convert to OpenAI SDK ResponseStreamEvent
sdk_event = self._convert_to_sdk_event(raw_event, output_index)
if sdk_event:
yield sdk_event
has_yielded = True
except json.JSONDecodeError:
logger.warning(f"Failed to parse SSE event: {data_str[:100]}")
continue
# Stream completed successfully
return
except self._RETRYABLE_ERRORS as e:
if has_yielded or attempt >= self.MAX_RETRIES - 1:
# Already yielded partial data or exhausted retries — must propagate
raise
wait = 2**attempt
logger.warning(
f"[ChatGPT] Transient error on stream (attempt {attempt + 1}/{self.MAX_RETRIES}), "
f"retrying in {wait}s: {type(e).__name__}: {e}"
)
await asyncio.sleep(wait)
# Wrap the async generator in AsyncStreamWrapper to provide context manager protocol
return AsyncStreamWrapper(stream_generator())
@@ -944,10 +1018,16 @@ class ChatGPTOAuthClient(LLMClientBase):
part=part,
)
# Unhandled event types - log for debugging
logger.debug(f"Unhandled SSE event type: {event_type}")
# Unhandled event types
logger.warning(f"Unhandled ChatGPT SSE event type: {event_type}")
return None
@staticmethod
def _is_upstream_connection_error(error_body: str) -> bool:
"""Check if an error body indicates an upstream connection/proxy failure."""
lower = error_body.lower()
return "upstream connect error" in lower or "reset before headers" in lower or "connection termination" in lower
def _handle_http_error_from_status(self, status_code: int, error_body: str) -> Exception:
"""Create appropriate exception from HTTP status code.
@@ -968,9 +1048,14 @@ class ChatGPTOAuthClient(LLMClientBase):
message=f"ChatGPT rate limit exceeded: {error_body}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
)
elif status_code == 502 or (status_code >= 500 and self._is_upstream_connection_error(error_body)):
return LLMConnectionError(
message=f"ChatGPT upstream connection error: {error_body}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
elif status_code >= 500:
return LLMServerError(
message=f"ChatGPT server error: {error_body}",
message=f"ChatGPT API error: {error_body}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
else:
@@ -992,25 +1077,43 @@ class ChatGPTOAuthClient(LLMClientBase):
return "o1" in model or "o3" in model or "o4" in model or "gpt-5" in model
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
"""Map ChatGPT-specific errors to common LLMError types.
Args:
e: Original exception.
llm_config: Optional LLM config to determine if this is a BYOK key.
Returns:
Mapped LLMError subclass.
"""
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
# Already a typed LLM/Letta error (e.g. from SSE error handling) — pass through
if isinstance(e, LettaError):
return e
if isinstance(e, httpx.HTTPStatusError):
return self._handle_http_error(e)
return self._handle_http_error(e, is_byok=is_byok)
return super().handle_llm_error(e)
# Handle httpx network errors which can occur during streaming
# when the connection is unexpectedly closed while reading/writing
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
logger.warning(f"[ChatGPT] Network error during streaming: {type(e).__name__}: {e}")
return LLMConnectionError(
message=f"Network error during ChatGPT streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
)
def _handle_http_error(self, e: httpx.HTTPStatusError) -> Exception:
return super().handle_llm_error(e, llm_config=llm_config)
def _handle_http_error(self, e: httpx.HTTPStatusError, is_byok: bool | None = None) -> Exception:
"""Handle HTTP status errors from ChatGPT backend.
Args:
e: HTTP status error.
is_byok: Whether the request used a BYOK key.
Returns:
Appropriate LLMError subclass.
@@ -1028,28 +1131,86 @@ class ChatGPTOAuthClient(LLMClientBase):
return LLMAuthenticationError(
message=f"ChatGPT authentication failed: {error_message}",
code=ErrorCode.UNAUTHENTICATED,
details={"is_byok": is_byok},
)
elif status_code == 429:
return LLMRateLimitError(
message=f"ChatGPT rate limit exceeded: {error_message}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details={"is_byok": is_byok},
)
elif status_code == 400:
if "context" in error_message.lower() or "token" in error_message.lower():
return ContextWindowExceededError(
message=f"ChatGPT context window exceeded: {error_message}",
details={"is_byok": is_byok},
)
return LLMBadRequestError(
message=f"ChatGPT bad request: {error_message}",
code=ErrorCode.INVALID_ARGUMENT,
details={"is_byok": is_byok},
)
elif status_code == 502 or (status_code >= 500 and self._is_upstream_connection_error(error_message)):
return LLMConnectionError(
message=f"ChatGPT upstream connection error: {error_message}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif status_code >= 500:
return LLMServerError(
message=f"ChatGPT server error: {error_message}",
message=f"ChatGPT API error: {error_message}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
else:
return LLMBadRequestError(
message=f"ChatGPT request failed ({status_code}): {error_message}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
def _handle_sse_error_event(self, raw_event: dict) -> Exception:
"""Create appropriate exception from an SSE error or response.failed event.
The ChatGPT backend can return errors as SSE events within a 200 OK stream,
e.g. {"type": "error", "error": {"type": "invalid_request_error",
"code": "context_length_exceeded", "message": "..."}}.
Args:
raw_event: Raw SSE event data containing an error.
Returns:
Appropriate LLM exception.
"""
error_obj = raw_event.get("error", {})
if isinstance(error_obj, str):
error_message = error_obj
error_code = None
else:
error_message = error_obj.get("message", "Unknown ChatGPT SSE error")
error_code = error_obj.get("code") or None
if error_code == "context_length_exceeded":
return ContextWindowExceededError(
message=f"ChatGPT context window exceeded: {error_message}",
)
elif error_code == "rate_limit_exceeded":
return LLMRateLimitError(
message=f"ChatGPT rate limit exceeded: {error_message}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
)
elif error_code == "authentication_error":
return LLMAuthenticationError(
message=f"ChatGPT authentication failed: {error_message}",
code=ErrorCode.UNAUTHENTICATED,
)
elif error_code == "server_error":
return LLMServerError(
message=f"ChatGPT API error: {error_message}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
else:
return LLMBadRequestError(
message=f"ChatGPT SSE error ({error_code or 'unknown'}): {error_message}",
code=ErrorCode.INVALID_ARGUMENT,
)

View File

@@ -5,6 +5,7 @@ from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
@@ -97,6 +98,8 @@ class DeepseekClient(OpenAIClient):
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
@@ -108,6 +111,8 @@ class DeepseekClient(OpenAIClient):
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(

View File

@@ -20,3 +20,21 @@ def is_context_window_overflow_message(msg: str) -> bool:
or "context_length_exceeded" in msg
or "Input tokens exceed the configured limit" in msg
)
def is_insufficient_credits_message(msg: str) -> bool:
"""Best-effort detection for insufficient credits/quota/billing errors.
BYOK users on OpenRouter, OpenAI, etc. may exhaust their credits mid-stream
or get rejected pre-flight. We detect these so they map to 402 instead of 400/500.
"""
lower = msg.lower()
return (
"insufficient credits" in lower
or "requires more credits" in lower
or "add more credits" in lower
or "exceeded your current quota" in lower
or "you've exceeded your budget" in lower
or ("billing" in lower and "hard limit" in lower)
or "can only afford" in lower
)

View File

@@ -8,6 +8,7 @@ from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.log import get_logger
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings, settings
logger = get_logger(__name__)
@@ -16,10 +17,27 @@ logger = get_logger(__name__)
class GoogleAIClient(GoogleVertexClient):
provider_label = "Google AI"
def _get_client(self):
def _get_client(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
api_key = None
if llm_config:
api_key, _, _ = self.get_byok_overrides(llm_config)
if not api_key:
api_key = model_settings.gemini_api_key
return genai.Client(
api_key=model_settings.gemini_api_key,
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
api_key = None
if llm_config:
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
if not api_key:
api_key = model_settings.gemini_api_key
return genai.Client(
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)

View File

@@ -1,5 +1,6 @@
GOOGLE_MODEL_TO_CONTEXT_LENGTH = {
"gemini-3-pro-preview": 1048576,
"gemini-3.1-pro-preview": 1048576,
"gemini-3-flash-preview": 1048576,
"gemini-2.5-pro": 1048576,
"gemini-2.5-flash": 1048576,

View File

@@ -1,9 +1,11 @@
import base64
import copy
import json
import uuid
from typing import AsyncIterator, List, Optional
import httpx
import pydantic_core
from google.genai import Client, errors
from google.genai.types import (
FunctionCallingConfig,
@@ -21,6 +23,7 @@ from letta.errors import (
LLMAuthenticationError,
LLMBadRequestError,
LLMConnectionError,
LLMInsufficientCreditsError,
LLMNotFoundError,
LLMPermissionDeniedError,
LLMRateLimitError,
@@ -29,12 +32,14 @@ from letta.errors import (
LLMUnprocessableEntityError,
)
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.json_helpers import json_dumps, json_loads, sanitize_unicode_surrogates
from letta.llm_api.error_utils import is_insufficient_credits_message
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentType
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool
@@ -50,8 +55,31 @@ class GoogleVertexClient(LLMClientBase):
MAX_RETRIES = model_settings.gemini_max_retries
provider_label = "Google Vertex"
def _get_client(self):
def _get_client(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
if llm_config:
api_key, _, _ = self.get_byok_overrides(llm_config)
if api_key:
return Client(
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)
return Client(
vertexai=True,
project=model_settings.google_cloud_project,
location=model_settings.google_cloud_location,
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
)
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
if llm_config:
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
if api_key:
return Client(
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)
return Client(
vertexai=True,
project=model_settings.google_cloud_project,
@@ -71,22 +99,36 @@ class GoogleVertexClient(LLMClientBase):
Performs underlying request to llm and returns raw response.
"""
try:
client = self._get_client()
client = self._get_client(llm_config)
response = client.models.generate_content(
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
return response.model_dump()
except pydantic_core._pydantic_core.ValidationError as e:
# Handle Pydantic validation errors from the Google SDK
# This occurs when tool schemas contain unsupported fields
logger.error(
f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}"
)
raise LLMBadRequestError(
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
f"Please check your tool definitions. Error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
except Exception as e:
raise self.handle_llm_error(e)
raise self.handle_llm_error(e, llm_config=llm_config)
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
client = self._get_client()
request_data = sanitize_unicode_surrogates(request_data)
client = await self._get_client_async(llm_config)
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
# https://github.com/googleapis/python-aiplatform/issues/4472
@@ -100,17 +142,30 @@ class GoogleVertexClient(LLMClientBase):
contents=request_data["contents"],
config=request_data["config"],
)
except pydantic_core._pydantic_core.ValidationError as e:
# Handle Pydantic validation errors from the Google SDK
# This occurs when tool schemas contain unsupported fields
logger.error(
f"Pydantic validation error when calling {self._provider_name()} API. "
f"Tool schema contains unsupported fields. Error: {e}"
)
raise LLMBadRequestError(
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
f"Please check your tool definitions. Error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
except errors.APIError as e:
# Retry on 503 and 500 errors as well, usually ephemeral from Gemini
if e.code == 503 or e.code == 500 or e.code == 504:
logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}")
retry_count += 1
if retry_count > self.MAX_RETRIES:
raise self.handle_llm_error(e)
raise self.handle_llm_error(e, llm_config=llm_config)
continue
raise self.handle_llm_error(e)
raise self.handle_llm_error(e, llm_config=llm_config)
except Exception as e:
raise self.handle_llm_error(e)
raise self.handle_llm_error(e, llm_config=llm_config)
response_data = response.model_dump()
is_malformed_function_call = self.is_malformed_function_call(response_data)
if is_malformed_function_call:
@@ -148,7 +203,9 @@ class GoogleVertexClient(LLMClientBase):
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
client = self._get_client()
request_data = sanitize_unicode_surrogates(request_data)
client = await self._get_client_async(llm_config)
try:
response = await client.aio.models.generate_content_stream(
@@ -156,13 +213,35 @@ class GoogleVertexClient(LLMClientBase):
contents=request_data["contents"],
config=request_data["config"],
)
except pydantic_core._pydantic_core.ValidationError as e:
# Handle Pydantic validation errors from the Google SDK
# This occurs when tool schemas contain unsupported fields
logger.error(
f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}"
)
raise LLMBadRequestError(
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
f"Please check your tool definitions. Error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
except errors.APIError as e:
raise self.handle_llm_error(e)
except Exception as e:
logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
raise e
# Direct yield - keeps response alive in generator's local scope throughout iteration
# This is required because the SDK's connection lifecycle is tied to the response object
async for chunk in response:
yield chunk
try:
async for chunk in response:
yield chunk
except errors.ClientError as e:
if e.code == 499:
logger.info(f"{self._provider_prefix()} Stream cancelled by client (499): {e}")
return
raise self.handle_llm_error(e, llm_config=llm_config)
except errors.APIError as e:
raise self.handle_llm_error(e, llm_config=llm_config)
@staticmethod
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
@@ -196,7 +275,7 @@ class GoogleVertexClient(LLMClientBase):
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
# * Only a subset of the OpenAPI schema is supported.
# * Supported parameter types in Python are limited.
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties", "$schema"]
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties", "$schema", "const", "$ref"]
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
for key_to_remove in keys_to_remove_at_this_level:
logger.debug(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
@@ -223,6 +302,49 @@ class GoogleVertexClient(LLMClientBase):
for item_schema in schema_part[key]:
self._clean_google_ai_schema_properties(item_schema)
def _resolve_json_schema_refs(self, schema: dict, defs: dict | None = None) -> dict:
"""
Recursively resolve $ref in JSON schema by inlining definitions.
Google GenAI SDK does not support $ref.
"""
if defs is None:
# Look for definitions at the top level
defs = schema.get("$defs") or schema.get("definitions") or {}
if not isinstance(schema, dict):
return schema
# If this is a ref, resolve it
if "$ref" in schema:
ref = schema["$ref"]
if isinstance(ref, str):
for prefix in ("#/$defs/", "#/definitions/"):
if ref.startswith(prefix):
ref_name = ref.split("/")[-1]
if ref_name in defs:
resolved = defs[ref_name].copy()
return self._resolve_json_schema_refs(resolved, defs)
break
logger.warning(f"Could not resolve $ref '{ref}' in schema — will be stripped by schema cleaner")
# Recursively process children
new_schema = schema.copy()
# We need to remove $defs/definitions from the output schema as Google doesn't support them
if "$defs" in new_schema:
del new_schema["$defs"]
if "definitions" in new_schema:
del new_schema["definitions"]
for k, v in new_schema.items():
if isinstance(v, dict):
new_schema[k] = self._resolve_json_schema_refs(v, defs)
elif isinstance(v, list):
new_schema[k] = [self._resolve_json_schema_refs(i, defs) if isinstance(i, dict) else i for i in v]
return new_schema
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
"""
OpenAI style:
@@ -273,7 +395,8 @@ class GoogleVertexClient(LLMClientBase):
dict(
name=t.function.name,
description=t.function.description,
parameters=t.function.parameters, # TODO need to unpack
# Deep copy parameters to avoid modifying the original Tool object
parameters=copy.deepcopy(t.function.parameters) if t.function.parameters else {},
)
for t in tools
]
@@ -284,6 +407,8 @@ class GoogleVertexClient(LLMClientBase):
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
if "parameters" in func and isinstance(func["parameters"], dict):
# Resolve $ref in schema because Google AI SDK doesn't support them
func["parameters"] = self._resolve_json_schema_refs(func["parameters"])
self._clean_google_ai_schema_properties(func["parameters"])
# Add inner thoughts
@@ -549,6 +674,9 @@ class GoogleVertexClient(LLMClientBase):
content=inner_thoughts,
tool_calls=[tool_call],
)
if response_message.thought_signature:
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
openai_response_message.reasoning_content_signature = thought_signature
else:
openai_response_message.content = inner_thoughts
if openai_response_message.tool_calls is None:
@@ -670,6 +798,7 @@ class GoogleVertexClient(LLMClientBase):
# "candidatesTokenCount": 27,
# "totalTokenCount": 36
# }
usage = None
if response.usage_metadata:
# Extract usage via centralized method
from letta.schemas.enums import ProviderType
@@ -750,54 +879,80 @@ class GoogleVertexClient(LLMClientBase):
return False
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
# Handle Google GenAI specific errors
if isinstance(e, errors.ClientError):
if e.code == 499:
logger.info(f"{self._provider_prefix()} Request cancelled by client (499): {e}")
return LLMConnectionError(
message=f"Request to {self._provider_name()} was cancelled (client disconnected): {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"status_code": 499, "cause": "client_cancelled", "is_byok": is_byok},
)
logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
# Handle specific error codes
if e.code == 400:
error_str = str(e).lower()
if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
if ("context" in error_str or "token count" in error_str or "tokens allowed" in error_str) and (
"exceed" in error_str or "limit" in error_str or "too long" in error_str
):
return ContextWindowExceededError(
message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
details={"is_byok": is_byok},
)
else:
return LLMBadRequestError(
message=f"Bad request to {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
code=ErrorCode.INVALID_ARGUMENT,
details={"is_byok": is_byok},
)
elif e.code == 401:
return LLMAuthenticationError(
message=f"Authentication failed with {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif e.code == 403:
return LLMPermissionDeniedError(
message=f"Permission denied by {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif e.code == 404:
return LLMNotFoundError(
message=f"Resource not found in {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif e.code == 408:
return LLMTimeoutError(
message=f"Request to {self._provider_name()} timed out: {str(e)}",
code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
elif e.code == 402 or is_insufficient_credits_message(str(e)):
msg = str(e)
return LLMInsufficientCreditsError(
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
code=ErrorCode.PAYMENT_REQUIRED,
details={"status_code": e.code, "is_byok": is_byok},
)
elif e.code == 422:
return LLMUnprocessableEntityError(
message=f"Invalid request content for {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif e.code == 429:
logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
return LLMRateLimitError(
message=f"Rate limited by {self._provider_name()}: {str(e)}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details={"is_byok": is_byok},
)
else:
return LLMServerError(
@@ -806,6 +961,7 @@ class GoogleVertexClient(LLMClientBase):
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
@@ -820,13 +976,14 @@ class GoogleVertexClient(LLMClientBase):
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
elif e.code == 502:
return LLMConnectionError(
message=f"Bad gateway from {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
elif e.code == 503:
return LLMServerError(
@@ -835,13 +992,14 @@ class GoogleVertexClient(LLMClientBase):
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
elif e.code == 504:
return LLMTimeoutError(
message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
else:
return LLMServerError(
@@ -850,6 +1008,7 @@ class GoogleVertexClient(LLMClientBase):
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
@@ -861,6 +1020,7 @@ class GoogleVertexClient(LLMClientBase):
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
@@ -872,7 +1032,7 @@ class GoogleVertexClient(LLMClientBase):
return LLMConnectionError(
message=f"Connection error during {self._provider_name()} streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
# Handle httpx network errors which can occur during streaming
@@ -882,7 +1042,7 @@ class GoogleVertexClient(LLMClientBase):
return LLMConnectionError(
message=f"Network error during {self._provider_name()} streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
)
# Handle connection-related errors
@@ -891,13 +1051,15 @@ class GoogleVertexClient(LLMClientBase):
return LLMConnectionError(
message=f"Failed to connect to {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
# Fallback to base implementation for other errors
return super().handle_llm_error(e)
return super().handle_llm_error(e, llm_config=llm_config)
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
async def count_tokens(
self, messages: List[dict] | None = None, model: str | None = None, tools: List[OpenAITool] | None = None
) -> int:
"""
Count tokens for the given messages and tools using the Gemini token counting API.

View File

@@ -5,6 +5,7 @@ from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
@@ -34,6 +35,11 @@ class GroqClient(OpenAIClient):
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# Groq only supports string values for tool_choice: "none", "auto", "required"
# Convert object-format tool_choice (used for force_tool_call) to "required"
if "tool_choice" in data and isinstance(data["tool_choice"], dict):
data["tool_choice"] = "required"
# Groq validation - these fields are not supported and will cause 400 errors
# https://console.groq.com/docs/openai
if "top_logprobs" in data:
@@ -69,6 +75,8 @@ class GroqClient(OpenAIClient):
"""
Performs underlying asynchronous request to Groq API and returns raw response dict.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)

View File

@@ -1,23 +1,17 @@
import copy
import json
import logging
from collections import OrderedDict
from typing import Any, List, Optional, Union
from typing import List, Optional
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.helpers.json_helpers import json_dumps
from letta.log import get_logger
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
from letta.schemas.response_format import (
JsonObjectResponseFormat,
JsonSchemaResponseFormat,
ResponseFormatType,
ResponseFormatUnion,
TextResponseFormat,
)
from letta.settings import summarizer_settings
from letta.utils import printd
logger = get_logger(__name__)

View File

@@ -23,7 +23,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.orm.user import User
from letta.otel.tracing import log_event, trace_method
from letta.schemas.enums import ProviderCategory
from letta.schemas.enums import LLMCallType, ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
@@ -245,6 +245,7 @@ def create(
request_json=prepare_openai_payload(data),
response_json=response.model_json_schema(),
step_id=step_id,
call_type=LLMCallType.agent_step,
),
)

View File

@@ -10,7 +10,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.errors import ErrorCode, LLMConnectionError, LLMError
from letta.otel.tracing import log_event, trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType, ProviderCategory
from letta.schemas.enums import AgentType, LLMCallType, ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
@@ -61,8 +61,11 @@ class LLMClientBase:
user_id: Optional[str] = None,
compaction_settings: Optional[Dict] = None,
llm_config: Optional[Dict] = None,
actor: Optional["User"] = None,
) -> None:
"""Set telemetry context for provider trace logging."""
if actor is not None:
self.actor = actor
self._telemetry_manager = telemetry_manager
self._telemetry_agent_id = agent_id
self._telemetry_agent_tags = agent_tags
@@ -82,6 +85,10 @@ class LLMClientBase:
"""Wrapper around request_async that logs telemetry for all requests including errors.
Call set_telemetry_context() first to set agent_id, run_id, etc.
Telemetry is logged via TelemetryManager which supports multiple backends
(postgres, clickhouse, socket, etc.) configured via
LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND.
"""
from letta.log import get_logger
@@ -97,6 +104,7 @@ class LLMClientBase:
error_type = type(e).__name__
raise
finally:
# Log telemetry via configured backends
if self._telemetry_manager and settings.track_provider_trace:
if self.actor is None:
logger.warning(f"Skipping telemetry: actor is None (call_type={self._telemetry_call_type})")
@@ -116,24 +124,33 @@ class LLMClientBase:
org_id=self._telemetry_org_id,
user_id=self._telemetry_user_id,
compaction_settings=self._telemetry_compaction_settings,
llm_config=self._telemetry_llm_config,
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
),
)
except Exception as e:
logger.warning(f"Failed to log telemetry: {e}")
async def stream_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig):
"""Returns raw stream. Caller should log telemetry after processing via log_provider_trace_async().
Call set_telemetry_context() first to set agent_id, run_id, etc.
After consuming the stream, call log_provider_trace_async() with the response data.
"""
return await self.stream_async(request_data, llm_config)
async def log_provider_trace_async(self, request_data: dict, response_json: dict) -> None:
async def log_provider_trace_async(
self,
request_data: dict,
response_json: Optional[dict],
llm_config: Optional[LLMConfig] = None,
latency_ms: Optional[int] = None,
error_msg: Optional[str] = None,
error_type: Optional[str] = None,
) -> None:
"""Log provider trace telemetry. Call after processing LLM response.
Uses telemetry context set via set_telemetry_context().
Telemetry is logged via TelemetryManager which supports multiple backends.
Args:
request_data: The request payload sent to the LLM
response_json: The response payload from the LLM
llm_config: LLMConfig for extracting provider/model info
latency_ms: Latency in milliseconds (not used currently, kept for API compatibility)
error_msg: Error message if request failed (not used currently)
error_type: Error type if request failed (not used currently)
"""
from letta.log import get_logger
@@ -146,6 +163,13 @@ class LLMClientBase:
logger.warning(f"Skipping telemetry: actor is None (call_type={self._telemetry_call_type})")
return
if response_json is None:
if error_msg:
response_json = {"error": error_msg, "error_type": error_type}
else:
logger.warning(f"Skipping telemetry: no response_json or error_msg (call_type={self._telemetry_call_type})")
return
try:
pydantic_actor = self.actor.to_pydantic() if hasattr(self.actor, "to_pydantic") else self.actor
await self._telemetry_manager.create_provider_trace_async(
@@ -161,7 +185,7 @@ class LLMClientBase:
org_id=self._telemetry_org_id,
user_id=self._telemetry_user_id,
compaction_settings=self._telemetry_compaction_settings,
llm_config=self._telemetry_llm_config,
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
),
)
except Exception as e:
@@ -204,11 +228,12 @@ class LLMClientBase:
request_json=request_data,
response_json=response_data,
step_id=step_id,
call_type=LLMCallType.agent_step,
),
)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:
raise self.handle_llm_error(e)
raise self.handle_llm_error(e, llm_config=llm_config)
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
@@ -237,12 +262,13 @@ class LLMClientBase:
request_json=request_data,
response_json=response_data,
step_id=step_id,
call_type=LLMCallType.agent_step,
),
)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:
raise self.handle_llm_error(e)
raise self.handle_llm_error(e, llm_config=llm_config)
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
@@ -334,17 +360,20 @@ class LLMClientBase:
raise NotImplementedError
@abstractmethod
def handle_llm_error(self, e: Exception) -> Exception:
def handle_llm_error(self, e: Exception, llm_config: Optional["LLMConfig"] = None) -> Exception:
"""
Maps provider-specific errors to common LLMError types.
Each LLM provider should implement this to translate their specific errors.
Args:
e: The original provider-specific exception
llm_config: Optional LLM config to determine if this is a BYOK key
Returns:
An LLMError subclass that represents the error in a provider-agnostic way
"""
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
# Handle httpx.RemoteProtocolError which can occur during streaming
# when the remote server closes the connection unexpectedly
# (e.g., "peer closed connection without sending complete message body")
@@ -356,10 +385,10 @@ class LLMClientBase:
return LLMConnectionError(
message=f"Connection error during streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
return LLMError(f"Unhandled LLM error: {str(e)}")
return LLMError(message=f"Unhandled LLM error: {str(e)}", details={"is_byok": is_byok})
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""

View File

@@ -4,6 +4,7 @@ import anthropic
from anthropic import AsyncStream
from anthropic.types.beta import BetaMessage, BetaRawMessageStreamEvent
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.anthropic_client import AnthropicClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
@@ -83,6 +84,8 @@ class MiniMaxClient(AnthropicClient):
Uses beta messages API for compatibility with Anthropic streaming interfaces.
"""
request_data = sanitize_unicode_surrogates(request_data)
client = await self._get_anthropic_client_async(llm_config, async_client=True)
try:
@@ -105,6 +108,8 @@ class MiniMaxClient(AnthropicClient):
Uses beta messages API for compatibility with Anthropic streaming interfaces.
"""
request_data = sanitize_unicode_surrogates(request_data)
client = await self._get_anthropic_client_async(llm_config, async_client=True)
request_data["stream"] = True

View File

@@ -10,7 +10,7 @@ async def mistral_get_model_list_async(url: str, api_key: str) -> dict:
url = smart_urljoin(url, "models")
headers = {"Content-Type": "application/json"}
if api_key is not None:
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
logger.debug("Sending request to %s", url)

View File

@@ -59,7 +59,7 @@ async def openai_get_model_list_async(
url = smart_urljoin(url, "models")
headers = {"Content-Type": "application/json"}
if api_key is not None:
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if "openrouter.ai" in url:
if model_settings.openrouter_referer:
@@ -86,7 +86,7 @@ async def openai_get_model_list_async(
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
error_response = http_err.response.json()
except:
except Exception:
error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text}
logger.debug(f"Got HTTPError, exception={http_err}, response={error_response}")
raise http_err
@@ -478,7 +478,7 @@ def openai_chat_completions_request_stream(
data = prepare_openai_payload(chat_completion_request)
data["stream"] = True
kwargs = {"api_key": api_key, "base_url": url, "max_retries": 0}
kwargs = {"api_key": api_key or "DUMMY_API_KEY", "base_url": url, "max_retries": 0}
if "openrouter.ai" in url:
headers = {}
if model_settings.openrouter_referer:
@@ -511,7 +511,7 @@ def openai_chat_completions_request(
https://platform.openai.com/docs/guides/text-generation?lang=curl
"""
data = prepare_openai_payload(chat_completion_request)
kwargs = {"api_key": api_key, "base_url": url, "max_retries": 0}
kwargs = {"api_key": api_key or "DUMMY_API_KEY", "base_url": url, "max_retries": 0}
if "openrouter.ai" in url:
headers = {}
if model_settings.openrouter_referer:
@@ -524,7 +524,17 @@ def openai_chat_completions_request(
log_event(name="llm_request_sent", attributes=data)
chat_completion = client.chat.completions.create(**data)
log_event(name="llm_response_received", attributes=chat_completion.model_dump())
return ChatCompletionResponse(**chat_completion.model_dump())
response = ChatCompletionResponse(**chat_completion.model_dump())
# Override tool_call IDs to ensure cross-provider compatibility (matches streaming path behavior)
# Some models (e.g. Kimi via OpenRouter) generate IDs like 'Read:93' which violate Anthropic's pattern
for choice in response.choices:
if choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
if tool_call.id is not None:
tool_call.id = get_tool_call_id()
return response
def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):

View File

@@ -20,6 +20,7 @@ from letta.errors import (
LLMAuthenticationError,
LLMBadRequestError,
LLMConnectionError,
LLMInsufficientCreditsError,
LLMNotFoundError,
LLMPermissionDeniedError,
LLMRateLimitError,
@@ -27,7 +28,8 @@ from letta.errors import (
LLMTimeoutError,
LLMUnprocessableEntityError,
)
from letta.llm_api.error_utils import is_context_window_overflow_message
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.error_utils import is_context_window_overflow_message, is_insufficient_credits_message
from letta.llm_api.helpers import (
add_inner_thoughts_to_functions,
convert_response_format_to_responses_api,
@@ -39,13 +41,13 @@ from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory
from letta.schemas.letta_message_content import MessageContentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import (
ChatCompletionRequest,
FunctionCall as ToolFunctionChoiceFunctionCall,
FunctionSchema,
Tool as OpenAITool,
ToolFunctionChoice,
cast_message_to_subtype,
@@ -56,7 +58,6 @@ from letta.schemas.openai.chat_completion_response import (
FunctionCall,
Message as ChoiceMessage,
ToolCall,
UsageStatistics,
)
from letta.schemas.openai.responses_request import ResponsesRequest
from letta.schemas.response_format import JsonSchemaResponseFormat
@@ -105,7 +106,7 @@ def accepts_developer_role(model: str) -> bool:
See: https://community.openai.com/t/developer-role-not-accepted-for-o1-o1-mini-o3-mini/1110750/7
"""
if is_openai_reasoning_model(model) and "o1-mini" not in model or "o1-preview" in model:
if (is_openai_reasoning_model(model) and "o1-mini" not in model) or "o1-preview" in model:
return True
else:
return False
@@ -244,6 +245,64 @@ class OpenAIClient(LLMClientBase):
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return supports_structured_output(llm_config)
def _is_openrouter_request(self, llm_config: LLMConfig) -> bool:
return (llm_config.model_endpoint and "openrouter.ai" in llm_config.model_endpoint) or (llm_config.provider_name == "openrouter")
def _is_true_openai_request(self, llm_config: LLMConfig) -> bool:
if llm_config.model_endpoint_type != "openai":
return False
if self._is_openrouter_request(llm_config):
return False
# Keep Letta inference endpoint behavior unchanged.
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
return False
# If provider_name is explicitly set and not openai, don't apply OpenAI-specific prompt caching fields.
if llm_config.provider_name and llm_config.provider_name != "openai":
return False
return True
def _normalize_model_name(self, model: Optional[str]) -> Optional[str]:
if not model:
return None
return model.split("/", 1)[-1]
def _supports_extended_prompt_cache_retention(self, model: Optional[str]) -> bool:
normalized_model = self._normalize_model_name(model)
if not normalized_model:
return False
# Per OpenAI docs: extended retention is available on gpt-4.1 and gpt-5 family models.
# gpt-5-mini is excluded (not listed in docs).
return normalized_model == "gpt-4.1" or (normalized_model.startswith("gpt-5") and normalized_model != "gpt-5-mini")
def _apply_prompt_cache_settings(
self,
llm_config: LLMConfig,
model: Optional[str],
messages: List[PydanticMessage],
request_obj: Any,
) -> None:
"""Apply OpenAI prompt cache settings to the request.
We intentionally do NOT set prompt_cache_key. OpenAI's default routing
(based on a hash of the first ~256 tokens of the prompt) already provides
good cache affinity for Letta agents, since each agent has a unique system
prompt. Setting an explicit key can disrupt existing warm caches and reduce
hit rates.
We only set prompt_cache_retention to "24h" for models that support extended
retention, which keeps cached prefixes active longer (up to 24h vs 5-10min).
"""
if not self._is_true_openai_request(llm_config):
return
if self._supports_extended_prompt_cache_retention(model):
request_obj.prompt_cache_retention = "24h"
@trace_method
def build_request_data_responses(
self,
@@ -384,6 +443,13 @@ class OpenAIClient(LLMClientBase):
data.model = "memgpt-openai"
self._apply_prompt_cache_settings(
llm_config=llm_config,
model=model,
messages=messages,
request_obj=data,
)
request_data = data.model_dump(exclude_unset=True)
# print("responses request data", request_data)
return request_data
@@ -452,13 +518,11 @@ class OpenAIClient(LLMClientBase):
model = None
# TODO: we may need to extend this to more models using proxy?
is_openrouter = (llm_config.model_endpoint and "openrouter.ai" in llm_config.model_endpoint) or (
llm_config.provider_name == "openrouter"
)
is_openrouter = self._is_openrouter_request(llm_config)
if is_openrouter:
try:
model = llm_config.handle.split("/", 1)[-1]
except:
except Exception:
# don't raise error since this isn't robust against edge cases
pass
@@ -468,7 +532,12 @@ class OpenAIClient(LLMClientBase):
tool_choice = None
if tools: # only set tool_choice if tools exist
if force_tool_call is not None:
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
# OpenRouter proxies to providers that may not support object-format tool_choice
# Use "required" instead which achieves similar effect
if is_openrouter:
tool_choice = "required"
else:
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
elif requires_subsequent_tool_call:
tool_choice = "required"
elif self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
@@ -505,6 +574,12 @@ class OpenAIClient(LLMClientBase):
if llm_config.frequency_penalty is not None:
data.frequency_penalty = llm_config.frequency_penalty
# Add logprobs configuration for RL training
if llm_config.return_logprobs:
data.logprobs = True
if llm_config.top_logprobs is not None:
data.top_logprobs = llm_config.top_logprobs
if tools and supports_parallel_tool_calling(model):
data.parallel_tool_calls = False
@@ -546,6 +621,13 @@ class OpenAIClient(LLMClientBase):
new_tools.append(tool.model_copy(deep=True))
data.tools = new_tools
self._apply_prompt_cache_settings(
llm_config=llm_config,
model=model,
messages=messages,
request_obj=data,
)
# Note: Tools are already processed by enable_strict_mode() in the workflow/agent code
# (temporal_letta_v1_agent_workflow.py or letta_agent_v3.py) before reaching here.
# enable_strict_mode() handles: strict flag, additionalProperties, required array, nullable fields
@@ -564,6 +646,17 @@ class OpenAIClient(LLMClientBase):
# If set, then in the backend "medium" thinking is turned on
# request_data["reasoning_effort"] = "medium"
# Add OpenRouter reasoning configuration via extra_body
if is_openrouter and llm_config.enable_reasoner:
reasoning_config = {}
if llm_config.reasoning_effort:
reasoning_config["effort"] = llm_config.reasoning_effort
if llm_config.max_reasoning_tokens and llm_config.max_reasoning_tokens > 0:
reasoning_config["max_tokens"] = llm_config.max_reasoning_tokens
if not reasoning_config:
reasoning_config = {"enabled": True}
request_data["extra_body"] = {"reasoning": reasoning_config}
return request_data
@trace_method
@@ -571,29 +664,51 @@ class OpenAIClient(LLMClientBase):
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
# Sanitize Unicode surrogates to prevent encoding errors
request_data = sanitize_unicode_surrogates(request_data)
client = OpenAI(**self._prepare_client_kwargs(llm_config))
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
try:
if "input" in request_data and "messages" not in request_data:
resp = client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
except json.JSONDecodeError as e:
logger.error(f"[OpenAI] Failed to parse API response as JSON: {e}")
raise LLMServerError(
message=f"OpenAI API returned invalid JSON response (likely an HTML error page): {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"json_error": str(e), "error_position": f"line {e.lineno} column {e.colno}"},
)
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
# Sanitize Unicode surrogates to prevent encoding errors
request_data = sanitize_unicode_surrogates(request_data)
kwargs = await self._prepare_client_kwargs_async(llm_config)
client = AsyncOpenAI(**kwargs)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = await client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
try:
if "input" in request_data and "messages" not in request_data:
resp = await client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
except json.JSONDecodeError as e:
logger.error(f"[OpenAI] Failed to parse API response as JSON: {e}")
raise LLMServerError(
message=f"OpenAI API returned invalid JSON response (likely an HTML error page): {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"json_error": str(e), "error_position": f"line {e.lineno} column {e.colno}"},
)
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return is_openai_reasoning_model(llm_config.model)
@@ -669,6 +784,12 @@ class OpenAIClient(LLMClientBase):
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
Handles potential extraction of inner thoughts if they were added via kwargs.
"""
if isinstance(response_data, str):
raise LLMServerError(
message="LLM endpoint returned a raw string instead of a JSON object. This usually indicates the endpoint URL is incorrect or returned an error page.",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"raw_response": response_data[:500]},
)
if "object" in response_data and response_data["object"] == "response":
# Map Responses API shape to Chat Completions shape
# See example payload in tests/integration_test_send_message_v2.py
@@ -696,7 +817,6 @@ class OpenAIClient(LLMClientBase):
finish_reason = None
# Optionally capture reasoning presence
found_reasoning = False
for out in outputs:
out_type = (out or {}).get("type")
if out_type == "message":
@@ -707,7 +827,6 @@ class OpenAIClient(LLMClientBase):
if text_val:
assistant_text_parts.append(text_val)
elif out_type == "reasoning":
found_reasoning = True
reasoning_summary_parts = [part.get("text") for part in out.get("summary")]
reasoning_content_signature = out.get("encrypted_content")
elif out_type == "function_call":
@@ -765,12 +884,12 @@ class OpenAIClient(LLMClientBase):
):
if "choices" in response_data and len(response_data["choices"]) > 0:
choice_data = response_data["choices"][0]
if "message" in choice_data and "reasoning_content" in choice_data["message"]:
reasoning_content = choice_data["message"]["reasoning_content"]
if reasoning_content:
chat_completion_response.choices[0].message.reasoning_content = reasoning_content
chat_completion_response.choices[0].message.reasoning_content_signature = None
message_data = choice_data.get("message", {})
# Check for reasoning_content (standard) or reasoning (OpenRouter)
reasoning_content = message_data.get("reasoning_content") or message_data.get("reasoning")
if reasoning_content:
chat_completion_response.choices[0].message.reasoning_content = reasoning_content
chat_completion_response.choices[0].message.reasoning_content_signature = None
# Unpack inner thoughts if they were embedded in function arguments
if llm_config.put_inner_thoughts_in_kwargs:
@@ -789,6 +908,9 @@ class OpenAIClient(LLMClientBase):
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
# Sanitize Unicode surrogates to prevent encoding errors
request_data = sanitize_unicode_surrogates(request_data)
kwargs = await self._prepare_client_kwargs_async(llm_config)
client = AsyncOpenAI(**kwargs)
@@ -820,6 +942,9 @@ class OpenAIClient(LLMClientBase):
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
# Sanitize Unicode surrogates to prevent encoding errors
request_data = sanitize_unicode_surrogates(request_data)
kwargs = await self._prepare_client_kwargs_async(llm_config)
client = AsyncOpenAI(**kwargs)
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(**request_data, stream=True)
@@ -958,10 +1083,11 @@ class OpenAIClient(LLMClientBase):
return results
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
"""
Maps OpenAI-specific errors to common LLMError types.
"""
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
if isinstance(e, openai.APITimeoutError):
timeout_duration = getattr(e, "timeout", "unknown")
logger.warning(f"[OpenAI] Request timeout after {timeout_duration} seconds: {e}")
@@ -971,6 +1097,7 @@ class OpenAIClient(LLMClientBase):
details={
"timeout_duration": timeout_duration,
"cause": str(e.__cause__) if e.__cause__ else None,
"is_byok": is_byok,
},
)
@@ -979,7 +1106,7 @@ class OpenAIClient(LLMClientBase):
return LLMConnectionError(
message=f"Failed to connect to OpenAI: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
# Handle httpx.RemoteProtocolError which can occur during streaming
@@ -990,7 +1117,7 @@ class OpenAIClient(LLMClientBase):
return LLMConnectionError(
message=f"Connection error during OpenAI streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
# Handle httpx network errors which can occur during streaming
@@ -1000,37 +1127,47 @@ class OpenAIClient(LLMClientBase):
return LLMConnectionError(
message=f"Network error during OpenAI streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
)
if isinstance(e, openai.RateLimitError):
logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}")
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
return LLMRateLimitError(
message=f"Rate limited by OpenAI: {str(e)}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details=e.body, # Include body which often has rate limit details
details={**body_details, "is_byok": is_byok},
)
if isinstance(e, openai.BadRequestError):
logger.warning(f"[OpenAI] Bad request (400): {str(e)}")
# BadRequestError can signify different issues (e.g., invalid args, context length)
# Check for context_length_exceeded error code in the error body
error_str = str(e)
if "<html" in error_str.lower() or (e.body and isinstance(e.body, str) and "<html" in e.body.lower()):
logger.warning("[OpenAI] Received HTML error response from upstream endpoint (likely ALB or reverse proxy)")
return LLMBadRequestError(
message="Upstream endpoint returned HTML error (400 Bad Request). This usually indicates the configured API endpoint is not an OpenAI-compatible API or the request was rejected by a load balancer.",
code=ErrorCode.INVALID_ARGUMENT,
details={"raw_body_preview": error_str[:500]},
)
error_code = None
if e.body and isinstance(e.body, dict):
error_details = e.body.get("error", {})
if isinstance(error_details, dict):
error_code = error_details.get("code")
# Check both the error code and message content for context length issues
if error_code == "context_length_exceeded" or is_context_window_overflow_message(str(e)):
if error_code == "context_length_exceeded" or is_context_window_overflow_message(error_str):
return ContextWindowExceededError(
message=f"Bad request to OpenAI (context window exceeded): {str(e)}",
message=f"Bad request to OpenAI (context window exceeded): {error_str}",
details={"is_byok": is_byok},
)
else:
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
return LLMBadRequestError(
message=f"Bad request to OpenAI: {str(e)}",
code=ErrorCode.INVALID_ARGUMENT, # Or more specific if detectable
details=e.body,
message=f"Bad request to OpenAI-compatible endpoint: {str(e)}",
code=ErrorCode.INVALID_ARGUMENT,
details={**body_details, "is_byok": is_byok},
)
# NOTE: The OpenAI Python SDK may raise a generic `openai.APIError` while *iterating*
@@ -1040,46 +1177,91 @@ class OpenAIClient(LLMClientBase):
#
# Example message:
# "Your input exceeds the context window of this model. Please adjust your input and try again."
if isinstance(e, openai.APIError):
if isinstance(e, openai.APIError) and not isinstance(e, openai.APIStatusError):
msg = str(e)
if is_context_window_overflow_message(msg):
return ContextWindowExceededError(
message=f"OpenAI request exceeded the context window: {msg}",
details={
"provider_exception_type": type(e).__name__,
# Best-effort extraction (may not exist on APIError)
"body": getattr(e, "body", None),
"is_byok": is_byok,
},
)
if is_insufficient_credits_message(msg):
return LLMInsufficientCreditsError(
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
code=ErrorCode.PAYMENT_REQUIRED,
details={
"provider_exception_type": type(e).__name__,
"body": getattr(e, "body", None),
"is_byok": is_byok,
},
)
return LLMBadRequestError(
message=f"OpenAI API error: {msg}",
code=ErrorCode.INVALID_ARGUMENT,
details={
"provider_exception_type": type(e).__name__,
"body": getattr(e, "body", None),
"is_byok": is_byok,
},
)
if isinstance(e, openai.AuthenticationError):
logger.error(f"[OpenAI] Authentication error (401): {str(e)}") # More severe log level
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
return LLMAuthenticationError(
message=f"Authentication failed with OpenAI: {str(e)}", code=ErrorCode.UNAUTHENTICATED, details=e.body
message=f"Authentication failed with OpenAI: {str(e)}",
code=ErrorCode.UNAUTHENTICATED,
details={**body_details, "is_byok": is_byok},
)
if isinstance(e, openai.PermissionDeniedError):
logger.error(f"[OpenAI] Permission denied (403): {str(e)}") # More severe log level
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
return LLMPermissionDeniedError(
message=f"Permission denied by OpenAI: {str(e)}", code=ErrorCode.PERMISSION_DENIED, details=e.body
message=f"Permission denied by OpenAI: {str(e)}",
code=ErrorCode.PERMISSION_DENIED,
details={**body_details, "is_byok": is_byok},
)
if isinstance(e, openai.NotFoundError):
logger.warning(f"[OpenAI] Resource not found (404): {str(e)}")
# Could be invalid model name, etc.
return LLMNotFoundError(message=f"Resource not found in OpenAI: {str(e)}", code=ErrorCode.NOT_FOUND, details=e.body)
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
return LLMNotFoundError(
message=f"Resource not found in OpenAI: {str(e)}",
code=ErrorCode.NOT_FOUND,
details={**body_details, "is_byok": is_byok},
)
if isinstance(e, openai.UnprocessableEntityError):
logger.warning(f"[OpenAI] Unprocessable entity (422): {str(e)}")
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
return LLMUnprocessableEntityError(
message=f"Invalid request content for OpenAI: {str(e)}",
code=ErrorCode.INVALID_ARGUMENT, # Usually validation errors
details=e.body,
code=ErrorCode.INVALID_ARGUMENT,
details={**body_details, "is_byok": is_byok},
)
# General API error catch-all
if isinstance(e, openai.APIStatusError):
logger.warning(f"[OpenAI] API status error ({e.status_code}): {str(e)}")
# Handle 413 Request Entity Too Large - request payload exceeds size limits
if e.status_code == 413:
return ContextWindowExceededError(
message=f"Request too large for OpenAI (413): {str(e)}",
details={"is_byok": is_byok},
)
# Handle 402 Payment Required or credit-related messages
if e.status_code == 402 or is_insufficient_credits_message(str(e)):
msg = str(e)
return LLMInsufficientCreditsError(
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
code=ErrorCode.PAYMENT_REQUIRED,
details={"status_code": e.status_code, "body": e.body, "is_byok": is_byok},
)
# Map based on status code potentially
if e.status_code >= 500:
error_cls = LLMServerError
@@ -1096,11 +1278,12 @@ class OpenAIClient(LLMClientBase):
"status_code": e.status_code,
"response": str(e.response),
"body": e.body,
"is_byok": is_byok,
},
)
# Fallback for unexpected errors
return super().handle_llm_error(e)
return super().handle_llm_error(e, llm_config=llm_config)
def fill_image_content_in_messages(openai_message_list: List[dict], pydantic_message_list: List[PydanticMessage]) -> List[dict]:

View File

@@ -0,0 +1,108 @@
"""
SGLang Native Client for Letta.
This client uses SGLang's native /generate endpoint instead of the OpenAI-compatible
/v1/chat/completions endpoint. The native endpoint returns token IDs and per-token
logprobs, which are essential for multi-turn RL training.
The OpenAI-compatible endpoint only returns token strings, not IDs, making it
impossible to accurately reconstruct the token sequence for training.
"""
from typing import Any, Dict, Optional
import httpx
from letta.log import get_logger
logger = get_logger(__name__)
class SGLangNativeClient:
"""Client for SGLang's native /generate endpoint.
Unlike the OpenAI-compatible endpoint, this returns:
- output_ids: List of token IDs
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
This is essential for RL training where we need exact token IDs, not re-tokenized text.
"""
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the SGLang native client.
Args:
base_url: Base URL for SGLang server (e.g., http://localhost:30000)
api_key: Optional API key for authentication
"""
# Remove /v1 suffix if present - native endpoint is at root
self.base_url = base_url.rstrip("/")
if self.base_url.endswith("/v1"):
self.base_url = self.base_url[:-3]
self.api_key = api_key
async def generate(
self,
text: str,
sampling_params: Optional[Dict[str, Any]] = None,
return_logprob: bool = True,
) -> Dict[str, Any]:
"""
Call SGLang's native /generate endpoint.
Args:
text: The formatted prompt text (with chat template applied)
sampling_params: Sampling parameters (temperature, max_new_tokens, etc.)
return_logprob: Whether to return logprobs (default True for RL training)
Returns:
Response dict with:
- text: Generated text
- output_ids: List of token IDs
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
- meta_info: Metadata including finish_reason, prompt_tokens, etc.
Example response:
{
"text": "Hello! How can I help?",
"output_ids": [9707, 0, 2585, 646, 358, 1492, 30],
"output_token_logprobs": [
[-0.005, 9707, null],
[0.0, 0, null],
...
],
"meta_info": {
"finish_reason": {"type": "stop", "matched": 151645},
"prompt_tokens": 42,
...
}
}
"""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"text": text,
"sampling_params": sampling_params or {},
"return_logprob": return_logprob,
}
async with httpx.AsyncClient(timeout=300.0) as client:
response = await client.post(
f"{self.base_url}/generate",
json=payload,
headers=headers,
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""Check if the SGLang server is healthy."""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(f"{self.base_url}/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -4,6 +4,7 @@ from typing import List
from openai import AsyncOpenAI, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
@@ -34,6 +35,8 @@ class TogetherClient(OpenAIClient):
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
if not api_key:

View File

@@ -5,6 +5,7 @@ from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
@@ -59,6 +60,8 @@ class XAIClient(OpenAIClient):
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
@@ -70,6 +73,8 @@ class XAIClient(OpenAIClient):
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(

View File

@@ -1,19 +1,30 @@
import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import model_settings
def is_zai_reasoning_model(model_name: str) -> bool:
"""Check if the model is a ZAI reasoning model (GLM-4.5+)."""
return (
model_name.startswith("glm-4.5")
or model_name.startswith("glm-4.6")
or model_name.startswith("glm-4.7")
or model_name.startswith("glm-5")
)
class ZAIClient(OpenAIClient):
"""Z.ai (ZhipuAI) client - uses OpenAI-compatible API."""
@@ -23,6 +34,10 @@ class ZAIClient(OpenAIClient):
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return False
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
"""Returns True if the model is a ZAI reasoning model (GLM-4.5+)."""
return is_zai_reasoning_model(llm_config.model)
@trace_method
def build_request_data(
self,
@@ -35,6 +50,50 @@ class ZAIClient(OpenAIClient):
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# Add thinking configuration for ZAI GLM-4.5+ models
# Must explicitly send type: "disabled" when reasoning is off, as GLM-4.7 has thinking on by default
if self.is_reasoning_model(llm_config):
if llm_config.enable_reasoner:
data["extra_body"] = {
"thinking": {
"type": "enabled",
"clear_thinking": False, # Preserved thinking for agents
}
}
else:
data["extra_body"] = {
"thinking": {
"type": "disabled",
}
}
# Sanitize empty text content — ZAI rejects empty text blocks
if "messages" in data:
for msg in data["messages"]:
content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", None)
# String content: replace empty with None (assistant+tool_calls) or "."
if isinstance(content, str) and not content.strip():
role = msg.get("role") if isinstance(msg, dict) else getattr(msg, "role", None)
has_tool_calls = msg.get("tool_calls") if isinstance(msg, dict) else getattr(msg, "tool_calls", None)
if role == "assistant" and has_tool_calls:
# assistant + tool_calls: null content is valid in OpenAI format
if isinstance(msg, dict):
msg["content"] = None
else:
msg.content = None
else:
if isinstance(msg, dict):
msg["content"] = "."
else:
msg.content = "."
# List content: fix empty text blocks within arrays
elif isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
if not block.get("text", "").strip():
block["text"] = "."
return data
@trace_method
@@ -53,6 +112,8 @@ class ZAIClient(OpenAIClient):
"""
Performs underlying asynchronous request to Z.ai API and returns raw response dict.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key = model_settings.zai_api_key
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
@@ -64,6 +125,8 @@ class ZAIClient(OpenAIClient):
"""
Performs underlying asynchronous streaming request to Z.ai and returns the async stream iterator.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key = model_settings.zai_api_key
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
@@ -79,3 +142,39 @@ class ZAIClient(OpenAIClient):
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
return [r.embedding for r in response.data]
@trace_method
async def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage],
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts raw ZAI response dict into the ChatCompletionResponse Pydantic model.
Handles extraction of reasoning_content from ZAI GLM-4.5+ responses.
"""
# Use parent class conversion first
chat_completion_response = await super().convert_response_to_chat_completion(response_data, input_messages, llm_config)
# Parse reasoning_content from ZAI responses (similar to OpenAI pattern)
# ZAI returns reasoning_content in delta.reasoning_content (streaming) or message.reasoning_content
if (
chat_completion_response.choices
and len(chat_completion_response.choices) > 0
and chat_completion_response.choices[0].message
and not chat_completion_response.choices[0].message.reasoning_content
):
if "choices" in response_data and len(response_data["choices"]) > 0:
choice_data = response_data["choices"][0]
if "message" in choice_data and "reasoning_content" in choice_data["message"]:
reasoning_content = choice_data["message"]["reasoning_content"]
if reasoning_content:
chat_completion_response.choices[0].message.reasoning_content = reasoning_content
chat_completion_response.choices[0].message.reasoning_content_signature = None
# If we used a reasoning model, mark that reasoning content was used
if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
chat_completion_response.choices[0].message.omitted_reasoning_content = True
return chat_completion_response

View File

@@ -16,7 +16,7 @@ from letta.local_llm.llamacpp.api import get_llamacpp_completion
from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
from letta.local_llm.lmstudio.api import get_lmstudio_completion, get_lmstudio_completion_chatcompletions
from letta.local_llm.ollama.api import get_ollama_completion
from letta.local_llm.utils import count_tokens, get_available_wrappers
from letta.local_llm.utils import get_available_wrappers
from letta.local_llm.vllm.api import get_vllm_completion
from letta.local_llm.webui.api import get_webui_completion
from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
@@ -177,7 +177,7 @@ def get_chat_completion(
raise LocalLLMError(
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
)
except requests.exceptions.ConnectionError as e:
except requests.exceptions.ConnectionError:
raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}")
attributes = usage if isinstance(usage, dict) else {"usage": usage}
@@ -207,10 +207,12 @@ def get_chat_completion(
if usage["prompt_tokens"] is None:
printd("usage dict was missing prompt_tokens, computing on-the-fly...")
usage["prompt_tokens"] = count_tokens(prompt)
# Approximate token count: bytes / 4
usage["prompt_tokens"] = len(prompt.encode("utf-8")) // 4
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result))
# Approximate token count: bytes / 4
usage["completion_tokens"] = len(json_dumps(chat_completion_result).encode("utf-8")) // 4
"""
if usage["completion_tokens"] is None:
printd(f"usage dict was missing completion_tokens, computing on-the-fly...")

View File

@@ -2,7 +2,7 @@
# (settings.py imports from this module indirectly through log.py)
# Import this here to avoid circular dependency at module level
from letta.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper
from letta.settings import DEFAULT_WRAPPER_NAME, INNER_THOUGHTS_KWARG
from letta.settings import INNER_THOUGHTS_KWARG
DEFAULT_WRAPPER = ChatMLInnerMonologueWrapper
INNER_THOUGHTS_KWARG_VERTEX = "thinking"

View File

@@ -5,7 +5,7 @@ from copy import copy
from enum import Enum
from inspect import getdoc, isclass
from types import NoneType
from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAlias, get_args, get_origin
from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAlias, get_args, get_origin # type: ignore[attr-defined]
from docstring_parser import parse
from pydantic import BaseModel, create_model
@@ -58,13 +58,13 @@ def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
return format_model_and_field_name(pydantic_type.__name__)
elif get_origin(pydantic_type) == list:
elif get_origin(pydantic_type) is list:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
elif get_origin(pydantic_type) == set:
elif get_origin(pydantic_type) is set:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
elif get_origin(pydantic_type) == Union:
elif get_origin(pydantic_type) is Union:
union_types = get_args(pydantic_type)
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
return f"union-{'-or-'.join(union_rules)}"
@@ -73,7 +73,7 @@ def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
elif isclass(pydantic_type):
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
elif get_origin(pydantic_type) == dict:
elif get_origin(pydantic_type) is dict:
key_type, value_type = get_args(pydantic_type)
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
else:
@@ -299,7 +299,7 @@ def generate_gbnf_rule_for_type(
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
rules.append(enum_rule)
gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == list: # Array
elif get_origin(field_type) is list: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@@ -309,7 +309,7 @@ def generate_gbnf_rule_for_type(
rules.append(array_rule)
gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == set or field_type == set: # Array
elif get_origin(field_type) is set or field_type is set: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@@ -320,7 +320,7 @@ def generate_gbnf_rule_for_type(
gbnf_type, rules = model_name + "-" + field_name, rules
elif gbnf_type.startswith("custom-class-"):
nested_model_rules, field_types = get_members_structure(field_type, gbnf_type)
nested_model_rules, _field_types = get_members_structure(field_type, gbnf_type)
rules.append(nested_model_rules)
elif gbnf_type.startswith("custom-dict-"):
key_type, value_type = get_args(field_type)
@@ -502,15 +502,15 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
model_rule += '"\\n" ws "}"'
model_rule += '"\\n" markdown-code-block'
has_special_string = True
all_rules = [model_rule] + nested_rules
all_rules = [model_rule, *nested_rules]
return all_rules, has_special_string
def generate_gbnf_grammar_from_pydantic_models(
models: List[Type[BaseModel]],
outer_object_name: str = None,
outer_object_content: str = None,
outer_object_name: str | None = None,
outer_object_content: str | None = None,
list_of_outputs: bool = False,
add_inner_thoughts: bool = False,
allow_only_inner_thoughts: bool = False,
@@ -704,11 +704,11 @@ def generate_markdown_documentation(
# continue
if isclass(field_type) and issubclass(field_type, BaseModel):
pyd_models.append((field_type, False))
if get_origin(field_type) == list:
if get_origin(field_type) is list:
element_type = get_args(field_type)[0]
if isclass(element_type) and issubclass(element_type, BaseModel):
pyd_models.append((element_type, False))
if get_origin(field_type) == Union:
if get_origin(field_type) is Union:
element_types = get_args(field_type)
for element_type in element_types:
if isclass(element_type) and issubclass(element_type, BaseModel):
@@ -747,14 +747,14 @@ def generate_field_markdown(
field_info = model.model_fields.get(field_name)
field_description = field_info.description if field_info and field_info.description else ""
if get_origin(field_type) == list:
if get_origin(field_type) is list:
element_type = get_args(field_type)[0]
field_text = f"{indent}{field_name} ({field_type.__name__} of {element_type.__name__})"
if field_description != "":
field_text += ": "
else:
field_text += "\n"
elif get_origin(field_type) == Union:
elif get_origin(field_type) is Union:
element_types = get_args(field_type)
types = []
for element_type in element_types:
@@ -857,11 +857,11 @@ def generate_text_documentation(
for name, field_type in model.__annotations__.items():
# if name == "markdown_code_block":
# continue
if get_origin(field_type) == list:
if get_origin(field_type) is list:
element_type = get_args(field_type)[0]
if isclass(element_type) and issubclass(element_type, BaseModel):
pyd_models.append((element_type, False))
if get_origin(field_type) == Union:
if get_origin(field_type) is Union:
element_types = get_args(field_type)
for element_type in element_types:
if isclass(element_type) and issubclass(element_type, BaseModel):
@@ -905,14 +905,14 @@ def generate_field_text(
field_info = model.model_fields.get(field_name)
field_description = field_info.description if field_info and field_info.description else ""
if get_origin(field_type) == list:
if get_origin(field_type) is list:
element_type = get_args(field_type)[0]
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
if field_description != "":
field_text += ":\n"
else:
field_text += "\n"
elif get_origin(field_type) == Union:
elif get_origin(field_type) is Union:
element_types = get_args(field_type)
types = []
for element_type in element_types:
@@ -1015,8 +1015,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
pydantic_model_list,
grammar_file_path="./generated_grammar.gbnf",
documentation_file_path="./generated_grammar_documentation.md",
outer_object_name: str = None,
outer_object_content: str = None,
outer_object_name: str | None = None,
outer_object_content: str | None = None,
model_prefix: str = "Output Model",
fields_prefix: str = "Output Fields",
list_of_outputs: bool = False,
@@ -1049,8 +1049,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
def generate_gbnf_grammar_and_documentation(
pydantic_model_list,
outer_object_name: str = None,
outer_object_content: str = None,
outer_object_name: str | None = None,
outer_object_content: str | None = None,
model_prefix: str = "Output Model",
fields_prefix: str = "Output Fields",
list_of_outputs: bool = False,
@@ -1087,8 +1087,8 @@ def generate_gbnf_grammar_and_documentation(
def generate_gbnf_grammar_and_documentation_from_dictionaries(
dictionaries: List[dict],
outer_object_name: str = None,
outer_object_content: str = None,
outer_object_name: str | None = None,
outer_object_content: str | None = None,
model_prefix: str = "Output Model",
fields_prefix: str = "Output Fields",
list_of_outputs: bool = False,

View File

@@ -1,7 +1,7 @@
from urllib.parse import urljoin
from letta.local_llm.settings.settings import get_completions_settings
from letta.local_llm.utils import count_tokens, post_json_auth_request
from letta.local_llm.utils import post_json_auth_request
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
@@ -10,7 +10,8 @@ def get_koboldcpp_completion(endpoint, auth_type, auth_key, prompt, context_wind
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
from letta.utils import printd
prompt_tokens = count_tokens(prompt)
# Approximate token count: bytes / 4
prompt_tokens = len(prompt.encode("utf-8")) // 4
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")

View File

@@ -1,7 +1,7 @@
from urllib.parse import urljoin
from letta.local_llm.settings.settings import get_completions_settings
from letta.local_llm.utils import count_tokens, post_json_auth_request
from letta.local_llm.utils import post_json_auth_request
LLAMACPP_API_SUFFIX = "/completion"
@@ -10,7 +10,8 @@ def get_llamacpp_completion(endpoint, auth_type, auth_key, prompt, context_windo
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
from letta.utils import printd
prompt_tokens = count_tokens(prompt)
# Approximate token count: bytes / 4
prompt_tokens = len(prompt.encode("utf-8")) // 4
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")

Some files were not shown because too many files have changed in this diff Show More