chore: bump version 0.16.5 (#3202)
This commit is contained in:
1
.github/scripts/model-sweep/conftest.py
vendored
1
.github/scripts/model-sweep/conftest.py
vendored
@@ -16,7 +16,6 @@ from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
|
||||
@@ -31,7 +31,7 @@ def get_support_status(passed_tests, feature_tests):
|
||||
|
||||
# Filter out error tests when checking for support
|
||||
non_error_tests = [test for test in feature_tests if not test.endswith("_error")]
|
||||
error_tests = [test for test in feature_tests if test.endswith("_error")]
|
||||
[test for test in feature_tests if test.endswith("_error")]
|
||||
|
||||
# Check which non-error tests passed
|
||||
passed_non_error_tests = [test for test in non_error_tests if test in passed_tests]
|
||||
@@ -137,7 +137,7 @@ def get_github_repo_info():
|
||||
else:
|
||||
return None
|
||||
return repo_path
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
@@ -335,7 +335,7 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
|
||||
# Format timestamp if it's a full ISO string
|
||||
if "T" in str(last_scanned):
|
||||
last_scanned = str(last_scanned).split("T")[0] # Just the date part
|
||||
except:
|
||||
except Exception:
|
||||
last_scanned = "Unknown"
|
||||
|
||||
# Calculate support score for ranking
|
||||
|
||||
6
.github/scripts/model-sweep/model_sweep.py
vendored
6
.github/scripts/model-sweep/model_sweep.py
vendored
@@ -1,16 +1,12 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, MessageCreate, Run
|
||||
from letta_client.core.api_error import ApiError
|
||||
from letta_client.types import (
|
||||
@@ -694,7 +690,7 @@ def test_token_streaming_agent_loop_error(
|
||||
stream_tokens=True,
|
||||
)
|
||||
list(response)
|
||||
except:
|
||||
except Exception:
|
||||
pass # only some models throw an error TODO: make this consistent
|
||||
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
|
||||
4
.github/workflows/reusable-test-workflow.yml
vendored
4
.github/workflows/reusable-test-workflow.yml
vendored
@@ -381,6 +381,10 @@ jobs:
|
||||
GOOGLE_CLOUD_PROJECT: ${{ secrets.GOOGLE_CLOUD_PROJECT }}
|
||||
GOOGLE_CLOUD_LOCATION: ${{ secrets.GOOGLE_CLOUD_LOCATION }}
|
||||
|
||||
# Real object store (required for git-backed memory integration test)
|
||||
# Use DEV bucket/prefix variable to avoid prod resources.
|
||||
LETTA_OBJECT_STORE_URI: ${{ vars.LETTA_OBJECT_STORE_URI_DEV }}
|
||||
|
||||
# Feature flags (shared across all test types)
|
||||
LETTA_ENABLE_BATCH_JOB_POLLING: true
|
||||
|
||||
|
||||
@@ -23,3 +23,10 @@ repos:
|
||||
- id: ruff-check
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: ty
|
||||
name: ty check
|
||||
entry: uv run ty check .
|
||||
language: python
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Start with pgvector base for builder
|
||||
FROM ankane/pgvector:v0.5.1 AS builder
|
||||
FROM pgvector/pgvector:0.8.1-pg15 AS builder
|
||||
# comment to trigger ci
|
||||
# Install Python and required packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
@@ -39,7 +39,7 @@ COPY . .
|
||||
RUN uv sync --frozen --no-dev --all-extras --python 3.11
|
||||
|
||||
# Runtime stage
|
||||
FROM ankane/pgvector:v0.5.1 AS runtime
|
||||
FROM pgvector/pgvector:0.8.1-pg15 AS runtime
|
||||
|
||||
# Overridable Node.js version with --build-arg NODE_VERSION
|
||||
ARG NODE_VERSION=22
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-10-07 13:01:17.872405
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-09-10 19:16:39.118760
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-12-17 15:46:06.184858
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-10-03 12:10:51.065067
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
33
alembic/versions/3e54e2fa2f7e_add_usage_columns_to_steps.py
Normal file
33
alembic/versions/3e54e2fa2f7e_add_usage_columns_to_steps.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""add_usage_columns_to_steps
|
||||
|
||||
Revision ID: 3e54e2fa2f7e
|
||||
Revises: a1b2c3d4e5f8
|
||||
Create Date: 2026-02-03 16:35:51.327031
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "3e54e2fa2f7e"
|
||||
down_revision: Union[str, None] = "a1b2c3d4e5f8"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("steps", sa.Column("model_handle", sa.String(), nullable=True))
|
||||
op.add_column("steps", sa.Column("cached_input_tokens", sa.Integer(), nullable=True))
|
||||
op.add_column("steps", sa.Column("cache_write_tokens", sa.Integer(), nullable=True))
|
||||
op.add_column("steps", sa.Column("reasoning_tokens", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("steps", "reasoning_tokens")
|
||||
op.drop_column("steps", "cache_write_tokens")
|
||||
op.drop_column("steps", "cached_input_tokens")
|
||||
op.drop_column("steps", "model_handle")
|
||||
@@ -9,6 +9,7 @@ Create Date: 2024-12-14 17:23:08.772554
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-09-19 10:58:19.658106
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-10-06 13:17:09.918439
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-11-11 19:16:00.000000
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-12-07 15:30:43.407495
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-11-11 21:16:00.000000
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Add model and model_settings columns to conversations table for model overrides
|
||||
|
||||
Revision ID: b2c3d4e5f6a8
|
||||
Revises: 3e54e2fa2f7e
|
||||
Create Date: 2026-02-23 02:50:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b2c3d4e5f6a8"
|
||||
down_revision: Union[str, None] = "3e54e2fa2f7e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("conversations", sa.Column("model", sa.String(), nullable=True))
|
||||
op.add_column("conversations", sa.Column("model_settings", sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("conversations", "model_settings")
|
||||
op.drop_column("conversations", "model")
|
||||
@@ -23,7 +23,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
def upgrade() -> None:
|
||||
# determine backfill value based on current pinecone settings
|
||||
try:
|
||||
from pinecone import IndexEmbed, PineconeAsyncio
|
||||
from pinecone import IndexEmbed, PineconeAsyncio # noqa: F401
|
||||
|
||||
pinecone_available = True
|
||||
except ImportError:
|
||||
|
||||
@@ -10,8 +10,6 @@ import json
|
||||
import os
|
||||
|
||||
# Add the app directory to path to import our crypto utils
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-11-07 15:43:59.446292
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-10-04 00:44:06.663817
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
412
conf.yaml
Normal file
412
conf.yaml
Normal file
@@ -0,0 +1,412 @@
|
||||
# Letta Configuration File
|
||||
# Place at ~/.letta/conf.yaml, ./conf.yaml, or set LETTA_CONFIG_PATH
|
||||
# Environment variables take precedence over config file values
|
||||
#
|
||||
# Top-level keys and their env var mappings:
|
||||
# letta: -> LETTA_*
|
||||
# model: -> Provider-prefixed (OPENAI_*, ANTHROPIC_*, etc.)
|
||||
# tool: -> Prefix-based (E2B_*, MCP_*, TOOL_*, etc.)
|
||||
# datadog: -> DD_*
|
||||
|
||||
letta:
|
||||
# =============================================================================
|
||||
# Core Settings (LETTA_*)
|
||||
# =============================================================================
|
||||
debug: false
|
||||
# environment: ""
|
||||
|
||||
# Default handles
|
||||
# default_llm_handle: ""
|
||||
# default_embedding_handle: ""
|
||||
|
||||
# SSE Streaming
|
||||
enable_keepalive: true
|
||||
keepalive_interval: 50.0
|
||||
enable_cancellation_aware_streaming: true
|
||||
|
||||
# =============================================================================
|
||||
# PostgreSQL (LETTA_PG_*)
|
||||
# =============================================================================
|
||||
pg:
|
||||
# db: ""
|
||||
# user: ""
|
||||
# password: ""
|
||||
# host: ""
|
||||
# port: ""
|
||||
# uri: ""
|
||||
pool_size: 25
|
||||
max_overflow: 10
|
||||
pool_timeout: 30
|
||||
pool_recycle: 1800
|
||||
echo: false
|
||||
|
||||
# Connection pool settings (LETTA_POOL_*)
|
||||
pool:
|
||||
pre_ping: true
|
||||
use_lifo: true
|
||||
|
||||
# Database settings (LETTA_DB_*)
|
||||
# db:
|
||||
# max_concurrent_sessions: ""
|
||||
|
||||
disable_sqlalchemy_pooling: true
|
||||
enable_db_pool_monitoring: true
|
||||
db_pool_monitoring_interval: 30
|
||||
|
||||
# =============================================================================
|
||||
# Redis (LETTA_REDIS_*)
|
||||
# =============================================================================
|
||||
redis:
|
||||
# host: ""
|
||||
port: 6379
|
||||
|
||||
# =============================================================================
|
||||
# Multi-Agent (LETTA_MULTI_AGENT_*)
|
||||
# =============================================================================
|
||||
multi_agent:
|
||||
send_message_max_retries: 3
|
||||
send_message_timeout: 1200
|
||||
concurrent_sends: 50
|
||||
|
||||
# =============================================================================
|
||||
# OTEL / Observability (LETTA_OTEL_*, LETTA_CLICKHOUSE_*)
|
||||
# =============================================================================
|
||||
otel:
|
||||
# exporter_otlp_endpoint: ""
|
||||
preferred_temporality: 1
|
||||
|
||||
clickhouse:
|
||||
# endpoint: ""
|
||||
database: otel
|
||||
username: default
|
||||
# password: ""
|
||||
|
||||
disable_tracing: false
|
||||
llm_api_logging: true
|
||||
track_last_agent_run: false
|
||||
track_errored_messages: true
|
||||
track_stop_reason: true
|
||||
track_agent_run: true
|
||||
track_provider_trace: true
|
||||
|
||||
# =============================================================================
|
||||
# Uvicorn (LETTA_UVICORN_*)
|
||||
# =============================================================================
|
||||
uvicorn:
|
||||
workers: 1
|
||||
reload: false
|
||||
timeout_keep_alive: 5
|
||||
|
||||
# Runtime settings
|
||||
use_uvloop: false
|
||||
use_granian: false
|
||||
sqlalchemy_tracing: false
|
||||
event_loop_threadpool_max_workers: 43
|
||||
|
||||
# =============================================================================
|
||||
# Experimental
|
||||
# =============================================================================
|
||||
use_vertex_structured_outputs_experimental: false
|
||||
use_asyncio_shield: true
|
||||
|
||||
# =============================================================================
|
||||
# Lettuce (LETTA_USE_LETTUCE_*)
|
||||
# =============================================================================
|
||||
use_lettuce_for_file_uploads: false
|
||||
|
||||
# =============================================================================
|
||||
# Batch Job Polling (LETTA_POLL_*, LETTA_BATCH_*)
|
||||
# =============================================================================
|
||||
enable_batch_job_polling: false
|
||||
poll_running_llm_batches_interval_seconds: 300
|
||||
poll_lock_retry_interval_seconds: 480
|
||||
batch_job_polling_lookback_weeks: 2
|
||||
# batch_job_polling_batch_size: ""
|
||||
|
||||
# =============================================================================
|
||||
# LLM Timeouts (LETTA_LLM_*)
|
||||
# =============================================================================
|
||||
llm:
|
||||
request_timeout_seconds: 60.0
|
||||
stream_timeout_seconds: 600.0
|
||||
|
||||
# =============================================================================
|
||||
# Pinecone (LETTA_PINECONE_*, LETTA_ENABLE_PINECONE, LETTA_UPSERT_PINECONE_INDICES)
|
||||
# =============================================================================
|
||||
enable_pinecone: false
|
||||
upsert_pinecone_indices: false
|
||||
pinecone:
|
||||
# api_key: ""
|
||||
source_index: sources
|
||||
agent_index: recall
|
||||
|
||||
# =============================================================================
|
||||
# Turbopuffer (LETTA_TPUF_*, LETTA_USE_TPUF, LETTA_EMBED_*)
|
||||
# =============================================================================
|
||||
use_tpuf: false
|
||||
embed_all_messages: false
|
||||
embed_tools: false
|
||||
tpuf:
|
||||
# api_key: ""
|
||||
region: gcp-us-central1
|
||||
|
||||
# =============================================================================
|
||||
# File Processing (LETTA_FILE_PROCESSING_*)
|
||||
# =============================================================================
|
||||
file_processing:
|
||||
timeout_minutes: 30
|
||||
timeout_error_message: "File processing timed out after {} minutes. Please try again."
|
||||
|
||||
# =============================================================================
|
||||
# Letta Client (LETTA_DEFAULT_*)
|
||||
# =============================================================================
|
||||
default_base_url: http://localhost:8283
|
||||
# default_token: ""
|
||||
|
||||
# =============================================================================
|
||||
# Agent Architecture
|
||||
# =============================================================================
|
||||
use_letta_v1_agent: false
|
||||
archival_memory_token_limit: 8192
|
||||
|
||||
# =============================================================================
|
||||
# Security
|
||||
# =============================================================================
|
||||
no_default_actor: false
|
||||
# encryption_key: ""
|
||||
|
||||
# =============================================================================
|
||||
# OCR
|
||||
# =============================================================================
|
||||
# mistral_api_key: ""
|
||||
|
||||
# =============================================================================
|
||||
# Summarizer (LETTA_SUMMARIZER_*)
|
||||
# =============================================================================
|
||||
summarizer:
|
||||
mode: partial_evict_message_buffer_mode
|
||||
message_buffer_limit: 60
|
||||
message_buffer_min: 15
|
||||
enable_summarization: true
|
||||
max_summarization_retries: 3
|
||||
partial_evict_summarizer_percentage: 0.30
|
||||
evict_all_messages: false
|
||||
max_summarizer_retries: 3
|
||||
memory_warning_threshold: 0.75
|
||||
send_memory_warning_message: false
|
||||
desired_memory_token_pressure: 0.3
|
||||
keep_last_n_messages: 0
|
||||
|
||||
# =============================================================================
|
||||
# Logging (LETTA_LOGGING_*)
|
||||
# =============================================================================
|
||||
logging:
|
||||
debug: false
|
||||
json_logging: false
|
||||
log_level: WARNING
|
||||
verbose_telemetry_logging: false
|
||||
|
||||
# =============================================================================
|
||||
# Telemetry (LETTA_TELEMETRY_*)
|
||||
# =============================================================================
|
||||
telemetry:
|
||||
enable_datadog: false
|
||||
provider_trace_backend: postgres
|
||||
socket_path: /var/run/telemetry/telemetry.sock
|
||||
provider_trace_pg_metadata_only: false
|
||||
# source: ""
|
||||
|
||||
# Datadog settings (LETTA_TELEMETRY_DATADOG_*)
|
||||
datadog:
|
||||
agent_host: localhost
|
||||
agent_port: 8126
|
||||
service_name: letta-server
|
||||
profiling_enabled: false
|
||||
profiling_memory_enabled: false
|
||||
profiling_heap_enabled: false
|
||||
# git_repository_url: ""
|
||||
# git_commit_sha: ""
|
||||
main_package: letta
|
||||
|
||||
# =============================================================================
|
||||
# Model Settings (-> OPENAI_*, ANTHROPIC_*, AWS_*, etc.)
|
||||
# =============================================================================
|
||||
model:
|
||||
# Global settings
|
||||
global_max_context_window_limit: 32000
|
||||
inner_thoughts_kwarg: thinking
|
||||
default_prompt_formatter: chatml
|
||||
|
||||
# OpenAI (-> OPENAI_*)
|
||||
openai:
|
||||
# api_key: ""
|
||||
api_base: https://api.openai.com/v1
|
||||
|
||||
# Anthropic (-> ANTHROPIC_*)
|
||||
anthropic:
|
||||
# api_key: ""
|
||||
max_retries: 3
|
||||
sonnet_1m: false
|
||||
|
||||
# Azure OpenAI (-> AZURE_*)
|
||||
azure:
|
||||
# api_key: ""
|
||||
# base_url: ""
|
||||
api_version: "2024-09-01-preview"
|
||||
|
||||
# Google Gemini (-> GEMINI_*)
|
||||
gemini:
|
||||
# api_key: ""
|
||||
base_url: https://generativelanguage.googleapis.com/
|
||||
force_minimum_thinking_budget: false
|
||||
max_retries: 5
|
||||
|
||||
# Google Vertex (-> GOOGLE_CLOUD_*)
|
||||
# google_cloud:
|
||||
# project: ""
|
||||
# location: ""
|
||||
|
||||
# AWS Bedrock (-> AWS_*, BEDROCK_*)
|
||||
aws:
|
||||
# access_key_id: ""
|
||||
# secret_access_key: ""
|
||||
default_region: us-east-1
|
||||
|
||||
bedrock:
|
||||
anthropic_version: bedrock-2023-05-31
|
||||
|
||||
# OpenRouter (-> OPENROUTER_*)
|
||||
# openrouter:
|
||||
# api_key: ""
|
||||
# referer: ""
|
||||
# title: ""
|
||||
# handle_base: ""
|
||||
|
||||
# Groq (-> GROQ_*)
|
||||
# groq:
|
||||
# api_key: ""
|
||||
|
||||
# Together (-> TOGETHER_*)
|
||||
# together:
|
||||
# api_key: ""
|
||||
|
||||
# DeepSeek (-> DEEPSEEK_*)
|
||||
# deepseek:
|
||||
# api_key: ""
|
||||
|
||||
# xAI/Grok (-> XAI_*)
|
||||
# xai:
|
||||
# api_key: ""
|
||||
|
||||
# Z.ai/ZhipuAI (-> ZAI_*)
|
||||
zai:
|
||||
# api_key: ""
|
||||
base_url: https://api.z.ai/api/paas/v4/
|
||||
|
||||
# MiniMax (-> MINIMAX_*)
|
||||
# minimax:
|
||||
# api_key: ""
|
||||
|
||||
# Ollama (-> OLLAMA_*)
|
||||
# ollama:
|
||||
# base_url: ""
|
||||
|
||||
# vLLM (-> VLLM_*)
|
||||
# vllm:
|
||||
# api_base: ""
|
||||
# handle_base: ""
|
||||
|
||||
# SGLang (-> SGLANG_*)
|
||||
# sglang:
|
||||
# api_base: ""
|
||||
# handle_base: ""
|
||||
|
||||
# LM Studio (-> LMSTUDIO_*)
|
||||
# lmstudio:
|
||||
# base_url: ""
|
||||
|
||||
# OpenLLM (-> OPENLLM_*)
|
||||
# openllm:
|
||||
# auth_type: ""
|
||||
# api_key: ""
|
||||
|
||||
# =============================================================================
|
||||
# Tool Settings (-> E2B_*, MCP_*, MODAL_*, TOOL_*, etc.)
|
||||
# =============================================================================
|
||||
tool:
|
||||
# E2B Sandbox (-> E2B_*)
|
||||
# e2b:
|
||||
# api_key: ""
|
||||
# sandbox_template_id: ""
|
||||
|
||||
# Modal Sandbox (-> MODAL_*)
|
||||
# modal:
|
||||
# token_id: ""
|
||||
# token_secret: ""
|
||||
|
||||
# Search Providers (-> TAVILY_*, EXA_*)
|
||||
# tavily:
|
||||
# api_key: ""
|
||||
|
||||
# exa:
|
||||
# api_key: ""
|
||||
|
||||
# Local Sandbox (-> TOOL_*)
|
||||
tool:
|
||||
# exec_dir: ""
|
||||
sandbox_timeout: 180
|
||||
# exec_venv_name: ""
|
||||
exec_autoreload_venv: true
|
||||
|
||||
# MCP (-> MCP_*)
|
||||
mcp:
|
||||
connect_to_server_timeout: 30.0
|
||||
list_tools_timeout: 30.0
|
||||
execute_tool_timeout: 60.0
|
||||
read_from_config: false
|
||||
disable_stdio: true
|
||||
|
||||
# =============================================================================
|
||||
# Datadog Agent Settings (-> DD_*)
|
||||
# =============================================================================
|
||||
# datadog:
|
||||
# site: ""
|
||||
# service: ""
|
||||
# version: ""
|
||||
#
|
||||
# trace:
|
||||
# enabled: false
|
||||
# agent_url: ""
|
||||
# health_metrics_enabled: false
|
||||
#
|
||||
# dogstatsd:
|
||||
# url: ""
|
||||
#
|
||||
# logs:
|
||||
# injection: false
|
||||
#
|
||||
# runtime:
|
||||
# metrics_enabled: false
|
||||
#
|
||||
# appsec:
|
||||
# enabled: false
|
||||
# sca_enabled: false
|
||||
#
|
||||
# iast:
|
||||
# enabled: false
|
||||
#
|
||||
# exception:
|
||||
# replay_enabled: false
|
||||
#
|
||||
# llmobs:
|
||||
# enabled: false
|
||||
# ml_app: ""
|
||||
#
|
||||
# instrumentation:
|
||||
# install_type: ""
|
||||
#
|
||||
# git:
|
||||
# repository_url: ""
|
||||
# commit_sha: ""
|
||||
#
|
||||
# main_package: ""
|
||||
@@ -1,6 +1,6 @@
|
||||
services:
|
||||
letta_db:
|
||||
image: ankane/pgvector:v0.5.1
|
||||
image: pgvector/pgvector:0.8.1-pg15
|
||||
networks:
|
||||
default:
|
||||
aliases:
|
||||
|
||||
3678
fern/openapi.json
3678
fern/openapi.json
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@ try:
|
||||
__version__ = version("letta")
|
||||
except PackageNotFoundError:
|
||||
# Fallback for development installations
|
||||
__version__ = "0.16.4"
|
||||
__version__ = "0.16.5"
|
||||
|
||||
if os.environ.get("LETTA_VERSION"):
|
||||
__version__ = os.environ["LETTA_VERSION"]
|
||||
@@ -16,26 +16,32 @@ try:
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
if settings.database_engine == DatabaseChoice.SQLITE:
|
||||
from letta.orm import sqlite_functions
|
||||
from letta.orm import sqlite_functions # noqa: F401
|
||||
except ImportError:
|
||||
# If sqlite_vec is not installed, it's fine for client usage
|
||||
pass
|
||||
|
||||
# # imports for easier access
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.letta_message import LettaMessage, LettaPing
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, BasicBlockMemory, ChatMemory, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.schemas.agent import AgentState as AgentState
|
||||
from letta.schemas.block import Block as Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig as EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus as JobStatus
|
||||
from letta.schemas.file import FileMetadata as FileMetadata
|
||||
from letta.schemas.job import Job as Job
|
||||
from letta.schemas.letta_message import LettaErrorMessage as LettaErrorMessage, LettaMessage as LettaMessage, LettaPing as LettaPing
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason as LettaStopReason
|
||||
from letta.schemas.llm_config import LLMConfig as LLMConfig
|
||||
from letta.schemas.memory import (
|
||||
ArchivalMemorySummary as ArchivalMemorySummary,
|
||||
BasicBlockMemory as BasicBlockMemory,
|
||||
ChatMemory as ChatMemory,
|
||||
Memory as Memory,
|
||||
RecallMemorySummary as RecallMemorySummary,
|
||||
)
|
||||
from letta.schemas.message import Message as Message
|
||||
from letta.schemas.organization import Organization as Organization
|
||||
from letta.schemas.passage import Passage as Passage
|
||||
from letta.schemas.source import Source as Source
|
||||
from letta.schemas.tool import Tool as Tool
|
||||
from letta.schemas.usage import LettaUsageStatistics as LettaUsageStatistics
|
||||
from letta.schemas.user import User as User
|
||||
|
||||
@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.schemas.enums import LLMCallType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ToolCall
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
@@ -24,6 +25,7 @@ class LettaLLMAdapter(ABC):
|
||||
self,
|
||||
llm_client: LLMClientBase,
|
||||
llm_config: LLMConfig,
|
||||
call_type: LLMCallType,
|
||||
agent_id: str | None = None,
|
||||
agent_tags: list[str] | None = None,
|
||||
run_id: str | None = None,
|
||||
@@ -32,6 +34,7 @@ class LettaLLMAdapter(ABC):
|
||||
) -> None:
|
||||
self.llm_client: LLMClientBase = llm_client
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
self.call_type: LLMCallType = call_type
|
||||
self.agent_id: str | None = agent_id
|
||||
self.agent_tags: list[str] | None = agent_tags
|
||||
self.run_id: str | None = run_id
|
||||
@@ -45,9 +48,14 @@ class LettaLLMAdapter(ABC):
|
||||
self.content: list[TextContent | ReasoningContent | RedactedReasoningContent] | None = None
|
||||
self.tool_call: ToolCall | None = None
|
||||
self.tool_calls: list[ToolCall] = []
|
||||
self.logprobs: ChoiceLogprobs | None = None
|
||||
# SGLang native endpoint data (for multi-turn RL training)
|
||||
self.output_ids: list[int] | None = None
|
||||
self.output_token_logprobs: list[list[float]] | None = None
|
||||
self.usage: LettaUsageStatistics = LettaUsageStatistics()
|
||||
self.telemetry_manager: TelemetryManager = TelemetryManager()
|
||||
self.llm_request_finish_timestamp_ns: int | None = None
|
||||
self._finish_reason: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
async def invoke_llm(
|
||||
@@ -85,6 +93,8 @@ class LettaLLMAdapter(ABC):
|
||||
Returns:
|
||||
str | None: The finish_reason if available, None otherwise
|
||||
"""
|
||||
if self._finish_reason is not None:
|
||||
return self._finish_reason
|
||||
if self.chat_completions_response and self.chat_completions_response.choices:
|
||||
return self.chat_completions_response.choices[0].finish_reason
|
||||
return None
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import AsyncGenerator
|
||||
|
||||
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.otel.tracing import log_attributes, log_event, safe_json_dumps, trace_method
|
||||
from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
@@ -66,7 +66,13 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
self.reasoning_content = [OmittedReasoningContent()]
|
||||
elif self.chat_completions_response.choices[0].message.content:
|
||||
# Reasoning placed into content for legacy reasons
|
||||
self.reasoning_content = [TextContent(text=self.chat_completions_response.choices[0].message.content)]
|
||||
# Carry thought_signature on TextContent when ReasoningContent doesn't exist to hold it
|
||||
self.reasoning_content = [
|
||||
TextContent(
|
||||
text=self.chat_completions_response.choices[0].message.content,
|
||||
signature=self.chat_completions_response.choices[0].message.reasoning_content_signature,
|
||||
)
|
||||
]
|
||||
else:
|
||||
# logger.info("No reasoning content found.")
|
||||
self.reasoning_content = None
|
||||
@@ -77,6 +83,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
else:
|
||||
self.tool_call = None
|
||||
|
||||
# Extract logprobs if present
|
||||
self.logprobs = self.chat_completions_response.choices[0].logprobs
|
||||
|
||||
# Extract usage statistics
|
||||
self.usage.step_count = 1
|
||||
self.usage.completion_tokens = self.chat_completions_response.usage.completion_tokens
|
||||
@@ -127,6 +136,7 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
call_type=self.call_type,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
|
||||
from letta.errors import LLMError
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
||||
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import LLMCallType, ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
@@ -30,13 +30,23 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
self,
|
||||
llm_client: LLMClientBase,
|
||||
llm_config: LLMConfig,
|
||||
call_type: LLMCallType,
|
||||
agent_id: str | None = None,
|
||||
agent_tags: list[str] | None = None,
|
||||
run_id: str | None = None,
|
||||
org_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
|
||||
super().__init__(
|
||||
llm_client,
|
||||
llm_config,
|
||||
call_type=call_type,
|
||||
agent_id=agent_id,
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
|
||||
|
||||
async def invoke_llm(
|
||||
@@ -88,11 +98,23 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
# Extract optional parameters
|
||||
# ttft_span = kwargs.get('ttft_span', None)
|
||||
|
||||
request_start_ns = get_utc_timestamp_ns()
|
||||
|
||||
# Start the streaming request (map provider errors to common LLMError types)
|
||||
try:
|
||||
stream = await self.llm_client.stream_async(request_data, self.llm_config)
|
||||
except Exception as e:
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
|
||||
await self.llm_client.log_provider_trace_async(
|
||||
request_data=request_data,
|
||||
response_json=None,
|
||||
llm_config=self.llm_config,
|
||||
latency_ms=latency_ms,
|
||||
error_msg=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
|
||||
|
||||
# Process the stream and yield chunks immediately for TTFT
|
||||
# Wrap in error handling to convert provider errors to common LLMError types
|
||||
@@ -101,7 +123,19 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
# Yield each chunk immediately as it arrives
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
|
||||
await self.llm_client.log_provider_trace_async(
|
||||
request_data=request_data,
|
||||
response_json=None,
|
||||
llm_config=self.llm_config,
|
||||
latency_ms=latency_ms,
|
||||
error_msg=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
if isinstance(e, LLMError):
|
||||
raise
|
||||
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
|
||||
|
||||
# After streaming completes, extract the accumulated data
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
@@ -109,7 +143,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
# Extract tool call from the interface
|
||||
try:
|
||||
self.tool_call = self.interface.get_tool_call_object()
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
# No tool call, handle upstream
|
||||
self.tool_call = None
|
||||
|
||||
@@ -183,6 +217,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
call_type=self.call_type,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
|
||||
515
letta/adapters/sglang_native_adapter.py
Normal file
515
letta/adapters/sglang_native_adapter.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
SGLang Native Adapter for multi-turn RL training.
|
||||
|
||||
This adapter uses SGLang's native /generate endpoint instead of the OpenAI-compatible
|
||||
endpoint to get token IDs and per-token logprobs, which are essential for proper
|
||||
multi-turn RL training with loss masking.
|
||||
|
||||
Uses HuggingFace tokenizer's apply_chat_template() for proper tool formatting.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from letta.adapters.simple_llm_request_adapter import SimpleLLMRequestAdapter
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.llm_api.sglang_native_client import SGLangNativeClient
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionTokenLogprob,
|
||||
Choice,
|
||||
ChoiceLogprobs,
|
||||
FunctionCall,
|
||||
Message as ChoiceMessage,
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Global tokenizer cache
|
||||
_tokenizer_cache: dict[str, Any] = {}
|
||||
|
||||
|
||||
class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
"""
|
||||
Adapter that uses SGLang's native /generate endpoint for multi-turn RL training.
|
||||
|
||||
Key differences from SimpleLLMRequestAdapter:
|
||||
- Uses /generate instead of /v1/chat/completions
|
||||
- Returns output_ids (token IDs) in addition to text
|
||||
- Returns output_token_logprobs with [logprob, token_id] pairs
|
||||
- Formats tools into prompt and parses tool calls from response
|
||||
|
||||
These are essential for building accurate loss masks in multi-turn training.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._sglang_client: Optional[SGLangNativeClient] = None
|
||||
self._tokenizer: Any = None
|
||||
|
||||
def _get_tokenizer(self) -> Any:
|
||||
"""Get or create tokenizer for the model."""
|
||||
global _tokenizer_cache
|
||||
|
||||
# Get model name from llm_config
|
||||
model_name = self.llm_config.model
|
||||
if not model_name:
|
||||
logger.warning("No model name in llm_config, cannot load tokenizer")
|
||||
return None
|
||||
|
||||
# Check cache
|
||||
if model_name in _tokenizer_cache:
|
||||
return _tokenizer_cache[model_name]
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
logger.info(f"Loading tokenizer for model: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
_tokenizer_cache[model_name] = tokenizer
|
||||
return tokenizer
|
||||
except ImportError:
|
||||
logger.warning("transformers not installed, falling back to manual formatting")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tokenizer: {e}, falling back to manual formatting")
|
||||
return None
|
||||
|
||||
def _get_sglang_client(self) -> SGLangNativeClient:
|
||||
"""Get or create SGLang native client."""
|
||||
if self._sglang_client is None:
|
||||
# Get base URL from llm_config, removing /v1 suffix if present
|
||||
base_url = self.llm_config.model_endpoint or ""
|
||||
# SGLang local instances typically don't need API key
|
||||
self._sglang_client = SGLangNativeClient(
|
||||
base_url=base_url,
|
||||
api_key=None,
|
||||
)
|
||||
return self._sglang_client
|
||||
|
||||
def _format_tools_for_prompt(self, tools: list) -> str:
|
||||
"""
|
||||
Format tools in Qwen3 chat template format for the system prompt.
|
||||
|
||||
This matches the exact format produced by Qwen3's tokenizer.apply_chat_template()
|
||||
with tools parameter.
|
||||
"""
|
||||
if not tools:
|
||||
return ""
|
||||
|
||||
# Format each tool as JSON (matching Qwen3 template exactly)
|
||||
tool_jsons = []
|
||||
for tool in tools:
|
||||
# Handle both dict and object formats
|
||||
if isinstance(tool, dict):
|
||||
# Already in OpenAI format
|
||||
tool_jsons.append(json.dumps(tool))
|
||||
else:
|
||||
# Convert object to dict
|
||||
tool_dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": getattr(getattr(tool, "function", tool), "name", ""),
|
||||
"description": getattr(getattr(tool, "function", tool), "description", ""),
|
||||
"parameters": getattr(getattr(tool, "function", tool), "parameters", {}),
|
||||
},
|
||||
}
|
||||
tool_jsons.append(json.dumps(tool_dict))
|
||||
|
||||
# Use exact Qwen3 format
|
||||
tools_section = (
|
||||
"\n\n# Tools\n\n"
|
||||
"You may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n"
|
||||
"<tools>\n" + "\n".join(tool_jsons) + "\n"
|
||||
"</tools>\n\n"
|
||||
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
|
||||
"<tool_call>\n"
|
||||
'{"name": <function-name>, "arguments": <args-json-object>}\n'
|
||||
"</tool_call>"
|
||||
)
|
||||
|
||||
return tools_section
|
||||
|
||||
def _convert_messages_to_openai_format(self, messages: list) -> list[dict]:
|
||||
"""Convert Letta Message objects to OpenAI-style message dicts."""
|
||||
openai_messages = []
|
||||
|
||||
for msg in messages:
|
||||
# Handle both dict and Pydantic Message objects
|
||||
if hasattr(msg, "role"):
|
||||
role = msg.role
|
||||
content = msg.content if hasattr(msg, "content") else ""
|
||||
# Handle content that might be a list of content parts
|
||||
if isinstance(content, list):
|
||||
content = " ".join([c.text if hasattr(c, "text") else str(c) for c in content])
|
||||
elif content is None:
|
||||
content = ""
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
tool_call_id = getattr(msg, "tool_call_id", None)
|
||||
name = getattr(msg, "name", None)
|
||||
else:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
tool_calls = msg.get("tool_calls", None)
|
||||
tool_call_id = msg.get("tool_call_id", None)
|
||||
name = msg.get("name", None)
|
||||
|
||||
openai_msg = {"role": role, "content": content}
|
||||
|
||||
if tool_calls:
|
||||
# Convert tool calls to OpenAI format
|
||||
openai_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
if hasattr(tc, "function"):
|
||||
tc_dict = {
|
||||
"id": getattr(tc, "id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
if isinstance(tc.function.arguments, str)
|
||||
else json.dumps(tc.function.arguments),
|
||||
},
|
||||
}
|
||||
else:
|
||||
tc_dict = {
|
||||
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||
"type": "function",
|
||||
"function": tc.get("function", {}),
|
||||
}
|
||||
openai_tool_calls.append(tc_dict)
|
||||
openai_msg["tool_calls"] = openai_tool_calls
|
||||
|
||||
if tool_call_id:
|
||||
openai_msg["tool_call_id"] = tool_call_id
|
||||
|
||||
if name and role == "tool":
|
||||
openai_msg["name"] = name
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
|
||||
return openai_messages
|
||||
|
||||
def _convert_tools_to_openai_format(self, tools: list) -> list[dict]:
|
||||
"""Convert tools to OpenAI format for tokenizer."""
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
# Already a dict, ensure it's in the right format
|
||||
if "function" in tool:
|
||||
openai_tools.append(tool)
|
||||
else:
|
||||
# Might be the function directly
|
||||
openai_tools.append({"type": "function", "function": tool})
|
||||
else:
|
||||
# Convert object to dict
|
||||
func = getattr(tool, "function", tool)
|
||||
tool_dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": getattr(func, "name", ""),
|
||||
"description": getattr(func, "description", ""),
|
||||
"parameters": getattr(func, "parameters", {}),
|
||||
},
|
||||
}
|
||||
openai_tools.append(tool_dict)
|
||||
return openai_tools
|
||||
|
||||
def _format_messages_to_text(self, messages: list, tools: list) -> str:
|
||||
"""
|
||||
Format messages to text using tokenizer's apply_chat_template if available.
|
||||
|
||||
Falls back to manual formatting if tokenizer is not available.
|
||||
"""
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
if tokenizer is not None:
|
||||
# Use tokenizer's apply_chat_template for proper formatting
|
||||
openai_messages = self._convert_messages_to_openai_format(messages)
|
||||
openai_tools = self._convert_tools_to_openai_format(tools) if tools else None
|
||||
|
||||
try:
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
openai_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=openai_tools,
|
||||
)
|
||||
logger.debug(f"Formatted prompt using tokenizer ({len(formatted)} chars)")
|
||||
return formatted
|
||||
except Exception as e:
|
||||
logger.warning(f"apply_chat_template failed: {e}, falling back to manual formatting")
|
||||
|
||||
# Fallback to manual formatting
|
||||
return self._format_messages_to_text_manual(messages, tools)
|
||||
|
||||
def _format_messages_to_text_manual(self, messages: list, tools: list) -> str:
|
||||
"""Manual fallback formatting for when tokenizer is not available."""
|
||||
formatted_parts = []
|
||||
tools_section = self._format_tools_for_prompt(tools)
|
||||
|
||||
for msg in messages:
|
||||
# Handle both dict and Pydantic Message objects
|
||||
if hasattr(msg, "role"):
|
||||
role = msg.role
|
||||
content = msg.content if hasattr(msg, "content") else ""
|
||||
if isinstance(content, list):
|
||||
content = " ".join([c.text if hasattr(c, "text") else str(c) for c in content])
|
||||
elif content is None:
|
||||
content = ""
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
else:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
tool_calls = msg.get("tool_calls", None)
|
||||
|
||||
if role == "system":
|
||||
system_content = content + tools_section if tools_section else content
|
||||
formatted_parts.append(f"<|im_start|>system\n{system_content}<|im_end|>")
|
||||
tools_section = ""
|
||||
elif role == "user":
|
||||
formatted_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
|
||||
elif role == "assistant":
|
||||
if tool_calls:
|
||||
tc_parts = []
|
||||
for tc in tool_calls:
|
||||
if hasattr(tc, "function"):
|
||||
tc_name = tc.function.name
|
||||
tc_args = tc.function.arguments
|
||||
else:
|
||||
tc_name = tc.get("function", {}).get("name", "")
|
||||
tc_args = tc.get("function", {}).get("arguments", "{}")
|
||||
|
||||
if isinstance(tc_args, str):
|
||||
try:
|
||||
tc_args = json.loads(tc_args)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tc_parts.append(f'<tool_call>\n{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n</tool_call>')
|
||||
|
||||
assistant_content = content + "\n" + "\n".join(tc_parts) if content else "\n".join(tc_parts)
|
||||
formatted_parts.append(f"<|im_start|>assistant\n{assistant_content}<|im_end|>")
|
||||
elif content:
|
||||
formatted_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||||
elif role == "tool":
|
||||
formatted_parts.append(f"<|im_start|>user\n<tool_response>\n{content}\n</tool_response><|im_end|>")
|
||||
|
||||
formatted_parts.append("<|im_start|>assistant\n")
|
||||
return "\n".join(formatted_parts)
|
||||
|
||||
def _parse_tool_calls(self, text: str) -> list[ToolCall]:
|
||||
"""
|
||||
Parse tool calls from response text.
|
||||
|
||||
Looks for patterns like:
|
||||
<tool_call>
|
||||
{"name": "tool_name", "arguments": {...}}
|
||||
</tool_call>
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
# Find all tool_call blocks
|
||||
pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
tc_data = json.loads(match)
|
||||
name = tc_data.get("name", "")
|
||||
arguments = tc_data.get("arguments", {})
|
||||
|
||||
if isinstance(arguments, dict):
|
||||
arguments = json.dumps(arguments)
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse tool call JSON: {e}")
|
||||
continue
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _extract_content_without_tool_calls(self, text: str) -> str:
|
||||
"""Extract content from response, removing tool_call blocks."""
|
||||
# Remove tool_call blocks
|
||||
cleaned = re.sub(r"<tool_call>.*?</tool_call>", "", text, flags=re.DOTALL)
|
||||
# Clean up whitespace
|
||||
cleaned = cleaned.strip()
|
||||
return cleaned
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
request_data: dict,
|
||||
messages: list,
|
||||
tools: list,
|
||||
use_assistant_message: bool,
|
||||
requires_approval_tools: list[str] = [],
|
||||
step_id: str | None = None,
|
||||
actor: str | None = None,
|
||||
) -> AsyncGenerator[LettaMessage | None, None]:
|
||||
"""
|
||||
Execute LLM request using SGLang native endpoint.
|
||||
|
||||
This method:
|
||||
1. Formats messages and tools to text using chat template
|
||||
2. Calls SGLang native /generate endpoint
|
||||
3. Extracts output_ids and output_token_logprobs
|
||||
4. Parses tool calls from response
|
||||
5. Converts response to standard format
|
||||
"""
|
||||
self.request_data = request_data
|
||||
|
||||
# Get sampling params from request_data
|
||||
sampling_params = {
|
||||
"temperature": request_data.get("temperature", 0.7),
|
||||
"max_new_tokens": request_data.get("max_tokens", 4096),
|
||||
"top_p": request_data.get("top_p", 0.9),
|
||||
}
|
||||
|
||||
# Format messages to text (includes tools in prompt)
|
||||
text_input = self._format_messages_to_text(messages, tools)
|
||||
|
||||
# Call SGLang native endpoint
|
||||
client = self._get_sglang_client()
|
||||
|
||||
try:
|
||||
response = await client.generate(
|
||||
text=text_input,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SGLang native endpoint error: {e}")
|
||||
raise
|
||||
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
|
||||
# Store native response data
|
||||
self.response_data = response
|
||||
|
||||
# Extract SGLang native data
|
||||
self.output_ids = response.get("output_ids")
|
||||
# output_token_logprobs is inside meta_info
|
||||
meta_info = response.get("meta_info", {})
|
||||
self.output_token_logprobs = meta_info.get("output_token_logprobs")
|
||||
|
||||
# Extract text response
|
||||
text_response = response.get("text", "")
|
||||
|
||||
# Remove trailing end token if present
|
||||
if text_response.endswith("<|im_end|>"):
|
||||
text_response = text_response[:-10]
|
||||
|
||||
# Parse tool calls from response
|
||||
parsed_tool_calls = self._parse_tool_calls(text_response)
|
||||
|
||||
# Extract content (text without tool_call blocks)
|
||||
content_text = self._extract_content_without_tool_calls(text_response)
|
||||
|
||||
# Determine finish reason
|
||||
meta_info = response.get("meta_info", {})
|
||||
finish_reason_info = meta_info.get("finish_reason", {})
|
||||
if isinstance(finish_reason_info, dict):
|
||||
finish_reason = finish_reason_info.get("type", "stop")
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
# If we have tool calls, set finish_reason to tool_calls
|
||||
if parsed_tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
# Convert to standard ChatCompletionResponse format for compatibility
|
||||
# Build logprobs in OpenAI format from SGLang format
|
||||
logprobs_content = None
|
||||
if self.output_token_logprobs:
|
||||
logprobs_content = []
|
||||
for i, lp_data in enumerate(self.output_token_logprobs):
|
||||
# SGLang format: [logprob, token_id, top_logprob]
|
||||
logprob = lp_data[0] if len(lp_data) > 0 else 0.0
|
||||
token_id = lp_data[1] if len(lp_data) > 1 else 0
|
||||
logprobs_content.append(
|
||||
ChatCompletionTokenLogprob(
|
||||
token=str(token_id),
|
||||
logprob=logprob,
|
||||
bytes=None,
|
||||
top_logprobs=[],
|
||||
)
|
||||
)
|
||||
|
||||
choice_logprobs = ChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
# Build chat completion response
|
||||
prompt_tokens = meta_info.get("prompt_tokens", 0)
|
||||
completion_tokens = len(self.output_ids) if self.output_ids else 0
|
||||
|
||||
self.chat_completions_response = ChatCompletionResponse(
|
||||
id=meta_info.get("id", "sglang-native"),
|
||||
created=int(time.time()),
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason=finish_reason,
|
||||
index=0,
|
||||
message=ChoiceMessage(
|
||||
role="assistant",
|
||||
content=content_text if content_text else None,
|
||||
tool_calls=parsed_tool_calls if parsed_tool_calls else None,
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
],
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
# Extract content
|
||||
if content_text:
|
||||
self.content = [TextContent(text=content_text)]
|
||||
else:
|
||||
self.content = None
|
||||
|
||||
# No reasoning content from native endpoint
|
||||
self.reasoning_content = None
|
||||
|
||||
# Set tool calls
|
||||
self.tool_calls = parsed_tool_calls
|
||||
self.tool_call = parsed_tool_calls[0] if parsed_tool_calls else None
|
||||
|
||||
# Set logprobs
|
||||
self.logprobs = choice_logprobs
|
||||
|
||||
# Extract usage statistics
|
||||
self.usage.step_count = 1
|
||||
self.usage.completion_tokens = completion_tokens
|
||||
self.usage.prompt_tokens = prompt_tokens
|
||||
self.usage.total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
self.log_provider_trace(step_id=step_id, actor=actor)
|
||||
|
||||
logger.info(
|
||||
f"SGLang native response: {len(self.output_ids or [])} tokens, "
|
||||
f"{len(self.output_token_logprobs or [])} logprobs, "
|
||||
f"{len(parsed_tool_calls)} tool calls"
|
||||
)
|
||||
|
||||
yield None
|
||||
return
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
|
||||
from letta.errors import LLMError
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.schemas.enums import LLMCallType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
|
||||
from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens
|
||||
@@ -45,7 +47,7 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
call_type="agent_step",
|
||||
call_type=LLMCallType.agent_step,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
@@ -53,7 +55,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
||||
try:
|
||||
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
|
||||
except Exception as e:
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
if isinstance(e, LLMError):
|
||||
raise
|
||||
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
|
||||
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
|
||||
@@ -80,7 +84,12 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
||||
if self.chat_completions_response.choices[0].message.content:
|
||||
# NOTE: big difference - 'content' goes into 'content'
|
||||
# Reasoning placed into content for legacy reasons
|
||||
self.content = [TextContent(text=self.chat_completions_response.choices[0].message.content)]
|
||||
# Carry thought_signature on TextContent when ReasoningContent doesn't exist to hold it
|
||||
# (e.g. Gemini 2.5 Flash with include_thoughts=False still returns thought_signature)
|
||||
orphan_sig = (
|
||||
self.chat_completions_response.choices[0].message.reasoning_content_signature if not self.reasoning_content else None
|
||||
)
|
||||
self.content = [TextContent(text=self.chat_completions_response.choices[0].message.content, signature=orphan_sig)]
|
||||
else:
|
||||
self.content = None
|
||||
|
||||
@@ -93,6 +102,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
||||
self.tool_calls = list(tool_calls)
|
||||
self.tool_call = self.tool_calls[0] if self.tool_calls else None
|
||||
|
||||
# Extract logprobs if present
|
||||
self.logprobs = self.chat_completions_response.choices[0].logprobs
|
||||
|
||||
# Extract usage statistics
|
||||
self.usage.step_count = 1
|
||||
self.usage.completion_tokens = self.chat_completions_response.usage.completion_tokens
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||
from letta.errors import LLMError
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -70,6 +70,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
# Store request data
|
||||
self.request_data = request_data
|
||||
|
||||
# Track request start time for latency calculation
|
||||
request_start_ns = get_utc_timestamp_ns()
|
||||
|
||||
# Get cancellation event for this run to enable graceful cancellation (before branching)
|
||||
cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None
|
||||
|
||||
@@ -138,7 +141,19 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
else:
|
||||
stream = await self.llm_client.stream_async(request_data, self.llm_config)
|
||||
except Exception as e:
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
|
||||
await self.llm_client.log_provider_trace_async(
|
||||
request_data=request_data,
|
||||
response_json=None,
|
||||
llm_config=self.llm_config,
|
||||
latency_ms=latency_ms,
|
||||
error_msg=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
if isinstance(e, LLMError):
|
||||
raise
|
||||
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
|
||||
|
||||
# Process the stream and yield chunks immediately for TTFT
|
||||
try:
|
||||
@@ -146,8 +161,19 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
# Yield each chunk immediately as it arrives
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
# Map provider-specific errors during streaming to common LLMError types
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
|
||||
await self.llm_client.log_provider_trace_async(
|
||||
request_data=request_data,
|
||||
response_json=None,
|
||||
llm_config=self.llm_config,
|
||||
latency_ms=latency_ms,
|
||||
error_msg=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
if isinstance(e, LLMError):
|
||||
raise
|
||||
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
|
||||
|
||||
# After streaming completes, extract the accumulated data
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
@@ -172,6 +198,22 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
# Store any additional data from the interface
|
||||
self.message_id = self.interface.letta_message_id
|
||||
|
||||
# Populate finish_reason for downstream continuation logic.
|
||||
# In Responses streaming, max_output_tokens is expressed via incomplete_details.reason.
|
||||
if hasattr(self.interface, "final_response") and self.interface.final_response is not None:
|
||||
resp = self.interface.final_response
|
||||
incomplete_details = getattr(resp, "incomplete_details", None)
|
||||
incomplete_reason = getattr(incomplete_details, "reason", None) if incomplete_details else None
|
||||
if incomplete_reason == "max_output_tokens":
|
||||
self._finish_reason = "length"
|
||||
elif incomplete_reason == "content_filter":
|
||||
self._finish_reason = "content_filter"
|
||||
elif incomplete_reason is not None:
|
||||
# Unknown incomplete reason — preserve it as-is for diagnostics
|
||||
self._finish_reason = incomplete_reason
|
||||
elif getattr(resp, "status", None) == "completed":
|
||||
self._finish_reason = "stop"
|
||||
|
||||
# Log request and response data
|
||||
self.log_provider_trace(step_id=step_id, actor=actor)
|
||||
|
||||
@@ -232,6 +274,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
call_type=self.call_type,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
|
||||
@@ -123,32 +123,17 @@ class BaseAgent(ABC):
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
|
||||
# extract the dynamic section that includes memory blocks, tool rules, and directories
|
||||
# this avoids timestamp comparison issues
|
||||
def extract_dynamic_section(text):
|
||||
start_marker = "</base_instructions>"
|
||||
end_marker = "<memory_metadata>"
|
||||
|
||||
start_idx = text.find(start_marker)
|
||||
end_idx = text.find(end_marker)
|
||||
|
||||
if start_idx != -1 and end_idx != -1:
|
||||
return text[start_idx:end_idx]
|
||||
return text # fallback to full text if markers not found
|
||||
|
||||
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
|
||||
|
||||
# generate just the memory string with current state for comparison
|
||||
# generate memory string with current state for comparison
|
||||
curr_memory_str = agent_state.memory.compile(
|
||||
tool_usage_rules=tool_constraint_block,
|
||||
sources=agent_state.sources,
|
||||
max_files_open=agent_state.max_files_open,
|
||||
llm_config=agent_state.llm_config,
|
||||
)
|
||||
new_dynamic_section = extract_dynamic_section(curr_memory_str)
|
||||
|
||||
# compare just the dynamic sections (memory blocks, tool rules, directories)
|
||||
if curr_dynamic_section == new_dynamic_section:
|
||||
system_prompt_changed = agent_state.system not in curr_system_message_text
|
||||
memory_changed = curr_memory_str not in curr_system_message_text
|
||||
if (not system_prompt_changed) and (not memory_changed):
|
||||
logger.debug(
|
||||
f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
@@ -183,7 +168,7 @@ class BaseAgent(ABC):
|
||||
actor=self.actor,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
return [new_system_message, *in_context_messages[1:]]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
|
||||
@@ -25,6 +25,11 @@ class BaseAgentV2(ABC):
|
||||
self.actor = actor
|
||||
self.logger = get_logger(agent_state.id)
|
||||
|
||||
@property
|
||||
def agent_id(self) -> str:
|
||||
"""Return the agent ID for backward compatibility with code expecting self.agent_id."""
|
||||
return self.agent_state.id
|
||||
|
||||
@abstractmethod
|
||||
async def build_request(
|
||||
self,
|
||||
@@ -46,6 +51,7 @@ class BaseAgentV2(ABC):
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list["ClientToolSchema"] | None = None,
|
||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
@@ -53,6 +59,7 @@ class BaseAgentV2(ABC):
|
||||
Args:
|
||||
client_tools: Optional list of client-side tools. When called, execution pauses
|
||||
for client to provide tool returns.
|
||||
include_compaction_messages: Not used in V2, but accepted for API compatibility.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -66,8 +73,9 @@ class BaseAgentV2(ABC):
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list["ClientToolSchema"] | None = None,
|
||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
||||
"""
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
@@ -78,5 +86,6 @@ class BaseAgentV2(ABC):
|
||||
Args:
|
||||
client_tools: Optional list of client-side tools. When called, execution pauses
|
||||
for client to provide tool returns.
|
||||
include_compaction_messages: Not used in V2, but accepted for API compatibility.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -8,7 +8,7 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.prompts.gpt_system import get_system_text
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import LLMCallType, MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
@@ -79,7 +79,7 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
content=[TextContent(text=get_system_text("summary_system_prompt"))],
|
||||
)
|
||||
messages = await convert_message_creates_to_messages(
|
||||
message_creates=[system_message_create] + input_messages,
|
||||
message_creates=[system_message_create, *input_messages],
|
||||
agent_id=self.agent_id,
|
||||
timezone=agent_state.timezone,
|
||||
run_id=None, # TODO: add this
|
||||
@@ -92,7 +92,7 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
telemetry_manager=TelemetryManager(),
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
call_type="summarization",
|
||||
call_type=LLMCallType.summarization,
|
||||
)
|
||||
response_data = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
|
||||
response = await llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from letta.errors import PendingApprovalError
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
from letta.errors import LettaError, PendingApprovalError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
@@ -233,6 +235,11 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
current_in_context_messages = [system_message]
|
||||
else:
|
||||
# Default mode: load messages from agent_state.message_ids
|
||||
if not agent_state.message_ids:
|
||||
raise LettaError(
|
||||
message=f"Agent {agent_state.id} has no in-context messages. "
|
||||
"This typically means the agent's system message was not initialized correctly.",
|
||||
)
|
||||
if agent_state.message_buffer_autoclear:
|
||||
# If autoclear is enabled, only include the most recent system message (usually at index 0)
|
||||
current_in_context_messages = [
|
||||
@@ -242,6 +249,14 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
# Otherwise, include the full list of messages by ID for context
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
# Convert ToolReturnCreate to ApprovalCreate for unified processing
|
||||
if input_messages[0].type == "tool_return":
|
||||
tool_return_msg = input_messages[0]
|
||||
input_messages = [
|
||||
ApprovalCreate(approvals=tool_return_msg.tool_returns),
|
||||
*input_messages[1:],
|
||||
]
|
||||
|
||||
# Check for approval-related message validation
|
||||
if input_messages[0].type == "approval":
|
||||
# User is trying to send an approval response
|
||||
@@ -254,12 +269,31 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
for msg in reversed(recent_messages):
|
||||
if msg.role == "tool" and validate_persisted_tool_call_ids(msg, input_messages[0]):
|
||||
logger.info(
|
||||
f"Idempotency check: Found matching tool return in recent history. "
|
||||
f"Idempotency check: Found matching tool return in recent in-context history. "
|
||||
f"tool_returns={msg.tool_returns}, approval_response.approvals={input_messages[0].approvals}"
|
||||
)
|
||||
approval_already_processed = True
|
||||
break
|
||||
|
||||
# If not found in context and summarization just happened, check full history
|
||||
non_system_summary_messages = [
|
||||
m for m in current_in_context_messages if m.role not in (MessageRole.system, MessageRole.summary)
|
||||
]
|
||||
if not approval_already_processed and len(non_system_summary_messages) == 0:
|
||||
last_tool_messages = await message_manager.list_messages(
|
||||
actor=actor,
|
||||
agent_id=agent_state.id,
|
||||
roles=[MessageRole.tool],
|
||||
limit=1,
|
||||
ascending=False, # Most recent first
|
||||
)
|
||||
if len(last_tool_messages) == 1 and validate_persisted_tool_call_ids(last_tool_messages[0], input_messages[0]):
|
||||
logger.info(
|
||||
f"Idempotency check: Found matching tool return in full history (post-compaction). "
|
||||
f"tool_returns={last_tool_messages[0].tool_returns}, approval_response.approvals={input_messages[0].approvals}"
|
||||
)
|
||||
approval_already_processed = True
|
||||
|
||||
if approval_already_processed:
|
||||
# Approval already handled, just process follow-up messages if any or manually inject keep-alive message
|
||||
keep_alive_messages = input_messages[1:] or [
|
||||
|
||||
@@ -13,14 +13,13 @@ from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||||
from letta.agents.helpers import (
|
||||
_build_rule_violation_result,
|
||||
_create_letta_response,
|
||||
_load_last_function_response,
|
||||
_pop_heartbeat,
|
||||
_prepare_in_context_messages_no_persist_async,
|
||||
_safe_load_tool_call_str,
|
||||
generate_step_id,
|
||||
)
|
||||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.errors import ContextWindowExceededError, LLMError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import AsyncTimer, get_utc_time, get_utc_timestamp_ns, ns_to_ms
|
||||
from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
|
||||
@@ -35,7 +34,7 @@ from letta.otel.context import get_ctx_attributes
|
||||
from letta.otel.metric_registry import MetricRegistry
|
||||
from letta.otel.tracing import log_event, trace_method, tracer
|
||||
from letta.schemas.agent import AgentState, UpdateAgent
|
||||
from letta.schemas.enums import JobStatus, ProviderType, StepStatus, ToolType
|
||||
from letta.schemas.enums import JobStatus, LLMCallType, ProviderType, StepStatus, ToolType
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
@@ -49,7 +48,6 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
UsageStatisticsCompletionTokenDetails,
|
||||
UsageStatisticsPromptTokenDetails,
|
||||
)
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.schemas.step import StepProgression
|
||||
from letta.schemas.step_metrics import StepMetrics
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -294,6 +292,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
step_progression = StepProgression.START
|
||||
caught_exception = None
|
||||
should_continue = False
|
||||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||||
|
||||
@@ -312,6 +311,7 @@ class LettaAgent(BaseAgent):
|
||||
step_id=step_id,
|
||||
project_id=agent_state.project_id,
|
||||
status=StepStatus.PENDING,
|
||||
model_handle=agent_state.llm_config.handle,
|
||||
)
|
||||
# Only use step_id in messages if step was actually created
|
||||
effective_step_id = step_id if logged_step else None
|
||||
@@ -370,8 +370,12 @@ class LettaAgent(BaseAgent):
|
||||
elif response.choices[0].message.omitted_reasoning_content:
|
||||
reasoning = [OmittedReasoningContent()]
|
||||
elif response.choices[0].message.content:
|
||||
# Carry thought_signature on TextContent when ReasoningContent doesn't exist to hold it
|
||||
reasoning = [
|
||||
TextContent(text=response.choices[0].message.content)
|
||||
TextContent(
|
||||
text=response.choices[0].message.content,
|
||||
signature=response.choices[0].message.reasoning_content_signature,
|
||||
)
|
||||
] # reasoning placed into content for legacy reasons
|
||||
else:
|
||||
self.logger.info("No reasoning content found.")
|
||||
@@ -409,24 +413,6 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace=ProviderTrace(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
||||
@@ -453,6 +439,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
caught_exception = e
|
||||
# Handle any unexpected errors during step processing
|
||||
self.logger.error(f"Error during step processing: {e}")
|
||||
job_update_metadata = {"error": str(e)}
|
||||
@@ -499,8 +486,8 @@ class LettaAgent(BaseAgent):
|
||||
await self.step_manager.update_step_error_async(
|
||||
actor=self.actor,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
||||
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||
error_traceback=traceback.format_exc(),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
@@ -646,6 +633,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
step_progression = StepProgression.START
|
||||
caught_exception = None
|
||||
should_continue = False
|
||||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||||
|
||||
@@ -664,6 +652,7 @@ class LettaAgent(BaseAgent):
|
||||
step_id=step_id,
|
||||
project_id=agent_state.project_id,
|
||||
status=StepStatus.PENDING,
|
||||
model_handle=agent_state.llm_config.handle,
|
||||
)
|
||||
# Only use step_id in messages if step was actually created
|
||||
effective_step_id = step_id if logged_step else None
|
||||
@@ -720,8 +709,12 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
]
|
||||
elif response.choices[0].message.content:
|
||||
# Carry thought_signature on TextContent when ReasoningContent doesn't exist to hold it
|
||||
reasoning = [
|
||||
TextContent(text=response.choices[0].message.content)
|
||||
TextContent(
|
||||
text=response.choices[0].message.content,
|
||||
signature=response.choices[0].message.reasoning_content_signature,
|
||||
)
|
||||
] # reasoning placed into content for legacy reasons
|
||||
elif response.choices[0].message.omitted_reasoning_content:
|
||||
reasoning = [OmittedReasoningContent()]
|
||||
@@ -762,24 +755,6 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace=ProviderTrace(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
step_progression = StepProgression.FINISHED
|
||||
|
||||
@@ -795,6 +770,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
caught_exception = e
|
||||
# Handle any unexpected errors during step processing
|
||||
self.logger.error(f"Error during step processing: {e}")
|
||||
job_update_metadata = {"error": str(e)}
|
||||
@@ -837,8 +813,8 @@ class LettaAgent(BaseAgent):
|
||||
await self.step_manager.update_step_error_async(
|
||||
actor=self.actor,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
||||
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||
error_traceback=traceback.format_exc(),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
@@ -1000,6 +976,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
step_progression = StepProgression.START
|
||||
caught_exception = None
|
||||
should_continue = False
|
||||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||||
|
||||
@@ -1018,6 +995,7 @@ class LettaAgent(BaseAgent):
|
||||
step_id=step_id,
|
||||
project_id=agent_state.project_id,
|
||||
status=StepStatus.PENDING,
|
||||
model_handle=agent_state.llm_config.handle,
|
||||
)
|
||||
# Only use step_id in messages if step was actually created
|
||||
effective_step_id = step_id if logged_step else None
|
||||
@@ -1152,6 +1130,8 @@ class LettaAgent(BaseAgent):
|
||||
"output_tokens": interface.output_tokens,
|
||||
},
|
||||
},
|
||||
llm_config=agent_state.llm_config,
|
||||
latency_ms=int(llm_request_ms),
|
||||
)
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
tool_call,
|
||||
@@ -1220,41 +1200,6 @@ class LettaAgent(BaseAgent):
|
||||
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
||||
# log_event("agent.stream.llm_response.processed") # [4^]
|
||||
|
||||
# Log LLM Trace
|
||||
# We are piecing together the streamed response here.
|
||||
# Content here does not match the actual response schema as streams come in chunks.
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace=ProviderTrace(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
},
|
||||
},
|
||||
step_id=step_id,
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
if persisted_messages[-1].role != "approval":
|
||||
# yields tool response as this is handled from Letta and not the response from the LLM provider
|
||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||
@@ -1287,6 +1232,7 @@ class LettaAgent(BaseAgent):
|
||||
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
|
||||
|
||||
except Exception as e:
|
||||
caught_exception = e
|
||||
# Handle any unexpected errors during step processing
|
||||
self.logger.error(f"Error during step processing: {e}")
|
||||
job_update_metadata = {"error": str(e)}
|
||||
@@ -1333,8 +1279,8 @@ class LettaAgent(BaseAgent):
|
||||
await self.step_manager.update_step_error_async(
|
||||
actor=self.actor,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
||||
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||
error_traceback=traceback.format_exc(),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
@@ -1481,7 +1427,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
step_id=step_metrics.id,
|
||||
call_type="agent_step",
|
||||
call_type=LLMCallType.agent_step,
|
||||
)
|
||||
response = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
|
||||
|
||||
@@ -1554,13 +1500,13 @@ class LettaAgent(BaseAgent):
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
step_id=step_id,
|
||||
call_type="agent_step",
|
||||
call_type=LLMCallType.agent_step,
|
||||
)
|
||||
|
||||
# Attempt LLM request with telemetry wrapper
|
||||
return (
|
||||
request_data,
|
||||
await llm_client.stream_async_with_telemetry(request_data, agent_state.llm_config),
|
||||
await llm_client.stream_async(request_data, agent_state.llm_config),
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
valid_tool_names,
|
||||
@@ -1605,8 +1551,10 @@ class LettaAgent(BaseAgent):
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
elif isinstance(e, LLMError):
|
||||
raise
|
||||
else:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
raise llm_client.handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
@trace_method
|
||||
async def _rebuild_context_window(
|
||||
@@ -1626,7 +1574,7 @@ class LettaAgent(BaseAgent):
|
||||
self.logger.warning(
|
||||
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
force=True,
|
||||
@@ -1639,7 +1587,7 @@ class LettaAgent(BaseAgent):
|
||||
self.logger.info(
|
||||
f"Total tokens {total_tokens} does not exceed configured max tokens {llm_config.context_window}, passing summarizing w/o force."
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
run_id=run_id,
|
||||
@@ -1659,7 +1607,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
||||
message_ids = agent_state.message_ids
|
||||
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=self.actor)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages, new_letta_messages=[], force=True
|
||||
)
|
||||
return await self.agent_manager.update_message_ids_async(
|
||||
|
||||
@@ -217,7 +217,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
|
||||
if batch_items:
|
||||
log_event(name="bulk_create_batch_items")
|
||||
batch_items_persisted = await self.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=self.actor)
|
||||
await self.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=self.actor)
|
||||
|
||||
log_event(name="return_batch_response")
|
||||
return LettaBatchResponse(
|
||||
|
||||
@@ -9,7 +9,6 @@ from letta.adapters.letta_llm_adapter import LettaLLMAdapter
|
||||
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
|
||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||
from letta.agents.base_agent_v2 import BaseAgentV2
|
||||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||||
from letta.agents.helpers import (
|
||||
_build_rule_violation_result,
|
||||
_load_last_function_response,
|
||||
@@ -20,7 +19,7 @@ from letta.agents.helpers import (
|
||||
generate_step_id,
|
||||
)
|
||||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
|
||||
from letta.errors import ContextWindowExceededError, LLMError
|
||||
from letta.errors import ContextWindowExceededError, InsufficientCreditsError, LLMError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
|
||||
from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
|
||||
@@ -31,7 +30,7 @@ from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event, trace_method, tracer
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.schemas.agent import AgentState, UpdateAgent
|
||||
from letta.schemas.enums import AgentType, MessageStreamStatus, RunStatus, StepStatus
|
||||
from letta.schemas.enums import AgentType, LLMCallType, MessageStreamStatus, RunStatus, StepStatus
|
||||
from letta.schemas.letta_message import LettaMessage, MessageType
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
@@ -58,6 +57,7 @@ from letta.server.rest_api.utils import (
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.archive_manager import ArchiveManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.credit_verification_service import CreditVerificationService
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
@@ -67,10 +67,10 @@ from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.settings import model_settings, settings, summarizer_settings
|
||||
from letta.settings import settings, summarizer_settings
|
||||
from letta.system import package_function_response
|
||||
from letta.types import JsonDict
|
||||
from letta.utils import log_telemetry, safe_create_task, united_diff, validate_function_response
|
||||
from letta.utils import log_telemetry, safe_create_task, safe_create_task_with_return, united_diff, validate_function_response
|
||||
|
||||
|
||||
class LettaAgentV2(BaseAgentV2):
|
||||
@@ -106,6 +106,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
self.passage_manager = PassageManager()
|
||||
self.step_manager = StepManager()
|
||||
self.telemetry_manager = TelemetryManager()
|
||||
self.credit_verification_service = CreditVerificationService()
|
||||
|
||||
## TODO: Expand to more
|
||||
# if summarizer_settings.enable_summarization and model_settings.openai_api_key:
|
||||
@@ -158,6 +159,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
llm_adapter=LettaLLMRequestAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
call_type=LLMCallType.agent_step,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
@@ -181,6 +184,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
@@ -193,6 +197,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
client_tools: Optional list of client-side tools (not used in V2, for API compatibility)
|
||||
include_compaction_messages: Not used in V2, but accepted for API compatibility.
|
||||
|
||||
Returns:
|
||||
LettaResponse: Complete response with all messages and metadata
|
||||
@@ -205,15 +210,25 @@ class LettaAgentV2(BaseAgentV2):
|
||||
)
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
response_letta_messages = []
|
||||
credit_task = None
|
||||
for i in range(max_steps):
|
||||
remaining_turns = max_steps - i - 1
|
||||
|
||||
# Await credit check from previous iteration before running next step
|
||||
if credit_task is not None:
|
||||
if not await credit_task:
|
||||
self.should_continue = False
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.insufficient_credits)
|
||||
break
|
||||
credit_task = None
|
||||
|
||||
response = self._step(
|
||||
messages=in_context_messages + self.response_messages,
|
||||
input_messages_to_persist=input_messages_to_persist,
|
||||
llm_adapter=LettaLLMRequestAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
call_type=LLMCallType.agent_step,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
@@ -233,6 +248,9 @@ class LettaAgentV2(BaseAgentV2):
|
||||
if not self.should_continue:
|
||||
break
|
||||
|
||||
# Fire credit check to run in parallel with loop overhead / next step setup
|
||||
credit_task = safe_create_task_with_return(self._check_credits())
|
||||
|
||||
input_messages_to_persist = []
|
||||
|
||||
# Rebuild context window after stepping
|
||||
@@ -271,6 +289,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
@@ -289,6 +308,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
client_tools: Optional list of client-side tools (not used in V2, for API compatibility)
|
||||
include_compaction_messages: Not used in V2, but accepted for API compatibility.
|
||||
|
||||
Yields:
|
||||
str: JSON-formatted SSE data chunks for each completed step
|
||||
@@ -301,6 +321,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
llm_adapter = LettaLLMStreamAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
call_type=LLMCallType.agent_step,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
@@ -311,6 +332,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
llm_adapter = LettaLLMRequestAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
call_type=LLMCallType.agent_step,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
@@ -323,7 +345,16 @@ class LettaAgentV2(BaseAgentV2):
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
)
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
credit_task = None
|
||||
for i in range(max_steps):
|
||||
# Await credit check from previous iteration before running next step
|
||||
if credit_task is not None:
|
||||
if not await credit_task:
|
||||
self.should_continue = False
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.insufficient_credits)
|
||||
break
|
||||
credit_task = None
|
||||
|
||||
response = self._step(
|
||||
messages=in_context_messages + self.response_messages,
|
||||
input_messages_to_persist=input_messages_to_persist,
|
||||
@@ -342,6 +373,9 @@ class LettaAgentV2(BaseAgentV2):
|
||||
if not self.should_continue:
|
||||
break
|
||||
|
||||
# Fire credit check to run in parallel with loop overhead / next step setup
|
||||
credit_task = safe_create_task_with_return(self._check_credits())
|
||||
|
||||
input_messages_to_persist = []
|
||||
|
||||
if self.stop_reason is None:
|
||||
@@ -420,8 +454,9 @@ class LettaAgentV2(BaseAgentV2):
|
||||
raise AssertionError("run_id is required when enforce_run_id_set is True")
|
||||
|
||||
step_progression = StepProgression.START
|
||||
caught_exception = None
|
||||
# TODO(@caren): clean this up
|
||||
tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
|
||||
tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, _step_start_ns, step_metrics = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -580,6 +615,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
)
|
||||
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
|
||||
except Exception as e:
|
||||
caught_exception = e
|
||||
self.logger.warning(f"Error during step processing: {e}")
|
||||
self.job_update_metadata = {"error": str(e)}
|
||||
|
||||
@@ -615,8 +651,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
await self.step_manager.update_step_error_async(
|
||||
actor=self.actor,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
||||
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||
error_traceback=traceback.format_exc(),
|
||||
stop_reason=self.stop_reason,
|
||||
)
|
||||
@@ -667,6 +703,17 @@ class LettaAgentV2(BaseAgentV2):
|
||||
self.last_function_response = None
|
||||
self.response_messages = []
|
||||
|
||||
async def _check_credits(self) -> bool:
|
||||
"""Check if the organization still has credits. Returns True if OK or not configured."""
|
||||
try:
|
||||
await self.credit_verification_service.verify_credits(self.actor.organization_id, self.agent_state.id)
|
||||
return True
|
||||
except InsufficientCreditsError:
|
||||
self.logger.warning(
|
||||
f"Insufficient credits for organization {self.actor.organization_id}, agent {self.agent_state.id}, stopping agent loop"
|
||||
)
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
async def _check_run_cancellation(self, run_id) -> bool:
|
||||
try:
|
||||
@@ -678,20 +725,37 @@ class LettaAgentV2(BaseAgentV2):
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
async def _refresh_messages(self, in_context_messages: list[Message]):
|
||||
num_messages = await self.message_manager.size_async(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
in_context_messages = await self._rebuild_memory(
|
||||
in_context_messages,
|
||||
num_messages=num_messages,
|
||||
num_archival_memories=num_archival_memories,
|
||||
)
|
||||
async def _refresh_messages(self, in_context_messages: list[Message], force_system_prompt_refresh: bool = False):
|
||||
"""Refresh in-context messages.
|
||||
|
||||
This performs two tasks:
|
||||
1) Rebuild the *system prompt* only if the memory/tool-rules/directories section has changed.
|
||||
This avoids rebuilding the system prompt on every step due to dynamic metadata (e.g. message counts),
|
||||
which can bust prefix caching.
|
||||
2) Scrub inner thoughts from messages.
|
||||
|
||||
Args:
|
||||
in_context_messages: Current in-context messages
|
||||
force_system_prompt_refresh: If True, forces evaluation of whether the system prompt needs to be rebuilt.
|
||||
(The rebuild will still be skipped if memory/tool-rules/directories haven't changed.)
|
||||
|
||||
Returns:
|
||||
Refreshed in-context messages.
|
||||
"""
|
||||
# Only rebuild when explicitly forced (e.g., after compaction).
|
||||
# Normal turns should not trigger system prompt recompilation.
|
||||
if force_system_prompt_refresh:
|
||||
try:
|
||||
in_context_messages = await self._rebuild_memory(
|
||||
in_context_messages,
|
||||
num_messages=None,
|
||||
num_archival_memories=None,
|
||||
force=True,
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# Always scrub inner thoughts regardless of system prompt refresh
|
||||
in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, self.agent_state.llm_config)
|
||||
return in_context_messages
|
||||
|
||||
@@ -699,8 +763,9 @@ class LettaAgentV2(BaseAgentV2):
|
||||
async def _rebuild_memory(
|
||||
self,
|
||||
in_context_messages: list[Message],
|
||||
num_messages: int,
|
||||
num_archival_memories: int,
|
||||
num_messages: int | None,
|
||||
num_archival_memories: int | None,
|
||||
force: bool = False,
|
||||
):
|
||||
agent_state = await self.agent_manager.refresh_memory_async(agent_state=self.agent_state, actor=self.actor)
|
||||
|
||||
@@ -721,49 +786,26 @@ class LettaAgentV2(BaseAgentV2):
|
||||
else:
|
||||
archive_tags = None
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
|
||||
# Extract the memory section that includes <memory_blocks>, tool rules, and directories.
|
||||
# This avoids timestamp comparison issues in <memory_metadata>, which is dynamic.
|
||||
def extract_memory_section(text: str) -> str:
|
||||
# Primary pattern: everything from <memory_blocks> up to <memory_metadata>
|
||||
mem_start = text.find("<memory_blocks>")
|
||||
meta_start = text.find("<memory_metadata>")
|
||||
if mem_start != -1:
|
||||
if meta_start != -1 and meta_start > mem_start:
|
||||
return text[mem_start:meta_start]
|
||||
return text[mem_start:]
|
||||
|
||||
# Fallback pattern used in some legacy prompts: between </base_instructions> and <memory_metadata>
|
||||
base_end = text.find("</base_instructions>")
|
||||
if base_end != -1:
|
||||
if meta_start != -1 and meta_start > base_end:
|
||||
return text[base_end + len("</base_instructions>") : meta_start]
|
||||
return text[base_end + len("</base_instructions>") :]
|
||||
|
||||
# Last resort: return full text
|
||||
return text
|
||||
|
||||
curr_memory_section = extract_memory_section(curr_system_message_text)
|
||||
|
||||
# refresh files
|
||||
agent_state = await self.agent_manager.refresh_file_blocks(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
# generate just the memory string with current state for comparison
|
||||
# generate memory string with current state
|
||||
curr_memory_str = agent_state.memory.compile(
|
||||
tool_usage_rules=tool_constraint_block,
|
||||
sources=agent_state.sources,
|
||||
max_files_open=agent_state.max_files_open,
|
||||
llm_config=agent_state.llm_config,
|
||||
)
|
||||
new_memory_section = extract_memory_section(curr_memory_str)
|
||||
|
||||
# compare just the memory sections (memory blocks, tool rules, directories)
|
||||
if curr_memory_section.strip() == new_memory_section.strip():
|
||||
# Skip rebuild unless explicitly forced and unless system/memory content actually changed.
|
||||
system_prompt_changed = agent_state.system not in curr_system_message_text
|
||||
memory_changed = curr_memory_str not in curr_system_message_text
|
||||
if (not force) and (not system_prompt_changed) and (not memory_changed):
|
||||
self.logger.debug(
|
||||
f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
f"Memory, sources, and system prompt haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
@@ -793,7 +835,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
new_system_message = await self.message_manager.update_message_by_id_async(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
return [new_system_message, *in_context_messages[1:]]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
@@ -864,6 +906,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
step_id=step_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
status=StepStatus.PENDING,
|
||||
model_handle=self.agent_state.llm_config.handle,
|
||||
)
|
||||
|
||||
# Also create step metrics early and update at the end of the step
|
||||
@@ -1279,7 +1322,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
self.logger.warning(
|
||||
f"Total tokens {total_tokens} exceeds configured max tokens {self.agent_state.llm_config.context_window}, forcefully clearing message history."
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
force=True,
|
||||
@@ -1292,7 +1335,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
self.logger.info(
|
||||
f"Total tokens {total_tokens} does not exceed configured max tokens {self.agent_state.llm_config.context_window}, passing summarizing w/o force."
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
run_id=run_id,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import openai
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
@@ -250,7 +253,6 @@ class VoiceAgent(BaseAgent):
|
||||
agent_state=agent_state,
|
||||
)
|
||||
tool_result = tool_execution_result.func_return
|
||||
success_flag = tool_execution_result.success_flag
|
||||
|
||||
# 3. Provide function_call response back into the conversation
|
||||
# TODO: fix this tool format
|
||||
@@ -292,7 +294,7 @@ class VoiceAgent(BaseAgent):
|
||||
new_letta_messages = await self.message_manager.create_many_messages_async(letta_message_db_queue, actor=self.actor)
|
||||
|
||||
# TODO: Make this more general and configurable, less brittle
|
||||
new_in_context_messages, updated = await summarizer.summarize(
|
||||
new_in_context_messages, _updated = await summarizer.summarize(
|
||||
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
from letta.agents.helpers import _create_letta_response, serialize_message_history
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
@@ -89,7 +94,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
current_in_context_messages, new_in_context_messages, stop_reason, usage = await super()._step(
|
||||
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages
|
||||
)
|
||||
self.agent_manager.set_in_context_messages(
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Annotated, Optional
|
||||
import typer
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
232
letta/config_file.py
Normal file
232
letta/config_file.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Letta Configuration File Support
|
||||
|
||||
Loads hierarchical YAML config and maps it to environment variables.
|
||||
|
||||
Supported top-level keys and their env var prefixes:
|
||||
letta: -> LETTA_*
|
||||
model: -> * (provider-prefixed: OPENAI_*, ANTHROPIC_*, etc.)
|
||||
tool: -> * (prefix-based: E2B_*, MCP_*, TOOL_*, etc.)
|
||||
datadog: -> DD_*
|
||||
|
||||
Config file format:
|
||||
letta:
|
||||
telemetry:
|
||||
enable_datadog: true
|
||||
pg:
|
||||
host: localhost
|
||||
model:
|
||||
openai:
|
||||
api_key: sk-xxx
|
||||
anthropic:
|
||||
api_key: sk-yyy
|
||||
tool:
|
||||
e2b:
|
||||
api_key: xxx
|
||||
mcp:
|
||||
disable_stdio: true
|
||||
datadog:
|
||||
site: us5.datadoghq.com
|
||||
service: memgpt-server
|
||||
|
||||
This maps to environment variables:
|
||||
LETTA_TELEMETRY_ENABLE_DATADOG=true
|
||||
LETTA_PG_HOST=localhost
|
||||
OPENAI_API_KEY=sk-xxx
|
||||
ANTHROPIC_API_KEY=sk-yyy
|
||||
E2B_API_KEY=xxx
|
||||
MCP_DISABLE_STDIO=true
|
||||
DD_SITE=us5.datadoghq.com
|
||||
DD_SERVICE=memgpt-server
|
||||
|
||||
Config file locations (in order of precedence):
|
||||
1. ~/.letta/conf.yaml
|
||||
2. ./conf.yaml
|
||||
3. LETTA_CONFIG_PATH environment variable
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
# Config file locations
|
||||
DEFAULT_USER_CONFIG = Path.home() / ".letta" / "conf.yaml"
|
||||
DEFAULT_PROJECT_CONFIG = Path.cwd() / "conf.yaml"
|
||||
|
||||
|
||||
def load_config_file(config_path: str | Path | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Optional explicit path to config file
|
||||
|
||||
Returns:
|
||||
Loaded config dict, or empty dict if no config found
|
||||
"""
|
||||
paths_to_check = []
|
||||
|
||||
# Check in order of precedence (lowest to highest)
|
||||
if DEFAULT_USER_CONFIG.exists():
|
||||
paths_to_check.append(DEFAULT_USER_CONFIG)
|
||||
|
||||
if DEFAULT_PROJECT_CONFIG.exists():
|
||||
paths_to_check.append(DEFAULT_PROJECT_CONFIG)
|
||||
|
||||
# Environment variable override
|
||||
env_path = os.environ.get("LETTA_CONFIG_PATH")
|
||||
if env_path and Path(env_path).exists():
|
||||
paths_to_check.append(Path(env_path))
|
||||
|
||||
# Explicit path has highest precedence
|
||||
if config_path:
|
||||
p = Path(config_path)
|
||||
if p.exists():
|
||||
paths_to_check.append(p)
|
||||
|
||||
# Merge configs (later files override earlier)
|
||||
config: dict[str, Any] = {}
|
||||
for path in paths_to_check:
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
file_config = yaml.safe_load(f)
|
||||
if file_config:
|
||||
config = _deep_merge(config, file_config)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
"""Deep merge two dicts, override values take precedence."""
|
||||
result = base.copy()
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _flatten_with_prefix(d: dict, prefix: str, env_vars: dict[str, str]) -> None:
|
||||
"""Flatten a dict with a given prefix."""
|
||||
for key, value in d.items():
|
||||
env_key = f"{prefix}_{key}".upper() if prefix else key.upper()
|
||||
if isinstance(value, dict):
|
||||
_flatten_with_prefix(value, env_key, env_vars)
|
||||
elif value is not None:
|
||||
if isinstance(value, bool):
|
||||
env_vars[env_key] = str(value).lower()
|
||||
else:
|
||||
env_vars[env_key] = str(value)
|
||||
|
||||
|
||||
def _flatten_model_settings(d: dict, env_vars: dict[str, str]) -> None:
|
||||
"""
|
||||
Flatten model settings where nested keys become prefixes.
|
||||
|
||||
model:
|
||||
openai:
|
||||
api_key: xxx -> OPENAI_API_KEY
|
||||
api_base: yyy -> OPENAI_API_BASE
|
||||
anthropic:
|
||||
api_key: zzz -> ANTHROPIC_API_KEY
|
||||
global_max_context_window_limit: 32000 -> GLOBAL_MAX_CONTEXT_WINDOW_LIMIT
|
||||
"""
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
# Nested provider config: openai.api_key -> OPENAI_API_KEY
|
||||
_flatten_with_prefix(value, key.upper(), env_vars)
|
||||
elif value is not None:
|
||||
# Top-level model setting
|
||||
env_key = key.upper()
|
||||
if isinstance(value, bool):
|
||||
env_vars[env_key] = str(value).lower()
|
||||
else:
|
||||
env_vars[env_key] = str(value)
|
||||
|
||||
|
||||
def _flatten_tool_settings(d: dict, env_vars: dict[str, str]) -> None:
|
||||
"""
|
||||
Flatten tool settings where nested keys become prefixes.
|
||||
|
||||
tool:
|
||||
e2b:
|
||||
api_key: xxx -> E2B_API_KEY
|
||||
sandbox_template_id: y -> E2B_SANDBOX_TEMPLATE_ID
|
||||
mcp:
|
||||
disable_stdio: true -> MCP_DISABLE_STDIO
|
||||
tool_sandbox_timeout: 180 -> TOOL_SANDBOX_TIMEOUT
|
||||
"""
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
# Nested tool config: e2b.api_key -> E2B_API_KEY
|
||||
_flatten_with_prefix(value, key.upper(), env_vars)
|
||||
elif value is not None:
|
||||
# Top-level tool setting
|
||||
env_key = key.upper()
|
||||
if isinstance(value, bool):
|
||||
env_vars[env_key] = str(value).lower()
|
||||
else:
|
||||
env_vars[env_key] = str(value)
|
||||
|
||||
|
||||
def config_to_env_vars(config: dict[str, Any]) -> dict[str, str]:
|
||||
"""
|
||||
Convert hierarchical config to flat environment variables.
|
||||
|
||||
Supports multiple top-level keys with different prefix behaviors:
|
||||
- letta: -> LETTA_* prefix
|
||||
- model: -> provider-prefixed (OPENAI_*, ANTHROPIC_*, etc.)
|
||||
- tool: -> prefix-based (E2B_*, MCP_*, TOOL_*, etc.)
|
||||
- datadog: -> DD_* prefix
|
||||
|
||||
Args:
|
||||
config: Hierarchical config dict
|
||||
|
||||
Returns:
|
||||
Dict of environment variable name -> value
|
||||
"""
|
||||
env_vars: dict[str, str] = {}
|
||||
|
||||
# Handle 'letta' section with LETTA_ prefix
|
||||
if "letta" in config:
|
||||
_flatten_with_prefix(config["letta"], "LETTA", env_vars)
|
||||
|
||||
# Handle 'model' section (provider-prefixed env vars)
|
||||
if "model" in config:
|
||||
_flatten_model_settings(config["model"], env_vars)
|
||||
|
||||
# Handle 'tool' section (prefix-based env vars)
|
||||
if "tool" in config:
|
||||
_flatten_tool_settings(config["tool"], env_vars)
|
||||
|
||||
# Handle 'datadog' section with DD_ prefix
|
||||
if "datadog" in config:
|
||||
_flatten_with_prefix(config["datadog"], "DD", env_vars)
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def apply_config_to_env(config_path: str | Path | None = None) -> None:
|
||||
"""
|
||||
Load config file and apply values to environment variables.
|
||||
|
||||
Environment variables already set take precedence over config file values.
|
||||
|
||||
Args:
|
||||
config_path: Optional explicit path to config file
|
||||
"""
|
||||
config = load_config_file(config_path)
|
||||
if not config:
|
||||
return
|
||||
|
||||
env_vars = config_to_env_vars(config)
|
||||
|
||||
for key, value in env_vars.items():
|
||||
# Only set if not already in environment (env vars take precedence)
|
||||
if key not in os.environ:
|
||||
os.environ[key] = value
|
||||
@@ -78,7 +78,7 @@ DEFAULT_CONTEXT_WINDOW = 32000
|
||||
|
||||
# Summarization trigger threshold (multiplier of context_window limit)
|
||||
# Summarization triggers when step usage > context_window * SUMMARIZATION_TRIGGER_MULTIPLIER
|
||||
SUMMARIZATION_TRIGGER_MULTIPLIER = 1.0
|
||||
SUMMARIZATION_TRIGGER_MULTIPLIER = 0.9 # using instead of 1.0 to avoid "too many tokens in prompt" fallbacks
|
||||
|
||||
# number of concurrent embedding requests to sent
|
||||
EMBEDDING_BATCH_SIZE = 200
|
||||
@@ -252,8 +252,11 @@ LLM_MAX_CONTEXT_WINDOW = {
|
||||
"deepseek-chat": 64000,
|
||||
"deepseek-reasoner": 64000,
|
||||
# glm (Z.AI)
|
||||
"glm-4.6": 200000,
|
||||
"glm-4.5": 128000,
|
||||
"glm-4.6": 200000,
|
||||
"glm-4.7": 200000,
|
||||
"glm-5": 200000,
|
||||
"glm-5-code": 200000,
|
||||
## OpenAI models: https://platform.openai.com/docs/models/overview
|
||||
# gpt-5
|
||||
"gpt-5": 272000,
|
||||
@@ -383,6 +386,7 @@ LLM_MAX_CONTEXT_WINDOW = {
|
||||
"gemini-2.5-computer-use-preview-10-2025": 1048576,
|
||||
# gemini 3
|
||||
"gemini-3-pro-preview": 1048576,
|
||||
"gemini-3.1-pro-preview": 1048576,
|
||||
"gemini-3-flash-preview": 1048576,
|
||||
# gemini latest aliases
|
||||
"gemini-flash-latest": 1048576,
|
||||
@@ -457,10 +461,18 @@ REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
|
||||
CONVERSATION_LOCK_PREFIX = "conversation:lock:"
|
||||
CONVERSATION_LOCK_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
# Memory repo locks - prevents concurrent modifications to git-based memory
|
||||
MEMORY_REPO_LOCK_PREFIX = "memory_repo:lock:"
|
||||
MEMORY_REPO_LOCK_TTL_SECONDS = 60 # 1 minute (git operations should be fast)
|
||||
|
||||
# TODO: This is temporary, eventually use token-based eviction
|
||||
# File based controls
|
||||
DEFAULT_MAX_FILES_OPEN = 5
|
||||
DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT: int = 50000
|
||||
# Max values for file controls (int32 limit to match database INTEGER type)
|
||||
MAX_INT32: int = 2147483647
|
||||
MAX_PER_FILE_VIEW_WINDOW_CHAR_LIMIT: int = MAX_INT32
|
||||
MAX_FILES_OPEN_LIMIT: int = 1000 # Practical limit - no agent needs 1000+ files open
|
||||
|
||||
GET_PROVIDERS_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterator, List, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.user import User
|
||||
|
||||
import typer
|
||||
|
||||
@@ -143,7 +146,13 @@ async def load_data(connector: DataConnector, source: Source, passage_manager: P
|
||||
|
||||
|
||||
class DirectoryConnector(DataConnector):
|
||||
def __init__(self, input_files: List[str] = None, input_directory: str = None, recursive: bool = False, extensions: List[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
input_files: List[str] | None = None,
|
||||
input_directory: str | None = None,
|
||||
recursive: bool = False,
|
||||
extensions: List[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Connector for reading text data from a directory of files.
|
||||
|
||||
|
||||
@@ -2,8 +2,16 @@ import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from letta.constants import CONVERSATION_LOCK_PREFIX, CONVERSATION_LOCK_TTL_SECONDS, REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
|
||||
from letta.errors import ConversationBusyError
|
||||
from letta.constants import (
|
||||
CONVERSATION_LOCK_PREFIX,
|
||||
CONVERSATION_LOCK_TTL_SECONDS,
|
||||
MEMORY_REPO_LOCK_PREFIX,
|
||||
MEMORY_REPO_LOCK_TTL_SECONDS,
|
||||
REDIS_EXCLUDE,
|
||||
REDIS_INCLUDE,
|
||||
REDIS_SET_DEFAULT_VAL,
|
||||
)
|
||||
from letta.errors import ConversationBusyError, MemoryRepoBusyError
|
||||
from letta.log import get_logger
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -141,7 +149,7 @@ class AsyncRedisClient:
|
||||
try:
|
||||
client = await self.get_client()
|
||||
return await client.get(key)
|
||||
except:
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@with_retry()
|
||||
@@ -230,6 +238,64 @@ class AsyncRedisClient:
|
||||
logger.warning(f"Failed to release conversation lock for conversation {conversation_id}: {e}")
|
||||
return False
|
||||
|
||||
async def acquire_memory_repo_lock(
|
||||
self,
|
||||
agent_id: str,
|
||||
token: str,
|
||||
) -> Optional["Lock"]:
|
||||
"""
|
||||
Acquire a distributed lock for a memory repository.
|
||||
|
||||
Prevents concurrent modifications to an agent's git-based memory.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID whose memory is being modified
|
||||
token: Unique identifier for the lock holder (for debugging/tracing)
|
||||
|
||||
Returns:
|
||||
Lock object if acquired, raises MemoryRepoBusyError if in use
|
||||
"""
|
||||
if Lock is None:
|
||||
return None
|
||||
client = await self.get_client()
|
||||
lock_key = f"{MEMORY_REPO_LOCK_PREFIX}{agent_id}"
|
||||
lock = Lock(
|
||||
client,
|
||||
lock_key,
|
||||
timeout=MEMORY_REPO_LOCK_TTL_SECONDS,
|
||||
blocking=False,
|
||||
thread_local=False,
|
||||
raise_on_release_error=False,
|
||||
)
|
||||
|
||||
if await lock.acquire(token=token):
|
||||
return lock
|
||||
|
||||
lock_holder_token = await client.get(lock_key)
|
||||
raise MemoryRepoBusyError(
|
||||
agent_id=agent_id,
|
||||
lock_holder_token=lock_holder_token,
|
||||
)
|
||||
|
||||
async def release_memory_repo_lock(self, agent_id: str) -> bool:
|
||||
"""
|
||||
Release a memory repo lock by agent_id.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID to release the lock for
|
||||
|
||||
Returns:
|
||||
True if lock was released, False if release failed
|
||||
"""
|
||||
try:
|
||||
client = await self.get_client()
|
||||
lock_key = f"{MEMORY_REPO_LOCK_PREFIX}{agent_id}"
|
||||
await client.delete(lock_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release memory repo lock for agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
@with_retry()
|
||||
async def exists(self, *keys: str) -> int:
|
||||
"""Check if keys exist."""
|
||||
@@ -254,7 +320,7 @@ class AsyncRedisClient:
|
||||
client = await self.get_client()
|
||||
result = await client.smismember(key, values)
|
||||
return result if isinstance(values, list) else result[0]
|
||||
except:
|
||||
except Exception:
|
||||
return [0] * len(values) if isinstance(values, list) else 0
|
||||
|
||||
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
|
||||
@@ -464,6 +530,16 @@ class NoopAsyncRedisClient(AsyncRedisClient):
|
||||
async def release_conversation_lock(self, conversation_id: str) -> bool:
|
||||
return False
|
||||
|
||||
async def acquire_memory_repo_lock(
|
||||
self,
|
||||
agent_id: str,
|
||||
token: str,
|
||||
) -> Optional["Lock"]:
|
||||
return None
|
||||
|
||||
async def release_memory_repo_lock(self, agent_id: str) -> bool:
|
||||
return False
|
||||
|
||||
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
@@ -20,6 +21,7 @@ class ErrorCode(Enum):
|
||||
TIMEOUT = "TIMEOUT"
|
||||
CONFLICT = "CONFLICT"
|
||||
EXPIRED = "EXPIRED"
|
||||
PAYMENT_REQUIRED = "PAYMENT_REQUIRED"
|
||||
|
||||
|
||||
class LettaError(Exception):
|
||||
@@ -91,6 +93,22 @@ class ConversationBusyError(LettaError):
|
||||
super().__init__(message=message, code=code, details=details)
|
||||
|
||||
|
||||
class MemoryRepoBusyError(LettaError):
|
||||
"""Error raised when attempting to modify memory while another operation is in progress."""
|
||||
|
||||
def __init__(self, agent_id: str, lock_holder_token: Optional[str] = None):
|
||||
self.agent_id = agent_id
|
||||
self.lock_holder_token = lock_holder_token
|
||||
message = "Cannot modify memory: Another operation is currently in progress for this agent's memory. Please wait for the current operation to complete."
|
||||
code = ErrorCode.CONFLICT
|
||||
details = {
|
||||
"error_code": "MEMORY_REPO_BUSY",
|
||||
"agent_id": agent_id,
|
||||
"lock_holder_token": lock_holder_token,
|
||||
}
|
||||
super().__init__(message=message, code=code, details=details)
|
||||
|
||||
|
||||
class LettaToolCreateError(LettaError):
|
||||
"""Error raised when a tool cannot be created."""
|
||||
|
||||
@@ -167,7 +185,9 @@ class LettaImageFetchError(LettaError):
|
||||
def __init__(self, url: str, reason: str):
|
||||
details = {"url": url, "reason": reason}
|
||||
super().__init__(
|
||||
message=f"Failed to fetch image from {url}: {reason}", code=ErrorCode.INVALID_ARGUMENT, details=details,
|
||||
message=f"Failed to fetch image from {url}: {reason}",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
@@ -238,6 +258,10 @@ class LLMBadRequestError(LLMError):
|
||||
"""Error when LLM service cannot process request"""
|
||||
|
||||
|
||||
class LLMInsufficientCreditsError(LLMError):
|
||||
"""Error when LLM provider reports insufficient credits or quota"""
|
||||
|
||||
|
||||
class LLMAuthenticationError(LLMError):
|
||||
"""Error when authentication fails with LLM service"""
|
||||
|
||||
@@ -308,7 +332,9 @@ class ContextWindowExceededError(LettaError):
|
||||
def __init__(self, message: str, details: dict = {}):
|
||||
error_message = f"{message} ({details})"
|
||||
super().__init__(
|
||||
message=error_message, code=ErrorCode.CONTEXT_WINDOW_EXCEEDED, details=details,
|
||||
message=error_message,
|
||||
code=ErrorCode.CONTEXT_WINDOW_EXCEEDED,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
@@ -328,7 +354,9 @@ class RateLimitExceededError(LettaError):
|
||||
def __init__(self, message: str, max_retries: int):
|
||||
error_message = f"{message} ({max_retries})"
|
||||
super().__init__(
|
||||
message=error_message, code=ErrorCode.RATE_LIMIT_EXCEEDED, details={"max_retries": max_retries},
|
||||
message=error_message,
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
details={"max_retries": max_retries},
|
||||
)
|
||||
|
||||
|
||||
@@ -383,7 +411,8 @@ class HandleNotFoundError(LettaError):
|
||||
|
||||
def __init__(self, handle: str, available_handles: List[str]):
|
||||
super().__init__(
|
||||
message=f"Handle {handle} not found, must be one of {available_handles}", code=ErrorCode.NOT_FOUND,
|
||||
message=f"Handle {handle} not found, must be one of {available_handles}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@@ -423,6 +452,16 @@ class AgentFileImportError(Exception):
|
||||
"""Exception raised during agent file import operations"""
|
||||
|
||||
|
||||
class InsufficientCreditsError(LettaError):
|
||||
"""Raised when an organization has no remaining credits."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
message="Insufficient credits to process this request.",
|
||||
details={"error_code": "INSUFFICIENT_CREDITS"},
|
||||
)
|
||||
|
||||
|
||||
class RunCancelError(LettaError):
|
||||
"""Error raised when a run cannot be cancelled."""
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
||||
|
||||
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agents.letta_agent import LettaAgent as Agent
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
|
||||
|
||||
|
||||
def memory(
|
||||
agent_state: "AgentState",
|
||||
@@ -242,7 +243,7 @@ async def archival_memory_search(
|
||||
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
|
||||
|
||||
|
||||
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore
|
||||
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> str: # type: ignore
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
@@ -251,15 +252,15 @@ def core_memory_append(agent_state: "AgentState", label: str, content: str) -> O
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
str: The updated value of the memory block.
|
||||
"""
|
||||
current_value = str(agent_state.memory.get_block(label).value)
|
||||
new_value = current_value + "\n" + str(content)
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
return None
|
||||
return new_value
|
||||
|
||||
|
||||
def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
|
||||
def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> str: # type: ignore
|
||||
"""
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
@@ -269,14 +270,14 @@ def core_memory_replace(agent_state: "AgentState", label: str, old_content: str,
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
str: The updated value of the memory block.
|
||||
"""
|
||||
current_value = str(agent_state.memory.get_block(label).value)
|
||||
if old_content not in current_value:
|
||||
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
|
||||
new_value = current_value.replace(str(old_content), str(new_content))
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
return None
|
||||
return new_value
|
||||
|
||||
|
||||
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> None:
|
||||
@@ -307,125 +308,118 @@ SNIPPET_LINES: int = 4
|
||||
|
||||
|
||||
# Based off of: https://github.com/anthropics/anthropic-quickstarts/blob/main/computer-use-demo/computer_use_demo/tools/edit.py?ref=musings.yasyf.com#L154
|
||||
def memory_replace(agent_state: "AgentState", label: str, old_str: str, new_str: str) -> str: # type: ignore
|
||||
def memory_replace(agent_state: "AgentState", label: str, old_string: str, new_string: str) -> str: # type: ignore
|
||||
"""
|
||||
The memory_replace command allows you to replace a specific string in a memory block with a new string. This is used for making precise edits.
|
||||
Do NOT attempt to replace long strings, e.g. do not attempt to replace the entire contents of a memory block with a new string.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited, identified by its label.
|
||||
old_str (str): The text to replace (must match exactly, including whitespace and indentation).
|
||||
new_str (str): The new text to insert in place of the old text. Do not include line number prefixes.
|
||||
old_string (str): The text to replace (must match exactly, including whitespace and indentation).
|
||||
new_string (str): The new text to insert in place of the old text. Do not include line number prefixes.
|
||||
|
||||
Examples:
|
||||
# Update a block containing information about the user
|
||||
memory_replace(label="human", old_str="Their name is Alice", new_str="Their name is Bob")
|
||||
memory_replace(label="human", old_string="Their name is Alice", new_string="Their name is Bob")
|
||||
|
||||
# Update a block containing a todo list
|
||||
memory_replace(label="todos", old_str="- [ ] Step 5: Search the web", new_str="- [x] Step 5: Search the web")
|
||||
memory_replace(label="todos", old_string="- [ ] Step 5: Search the web", new_string="- [x] Step 5: Search the web")
|
||||
|
||||
# Pass an empty string to
|
||||
memory_replace(label="human", old_str="Their name is Alice", new_str="")
|
||||
memory_replace(label="human", old_string="Their name is Alice", new_string="")
|
||||
|
||||
# Bad example - do NOT add (view-only) line numbers to the args
|
||||
memory_replace(label="human", old_str="1: Their name is Alice", new_str="1: Their name is Bob")
|
||||
memory_replace(label="human", old_string="1: Their name is Alice", new_string="1: Their name is Bob")
|
||||
|
||||
# Bad example - do NOT include the line number warning either
|
||||
memory_replace(label="human", old_str="# NOTE: Line numbers shown below (with arrows like '1→') are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\\n1→ Their name is Alice", new_str="1→ Their name is Bob")
|
||||
memory_replace(label="human", old_string="# NOTE: Line numbers shown below (with arrows like '1→') are to help during editing. Do NOT include line number prefixes in your memory edit tool calls.\\n1→ Their name is Alice", new_string="1→ Their name is Bob")
|
||||
|
||||
# Good example - no line numbers or line number warning (they are view-only), just the text
|
||||
memory_replace(label="human", old_str="Their name is Alice", new_str="Their name is Bob")
|
||||
memory_replace(label="human", old_string="Their name is Alice", new_string="Their name is Bob")
|
||||
|
||||
Returns:
|
||||
str: The success message
|
||||
str: The updated value of the memory block.
|
||||
"""
|
||||
import re
|
||||
|
||||
if bool(re.search(r"\nLine \d+: ", old_str)):
|
||||
if bool(re.search(r"\nLine \d+: ", old_string)):
|
||||
raise ValueError(
|
||||
"old_str contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
|
||||
"old_string contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
|
||||
)
|
||||
if CORE_MEMORY_LINE_NUMBER_WARNING in old_str:
|
||||
if CORE_MEMORY_LINE_NUMBER_WARNING in old_string:
|
||||
raise ValueError(
|
||||
"old_str contains a line number warning, which is not allowed. Do not include line number information when calling memory tools (line numbers are for display purposes only)."
|
||||
"old_string contains a line number warning, which is not allowed. Do not include line number information when calling memory tools (line numbers are for display purposes only)."
|
||||
)
|
||||
if bool(re.search(r"\nLine \d+: ", new_str)):
|
||||
if bool(re.search(r"\nLine \d+: ", new_string)):
|
||||
raise ValueError(
|
||||
"new_str contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
|
||||
"new_string contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
|
||||
)
|
||||
|
||||
old_str = str(old_str).expandtabs()
|
||||
new_str = str(new_str).expandtabs()
|
||||
old_string = str(old_string).expandtabs()
|
||||
new_string = str(new_string).expandtabs()
|
||||
current_value = str(agent_state.memory.get_block(label).value).expandtabs()
|
||||
|
||||
# Check if old_str is unique in the block
|
||||
occurences = current_value.count(old_str)
|
||||
# Check if old_string is unique in the block
|
||||
occurences = current_value.count(old_string)
|
||||
if occurences == 0:
|
||||
raise ValueError(f"No replacement was performed, old_str `{old_str}` did not appear verbatim in memory block with label `{label}`.")
|
||||
raise ValueError(
|
||||
f"No replacement was performed, old_string `{old_string}` did not appear verbatim in memory block with label `{label}`."
|
||||
)
|
||||
elif occurences > 1:
|
||||
content_value_lines = current_value.split("\n")
|
||||
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_str in line]
|
||||
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_string in line]
|
||||
raise ValueError(
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique."
|
||||
f"No replacement was performed. Multiple occurrences of old_string `{old_string}` in lines {lines}. Please ensure it is unique."
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
new_value = current_value.replace(str(old_str), str(new_str))
|
||||
# Replace old_string with new_string
|
||||
new_value = current_value.replace(str(old_string), str(new_string))
|
||||
|
||||
# Write the new content to the block
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
# SNIPPET_LINES = 3
|
||||
# replacement_line = current_value.split(old_str)[0].count("\n")
|
||||
# replacement_line = current_value.split(old_string)[0].count("\n")
|
||||
# start_line = max(0, replacement_line - SNIPPET_LINES)
|
||||
# end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
||||
# end_line = replacement_line + SNIPPET_LINES + new_string.count("\n")
|
||||
# snippet = "\n".join(new_value.split("\n")[start_line : end_line + 1])
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = (
|
||||
f"The core memory block with label `{label}` has been successfully edited. "
|
||||
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
|
||||
f"Review the changes and make sure they are as expected (correct indentation, "
|
||||
f"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
# return None
|
||||
return success_msg
|
||||
return new_value
|
||||
|
||||
|
||||
def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_line: int = -1) -> Optional[str]: # type: ignore
|
||||
def memory_insert(agent_state: "AgentState", label: str, new_string: str, insert_line: int = -1) -> str: # type: ignore
|
||||
"""
|
||||
The memory_insert command allows you to insert text at a specific location in a memory block.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited, identified by its label.
|
||||
new_str (str): The text to insert. Do not include line number prefixes.
|
||||
new_string (str): The text to insert. Do not include line number prefixes.
|
||||
insert_line (int): The line number after which to insert the text (0 for beginning of file). Defaults to -1 (end of the file).
|
||||
|
||||
Examples:
|
||||
# Update a block containing information about the user (append to the end of the block)
|
||||
memory_insert(label="customer", new_str="The customer's ticket number is 12345")
|
||||
memory_insert(label="customer", new_string="The customer's ticket number is 12345")
|
||||
|
||||
# Update a block containing information about the user (insert at the beginning of the block)
|
||||
memory_insert(label="customer", new_str="The customer's ticket number is 12345", insert_line=0)
|
||||
memory_insert(label="customer", new_string="The customer's ticket number is 12345", insert_line=0)
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
import re
|
||||
|
||||
if bool(re.search(r"\nLine \d+: ", new_str)):
|
||||
if bool(re.search(r"\nLine \d+: ", new_string)):
|
||||
raise ValueError(
|
||||
"new_str contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
|
||||
"new_string contains a line number prefix, which is not allowed. Do not include line numbers when calling memory tools (line numbers are for display purposes only)."
|
||||
)
|
||||
if CORE_MEMORY_LINE_NUMBER_WARNING in new_str:
|
||||
if CORE_MEMORY_LINE_NUMBER_WARNING in new_string:
|
||||
raise ValueError(
|
||||
"new_str contains a line number warning, which is not allowed. Do not include line number information when calling memory tools (line numbers are for display purposes only)."
|
||||
"new_string contains a line number warning, which is not allowed. Do not include line number information when calling memory tools (line numbers are for display purposes only)."
|
||||
)
|
||||
|
||||
current_value = str(agent_state.memory.get_block(label).value).expandtabs()
|
||||
new_str = str(new_str).expandtabs()
|
||||
new_string = str(new_string).expandtabs()
|
||||
current_value_lines = current_value.split("\n")
|
||||
n_lines = len(current_value_lines)
|
||||
|
||||
@@ -438,11 +432,11 @@ def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_li
|
||||
)
|
||||
|
||||
# Insert the new string as a line
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_value_lines = current_value_lines[:insert_line] + new_str_lines + current_value_lines[insert_line:]
|
||||
snippet_lines = (
|
||||
new_string_lines = new_string.split("\n")
|
||||
new_value_lines = current_value_lines[:insert_line] + new_string_lines + current_value_lines[insert_line:]
|
||||
(
|
||||
current_value_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ new_string_lines
|
||||
+ current_value_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
@@ -453,15 +447,7 @@ def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_li
|
||||
# Write into the block
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = (
|
||||
f"The core memory block with label `{label}` has been successfully edited. "
|
||||
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
|
||||
f"Review the changes and make sure they are as expected (correct indentation, "
|
||||
f"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
return success_msg
|
||||
return new_value
|
||||
|
||||
|
||||
def memory_apply_patch(agent_state: "AgentState", label: str, patch: str) -> str: # type: ignore
|
||||
@@ -499,7 +485,7 @@ def memory_apply_patch(agent_state: "AgentState", label: str, patch: str) -> str
|
||||
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
|
||||
|
||||
|
||||
def memory_rethink(agent_state: "AgentState", label: str, new_memory: str) -> None:
|
||||
def memory_rethink(agent_state: "AgentState", label: str, new_memory: str) -> str:
|
||||
"""
|
||||
The memory_rethink command allows you to completely rewrite the contents of a memory block. Use this tool to make large sweeping changes (e.g. when you want to condense or reorganize the memory blocks), do NOT use this tool to make small precise edits (e.g. add or remove a line, replace a specific string, etc).
|
||||
|
||||
@@ -528,17 +514,7 @@ def memory_rethink(agent_state: "AgentState", label: str, new_memory: str) -> No
|
||||
agent_state.memory.set_block(new_block)
|
||||
|
||||
agent_state.memory.update_block_value(label=label, value=new_memory)
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = (
|
||||
f"The core memory block with label `{label}` has been successfully edited. "
|
||||
f"Your system prompt has been recompiled with the updated memory contents and is now active in your context. "
|
||||
f"Review the changes and make sure they are as expected (correct indentation, "
|
||||
f"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
# return None
|
||||
return success_msg
|
||||
return new_memory
|
||||
|
||||
|
||||
def memory_finish_edits(agent_state: "AgentState") -> None: # type: ignore
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agents.letta_agent import LettaAgent as Agent
|
||||
|
||||
from letta.functions.helpers import (
|
||||
_send_message_to_agents_matching_tags_async,
|
||||
_send_message_to_all_agents_in_group_async,
|
||||
execute_send_message_to_agent,
|
||||
extract_send_message_from_steps_messages,
|
||||
fire_and_forget_send_to_agent,
|
||||
)
|
||||
from letta.schemas.enums import MessageRole
|
||||
@@ -27,9 +28,7 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_
|
||||
str: The response from the target agent.
|
||||
"""
|
||||
augmented_message = (
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - your response will be delivered to the sender] {message}"
|
||||
)
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
|
||||
|
||||
@@ -56,57 +55,18 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
|
||||
in the returned list.
|
||||
"""
|
||||
server = get_letta_server()
|
||||
augmented_message = (
|
||||
f"[Incoming message from external Letta agent - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
augmented_message = f"[Incoming message from external Letta agent - your response will be delivered to the sender] {message}"
|
||||
|
||||
# Find matching agents
|
||||
matching_agents = server.agent_manager.list_agents_matching_tags(actor=self.user, match_all=match_all, match_some=match_some)
|
||||
if not matching_agents:
|
||||
return []
|
||||
|
||||
def process_agent(agent_id: str) -> str:
|
||||
"""Loads an agent, formats the message, and executes .step()"""
|
||||
actor = self.user # Ensure correct actor context
|
||||
agent = server.load_agent(agent_id=agent_id, interface=None, actor=actor)
|
||||
# Prepare the message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
|
||||
|
||||
# Prepare the message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
|
||||
|
||||
# Run .step() and return the response
|
||||
usage_stats = agent.step(
|
||||
input_messages=messages,
|
||||
chaining=True,
|
||||
max_chaining_steps=None,
|
||||
stream=False,
|
||||
skip_verify=True,
|
||||
metadata=None,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
|
||||
send_messages = extract_send_message_from_steps_messages(usage_stats.steps_messages, logger=agent.logger)
|
||||
response_data = {
|
||||
"agent_id": agent_id,
|
||||
"response_messages": send_messages if send_messages else ["<no response>"],
|
||||
}
|
||||
|
||||
return json.dumps(response_data, indent=2)
|
||||
|
||||
# Use ThreadPoolExecutor for parallel execution
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=settings.multi_agent_concurrent_sends) as executor:
|
||||
future_to_agent = {executor.submit(process_agent, agent_state.id): agent_state for agent_state in matching_agents}
|
||||
|
||||
for future in as_completed(future_to_agent):
|
||||
try:
|
||||
results.append(future.result()) # Collect results
|
||||
except Exception as e:
|
||||
# Log or handle failure for specific agents if needed
|
||||
self.logger.exception(f"Error processing agent {future_to_agent[future]}: {e}")
|
||||
|
||||
return results
|
||||
# Use async helper for parallel message sending
|
||||
return asyncio.run(_send_message_to_agents_matching_tags_async(self, server, messages, matching_agents))
|
||||
|
||||
|
||||
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:
|
||||
@@ -138,8 +98,8 @@ def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str
|
||||
raise RuntimeError("This tool is not allowed to be run on Letta Cloud.")
|
||||
|
||||
message = (
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message_to_agent_async' tool, or the agent will not receive your message] "
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - "
|
||||
f"this is a one-way notification; if you need to respond, use an agent-to-agent messaging tool if available] "
|
||||
f"{message}"
|
||||
)
|
||||
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
## Voice chat + sleeptime tools
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dic
|
||||
pass # Field is required, no default
|
||||
else:
|
||||
field_kwargs["default"] = default_val
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
fields[field_name] = Field(**field_kwargs)
|
||||
@@ -188,7 +188,7 @@ def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dic
|
||||
try:
|
||||
default_val = ast.literal_eval(stmt.value)
|
||||
fields[field_name] = default_val
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create the dynamic Pydantic model
|
||||
|
||||
@@ -3,7 +3,17 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
from random import uniform
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agents.letta_agent import LettaAgent as Agent
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
try:
|
||||
from langchain.tools.base import BaseTool as LangChainBaseTool
|
||||
except ImportError:
|
||||
LangChainBaseTool = None
|
||||
|
||||
import humps
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
@@ -21,6 +31,8 @@ from letta.server.rest_api.dependencies import get_letta_server
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
# TODO needed?
|
||||
def generate_mcp_tool_wrapper(mcp_tool_name: str) -> tuple[str, str]:
|
||||
@@ -36,7 +48,8 @@ def {mcp_tool_name}(**kwargs):
|
||||
|
||||
|
||||
def generate_langchain_tool_wrapper(
|
||||
tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None
|
||||
tool: "LangChainBaseTool",
|
||||
additional_imports_module_attr_map: dict[str, str] | None = None,
|
||||
) -> tuple[str, str]:
|
||||
tool_name = tool.__class__.__name__
|
||||
import_statement = f"from langchain_community.tools import {tool_name}"
|
||||
@@ -428,15 +441,18 @@ def fire_and_forget_send_to_agent(
|
||||
# 4) Try to schedule the coroutine in an existing loop, else spawn a thread
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we get here, a loop is running; schedule the coroutine in background
|
||||
loop.create_task(background_task())
|
||||
task = loop.create_task(background_task())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except RuntimeError:
|
||||
# Means no event loop is running in this thread
|
||||
run_in_background_thread(background_task())
|
||||
|
||||
|
||||
async def _send_message_to_agents_matching_tags_async(
|
||||
sender_agent: "Agent", server: "SyncServer", messages: List[MessageCreate], matching_agents: List["AgentState"]
|
||||
sender_agent: "Agent",
|
||||
server: "SyncServer",
|
||||
messages: List[MessageCreate],
|
||||
matching_agents: List["AgentState"],
|
||||
) -> List[str]:
|
||||
async def _send_single(agent_state):
|
||||
return await _async_send_message_with_retries(
|
||||
@@ -464,9 +480,7 @@ async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", mess
|
||||
server = get_letta_server()
|
||||
|
||||
augmented_message = (
|
||||
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - your response will be delivered to the sender] {message}"
|
||||
)
|
||||
|
||||
worker_agents_ids = sender_agent.agent_state.multi_agent_group.agent_ids
|
||||
@@ -520,7 +534,9 @@ def generate_model_from_args_json_schema(schema: Dict[str, Any]) -> Type[BaseMod
|
||||
return _create_model_from_schema(schema.get("title", "DynamicModel"), schema, nested_models)
|
||||
|
||||
|
||||
def _create_model_from_schema(name: str, model_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] = None) -> Type[BaseModel]:
|
||||
def _create_model_from_schema(
|
||||
name: str, model_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] | None = None
|
||||
) -> Type[BaseModel]:
|
||||
fields = {}
|
||||
for field_name, field_schema in model_schema["properties"].items():
|
||||
field_type = _get_field_type(field_schema, nested_models)
|
||||
@@ -531,7 +547,7 @@ def _create_model_from_schema(name: str, model_schema: Dict[str, Any], nested_mo
|
||||
return create_model(name, **fields)
|
||||
|
||||
|
||||
def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] = None) -> Any:
|
||||
def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] | None = None) -> Any:
|
||||
"""Helper to convert JSON schema types to Python types."""
|
||||
if field_schema.get("type") == "string":
|
||||
return str
|
||||
|
||||
@@ -98,6 +98,32 @@ class BaseServerConfig(BaseModel):
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_dict_key(key: str) -> str:
|
||||
"""Strip surrounding quotes and trailing colons from a dict key."""
|
||||
key = key.strip()
|
||||
for quote in ('"', "'"):
|
||||
if key.startswith(quote) and key.endswith(quote):
|
||||
key = key[1:-1]
|
||||
break
|
||||
key = key.rstrip(":")
|
||||
return key.strip()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_dict_value(value: str) -> str:
|
||||
"""Strip surrounding quotes from a dict value."""
|
||||
value = value.strip()
|
||||
for quote in ('"', "'"):
|
||||
if value.startswith(quote) and value.endswith(quote):
|
||||
value = value[1:-1]
|
||||
break
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _sanitize_dict(cls, d: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Sanitize a string dict by stripping quotes from keys and values."""
|
||||
return {cls._sanitize_dict_key(k): cls._sanitize_dict_value(v) for k, v in d.items()}
|
||||
|
||||
def resolve_custom_headers(
|
||||
self, custom_headers: Optional[Dict[str, str]], environment_variables: Optional[Dict[str, str]] = None
|
||||
) -> Optional[Dict[str, str]]:
|
||||
@@ -114,6 +140,8 @@ class BaseServerConfig(BaseModel):
|
||||
if custom_headers is None:
|
||||
return None
|
||||
|
||||
custom_headers = self._sanitize_dict(custom_headers)
|
||||
|
||||
resolved_headers = {}
|
||||
for key, value in custom_headers.items():
|
||||
# Resolve templated variables in each header value
|
||||
@@ -164,8 +192,12 @@ class HTTPBasedServerConfig(BaseServerConfig):
|
||||
return None
|
||||
|
||||
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
|
||||
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
|
||||
self.auth_token = super().get_tool_variable(self.auth_token, environment_variables)
|
||||
if self.auth_header:
|
||||
self.auth_header = self._sanitize_dict_key(self.auth_header)
|
||||
if self.auth_token:
|
||||
self.auth_token = self._sanitize_dict_value(self.auth_token)
|
||||
if super().is_templated_tool_variable(self.auth_token):
|
||||
self.auth_token = super().get_tool_variable(self.auth_token, environment_variables)
|
||||
|
||||
self.custom_headers = super().resolve_custom_headers(self.custom_headers, environment_variables)
|
||||
|
||||
@@ -176,11 +208,11 @@ class HTTPBasedServerConfig(BaseServerConfig):
|
||||
Returns:
|
||||
Dictionary of headers or None if no headers are configured
|
||||
"""
|
||||
if self.custom_headers is not None or (self.auth_header is not None and self.auth_token is not None):
|
||||
if self.custom_headers is not None or (self.auth_header and self.auth_token):
|
||||
headers = self.custom_headers.copy() if self.custom_headers else {}
|
||||
|
||||
# Add auth header if specified
|
||||
if self.auth_header is not None and self.auth_token is not None:
|
||||
# Add auth header if specified (skip if either is empty to avoid illegal header values)
|
||||
if self.auth_header and self.auth_token:
|
||||
headers[self.auth_header] = self.auth_token
|
||||
|
||||
return headers
|
||||
|
||||
@@ -96,7 +96,7 @@ def type_to_json_schema_type(py_type) -> dict:
|
||||
|
||||
# Handle array types
|
||||
origin = get_origin(py_type)
|
||||
if py_type == list or origin in (list, List):
|
||||
if py_type is list or origin in (list, List):
|
||||
args = get_args(py_type)
|
||||
if len(args) == 0:
|
||||
# is this correct
|
||||
@@ -142,7 +142,7 @@ def type_to_json_schema_type(py_type) -> dict:
|
||||
}
|
||||
|
||||
# Handle object types
|
||||
if py_type == dict or origin in (dict, Dict):
|
||||
if py_type is dict or origin in (dict, Dict):
|
||||
args = get_args(py_type)
|
||||
if not args:
|
||||
# Generic dict without type arguments
|
||||
@@ -704,8 +704,9 @@ def generate_tool_schema_for_mcp(
|
||||
name = mcp_tool.name
|
||||
description = mcp_tool.description
|
||||
|
||||
assert "type" in parameters_schema, parameters_schema
|
||||
assert "properties" in parameters_schema, parameters_schema
|
||||
if "type" not in parameters_schema:
|
||||
parameters_schema["type"] = "object"
|
||||
parameters_schema.setdefault("properties", {})
|
||||
# assert "required" in parameters_schema, parameters_schema
|
||||
|
||||
# Normalize the schema to fix common issues with MCP schemas
|
||||
|
||||
@@ -56,7 +56,7 @@ def validate_complete_json_schema(schema: Dict[str, Any]) -> Tuple[SchemaHealth,
|
||||
"""
|
||||
if obj_schema.get("type") != "object":
|
||||
return False
|
||||
props = obj_schema.get("properties", {})
|
||||
obj_schema.get("properties", {})
|
||||
required = obj_schema.get("required", [])
|
||||
additional = obj_schema.get("additionalProperties", True)
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agents.letta_agent import LettaAgent as Agent
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
@@ -92,7 +95,7 @@ class DynamicMultiAgent(BaseAgent):
|
||||
|
||||
# Parse manager response
|
||||
responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages)
|
||||
assistant_message = [response for response in responses if response.message_type == "assistant_message"][0]
|
||||
assistant_message = next(response for response in responses if response.message_type == "assistant_message")
|
||||
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
|
||||
if name.lower() in assistant_message.content.lower():
|
||||
speaker_id = agent_id
|
||||
|
||||
@@ -98,7 +98,7 @@ def stringify_message(message: Message, use_assistant_name: bool = False) -> str
|
||||
elif isinstance(content, ImageContent):
|
||||
messages.append(f"{message.name or 'user'}: [Image Here]")
|
||||
return "\n".join(messages)
|
||||
except:
|
||||
except Exception:
|
||||
if message.content and len(message.content) > 0:
|
||||
return f"{message.name or 'user'}: {message.content[0].text}"
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -213,7 +212,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
|
||||
)
|
||||
for sleeptime_agent_id in self.group.agent_ids:
|
||||
run_id = await self._issue_background_task(
|
||||
await self._issue_background_task(
|
||||
sleeptime_agent_id,
|
||||
last_response_messages,
|
||||
last_processed_message_id,
|
||||
|
||||
@@ -9,7 +9,6 @@ from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
@@ -47,6 +46,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
) -> LettaResponse:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -61,6 +61,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
client_tools=client_tools,
|
||||
include_compaction_messages=include_compaction_messages,
|
||||
)
|
||||
|
||||
await self.run_sleeptime_agents()
|
||||
@@ -79,6 +80,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -96,6 +98,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
client_tools=client_tools,
|
||||
include_compaction_messages=include_compaction_messages,
|
||||
):
|
||||
yield chunk
|
||||
finally:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -7,9 +6,8 @@ from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.groups.helpers import stringify_message
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import JobStatus, RunStatus
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
@@ -48,6 +46,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
) -> LettaResponse:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -63,6 +62,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
conversation_id=conversation_id,
|
||||
client_tools=client_tools,
|
||||
include_compaction_messages=include_compaction_messages,
|
||||
)
|
||||
|
||||
run_ids = await self.run_sleeptime_agents()
|
||||
@@ -81,6 +81,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -99,6 +100,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
conversation_id=conversation_id,
|
||||
client_tools=client_tools,
|
||||
include_compaction_messages=include_compaction_messages,
|
||||
):
|
||||
yield chunk
|
||||
finally:
|
||||
|
||||
@@ -1,19 +1,9 @@
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL
|
||||
from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import ToolType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
from letta.helpers.tool_rule_solver import ToolRulesSolver
|
||||
from letta.helpers.tool_rule_solver import ToolRulesSolver as ToolRulesSolver
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.services.summarizer.summarizer_config import CompactionSettings
|
||||
|
||||
import numpy as np
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
|
||||
@@ -4,6 +4,53 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
def sanitize_unicode_surrogates(value: Any) -> Any:
|
||||
"""Recursively remove invalid Unicode surrogate characters from strings.
|
||||
|
||||
Unicode surrogate pairs (U+D800 to U+DFFF) are used internally by UTF-16 encoding
|
||||
but are invalid as standalone characters in UTF-8. When present, they cause
|
||||
UnicodeEncodeError when encoding to UTF-8, breaking API requests that need to
|
||||
serialize data to JSON.
|
||||
|
||||
This function sanitizes:
|
||||
- Strings: removes unpaired surrogates that can't be encoded to UTF-8
|
||||
- Dicts: recursively sanitizes all string values
|
||||
- Lists: recursively sanitizes all elements
|
||||
- Other types: returned as-is
|
||||
|
||||
Args:
|
||||
value: The value to sanitize
|
||||
|
||||
Returns:
|
||||
The sanitized value with surrogate characters removed from all strings
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
# Remove lone surrogate characters (U+D800 to U+DFFF) which are invalid in UTF-8
|
||||
# Using character filtering is more reliable than encode/decode for edge cases
|
||||
try:
|
||||
# Filter out any character in the surrogate range
|
||||
return "".join(char for char in value if not (0xD800 <= ord(char) <= 0xDFFF))
|
||||
except Exception:
|
||||
# Fallback: try encode with errors="replace" which replaces surrogates with <20>
|
||||
try:
|
||||
return value.encode("utf-8", errors="replace").decode("utf-8")
|
||||
except Exception:
|
||||
# Last resort: return original (should never reach here)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
# Recursively sanitize dictionary keys and values
|
||||
return {sanitize_unicode_surrogates(k): sanitize_unicode_surrogates(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
# Recursively sanitize list elements
|
||||
return [sanitize_unicode_surrogates(item) for item in value]
|
||||
elif isinstance(value, tuple):
|
||||
# Recursively sanitize tuple elements (return as tuple)
|
||||
return tuple(sanitize_unicode_surrogates(item) for item in value)
|
||||
else:
|
||||
# Return other types as-is (int, float, bool, None, etc.)
|
||||
return value
|
||||
|
||||
|
||||
def sanitize_null_bytes(value: Any) -> Any:
|
||||
"""Recursively remove null bytes (0x00) from strings.
|
||||
|
||||
|
||||
@@ -139,6 +139,20 @@ async def _convert_message_create_to_message(
|
||||
image_media_type, _ = mimetypes.guess_type(file_path)
|
||||
if not image_media_type:
|
||||
image_media_type = "image/jpeg" # default fallback
|
||||
elif url.startswith("data:"):
|
||||
# Handle data: URLs (inline base64 encoded images)
|
||||
# Format: data:[<mediatype>][;base64],<data>
|
||||
try:
|
||||
# Split header from data
|
||||
header, image_data = url.split(",", 1)
|
||||
# Extract media type from header (e.g., "data:image/jpeg;base64")
|
||||
header_parts = header.split(";")
|
||||
image_media_type = header_parts[0].replace("data:", "") or "image/jpeg"
|
||||
# Data is already base64 encoded, set directly and continue
|
||||
content.source = Base64Image(media_type=image_media_type, data=image_data)
|
||||
continue # Skip the common conversion path below
|
||||
except ValueError:
|
||||
raise LettaImageFetchError(url=url[:100] + "...", reason="Invalid data URL format")
|
||||
else:
|
||||
# Handle http(s):// URLs using async httpx
|
||||
image_bytes, image_media_type = await _fetch_image_from_url(url)
|
||||
|
||||
@@ -306,7 +306,9 @@ async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any],
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
|
||||
async def list_pinecone_index_for_files(
|
||||
file_id: str, actor: User, limit: int | None = None, pagination_token: str | None = None
|
||||
) -> List[str]:
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, MCP_TOOL_METADATA_SCHEMA_WARNINGS
|
||||
@@ -201,7 +201,7 @@ def add_pre_execution_message(tool_schema: Dict[str, Any], description: Optional
|
||||
|
||||
# Ensure pre-execution message is the first required field
|
||||
if PRE_EXECUTION_MESSAGE_ARG not in required:
|
||||
required = [PRE_EXECUTION_MESSAGE_ARG] + required
|
||||
required = [PRE_EXECUTION_MESSAGE_ARG, *required]
|
||||
|
||||
# Update the schema with ordered properties and required list
|
||||
schema["parameters"] = {
|
||||
|
||||
@@ -3,12 +3,20 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
|
||||
import httpx
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, TagMatchMode
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
@@ -16,6 +24,136 @@ from letta.settings import model_settings, settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type variable for generic async retry decorator
|
||||
T = TypeVar("T")
|
||||
|
||||
# Default retry configuration for turbopuffer operations
|
||||
TPUF_MAX_RETRIES = 3
|
||||
TPUF_INITIAL_DELAY = 1.0 # seconds
|
||||
TPUF_EXPONENTIAL_BASE = 2.0
|
||||
TPUF_JITTER = True
|
||||
|
||||
|
||||
def is_transient_error(error: Exception) -> bool:
|
||||
"""Check if an error is transient and should be retried.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the error is transient and can be retried
|
||||
"""
|
||||
# httpx connection errors (network issues, DNS failures, etc.)
|
||||
if isinstance(error, httpx.ConnectError):
|
||||
return True
|
||||
|
||||
# httpx timeout errors
|
||||
if isinstance(error, httpx.TimeoutException):
|
||||
return True
|
||||
|
||||
# httpx network errors
|
||||
if isinstance(error, httpx.NetworkError):
|
||||
return True
|
||||
|
||||
# Check for connection-related errors in the error message
|
||||
error_str = str(error).lower()
|
||||
transient_patterns = [
|
||||
"connect call failed",
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"connection timed out",
|
||||
"temporary failure",
|
||||
"name resolution",
|
||||
"dns",
|
||||
"network unreachable",
|
||||
"no route to host",
|
||||
"ssl handshake",
|
||||
]
|
||||
for pattern in transient_patterns:
|
||||
if pattern in error_str:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def async_retry_with_backoff(
|
||||
max_retries: int = TPUF_MAX_RETRIES,
|
||||
initial_delay: float = TPUF_INITIAL_DELAY,
|
||||
exponential_base: float = TPUF_EXPONENTIAL_BASE,
|
||||
jitter: bool = TPUF_JITTER,
|
||||
):
|
||||
"""Decorator for async functions that retries on transient errors with exponential backoff.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts
|
||||
initial_delay: Initial delay between retries in seconds
|
||||
exponential_base: Base for exponential backoff calculation
|
||||
jitter: Whether to add random jitter to delays
|
||||
|
||||
Returns:
|
||||
Decorated async function with retry logic
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a retryable error
|
||||
if not is_transient_error(e):
|
||||
# Not a transient error, re-raise immediately
|
||||
raise
|
||||
|
||||
num_retries += 1
|
||||
|
||||
# Log the retry attempt
|
||||
log_event(
|
||||
"turbopuffer_retry_attempt",
|
||||
{
|
||||
"attempt": num_retries,
|
||||
"delay": delay,
|
||||
"error_type": type(e).__name__,
|
||||
"error": str(e),
|
||||
"function": func.__name__,
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
f"Turbopuffer operation '{func.__name__}' failed with transient error "
|
||||
f"(attempt {num_retries}/{max_retries}): {e}. Retrying in {delay:.1f}s..."
|
||||
)
|
||||
|
||||
# Check if max retries exceeded
|
||||
if num_retries > max_retries:
|
||||
log_event(
|
||||
"turbopuffer_max_retries_exceeded",
|
||||
{
|
||||
"max_retries": max_retries,
|
||||
"error_type": type(e).__name__,
|
||||
"error": str(e),
|
||||
"function": func.__name__,
|
||||
},
|
||||
)
|
||||
logger.error(f"Turbopuffer operation '{func.__name__}' failed after {max_retries} retries: {e}")
|
||||
raise
|
||||
|
||||
# Wait with exponential backoff
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Calculate next delay with optional jitter
|
||||
delay *= exponential_base
|
||||
if jitter:
|
||||
delay *= 1 + random.random() * 0.1 # Add up to 10% jitter
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Global semaphore for Turbopuffer operations to prevent overwhelming the service
|
||||
# This is separate from embedding semaphore since Turbopuffer can handle more concurrency
|
||||
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
|
||||
@@ -25,11 +163,11 @@ def _run_turbopuffer_write_in_thread(
|
||||
api_key: str,
|
||||
region: str,
|
||||
namespace_name: str,
|
||||
upsert_columns: dict = None,
|
||||
deletes: list = None,
|
||||
delete_by_filter: tuple = None,
|
||||
upsert_columns: dict | None = None,
|
||||
deletes: list | None = None,
|
||||
delete_by_filter: tuple | None = None,
|
||||
distance_metric: str = "cosine_distance",
|
||||
schema: dict = None,
|
||||
schema: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Sync wrapper to run turbopuffer write in isolated event loop.
|
||||
@@ -93,7 +231,7 @@ class TurbopufferClient:
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str = None, region: str = None):
|
||||
def __init__(self, api_key: str | None = None, region: str | None = None):
|
||||
"""Initialize Turbopuffer client."""
|
||||
self.api_key = api_key or settings.tpuf_api_key
|
||||
self.region = region or settings.tpuf_region
|
||||
@@ -222,6 +360,7 @@ class TurbopufferClient:
|
||||
return json.dumps(parts)
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def insert_tools(
|
||||
self,
|
||||
tools: List["PydanticTool"],
|
||||
@@ -238,7 +377,6 @@ class TurbopufferClient:
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not tools:
|
||||
return True
|
||||
@@ -313,6 +451,7 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def insert_archival_memories(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -339,7 +478,6 @@ class TurbopufferClient:
|
||||
Returns:
|
||||
List of PydanticPassage objects that were inserted
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
# filter out empty text chunks
|
||||
filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
|
||||
@@ -464,6 +602,7 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def insert_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -494,7 +633,6 @@ class TurbopufferClient:
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
# filter out empty message texts
|
||||
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
|
||||
@@ -609,6 +747,7 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def _execute_query(
|
||||
self,
|
||||
namespace_name: str,
|
||||
@@ -1377,9 +1516,9 @@ class TurbopufferClient:
|
||||
return sorted_results[:top_k]
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
||||
"""Delete a passage from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = await self._get_archive_namespace_name(archive_id)
|
||||
|
||||
@@ -1399,9 +1538,9 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
|
||||
"""Delete multiple passages from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not passage_ids:
|
||||
return True
|
||||
@@ -1424,6 +1563,7 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def delete_all_passages(self, archive_id: str) -> bool:
|
||||
"""Delete all passages for an archive from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
@@ -1442,9 +1582,9 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
|
||||
"""Delete multiple messages from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not message_ids:
|
||||
return True
|
||||
@@ -1467,9 +1607,9 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
|
||||
"""Delete all messages for an agent from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = await self._get_message_namespace_name(organization_id)
|
||||
|
||||
@@ -1509,6 +1649,7 @@ class TurbopufferClient:
|
||||
return namespace_name
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def insert_file_passages(
|
||||
self,
|
||||
source_id: str,
|
||||
@@ -1531,7 +1672,6 @@ class TurbopufferClient:
|
||||
Returns:
|
||||
List of PydanticPassage objects that were inserted
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not text_chunks:
|
||||
return []
|
||||
@@ -1765,9 +1905,9 @@ class TurbopufferClient:
|
||||
return passages_with_scores
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
|
||||
"""Delete all passages for a specific file from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
||||
|
||||
@@ -1793,9 +1933,9 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
|
||||
"""Delete all passages for a source from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
||||
|
||||
@@ -1817,6 +1957,7 @@ class TurbopufferClient:
|
||||
# tool methods
|
||||
|
||||
@trace_method
|
||||
@async_retry_with_backoff()
|
||||
async def delete_tools(self, organization_id: str, tool_ids: List[str]) -> bool:
|
||||
"""Delete tools from Turbopuffer.
|
||||
|
||||
@@ -1827,7 +1968,6 @@ class TurbopufferClient:
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not tool_ids:
|
||||
return True
|
||||
|
||||
@@ -136,7 +136,7 @@ class CLIInterface(AgentInterface):
|
||||
else:
|
||||
try:
|
||||
msg_json = json_loads(msg)
|
||||
except:
|
||||
except Exception:
|
||||
printd(f"{CLI_WARNING_PREFIX}failed to parse user message into json")
|
||||
printd_user_message("🧑", msg)
|
||||
return
|
||||
|
||||
@@ -3,7 +3,12 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import (
|
||||
@@ -274,7 +279,13 @@ class SimpleAnthropicStreamingInterface:
|
||||
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
||||
)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
|
||||
# Transform Anthropic errors into our custom error types for consistent handling
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
|
||||
client = AnthropicClient()
|
||||
transformed_error = client.handle_llm_error(e)
|
||||
raise transformed_error
|
||||
finally:
|
||||
logger.info("AnthropicStreamingInterface: Stream processing complete.")
|
||||
|
||||
@@ -316,6 +327,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
# Do not emit placeholder arguments here to avoid UI duplicates
|
||||
tool_call=ToolCallDelta(name=name, tool_call_id=call_id),
|
||||
tool_calls=ToolCallDelta(name=name, tool_call_id=call_id),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
@@ -421,6 +433,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
tool_call=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json),
|
||||
tool_calls=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
|
||||
@@ -3,7 +3,12 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import (
|
||||
@@ -116,7 +121,7 @@ class AnthropicStreamingInterface:
|
||||
# Attempt to use OptimisticJSONParser to handle incomplete/malformed JSON
|
||||
try:
|
||||
tool_input = self.json_parser.parse(args_str)
|
||||
except:
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to decode tool call arguments for tool_call_id={self.tool_call_id}, "
|
||||
f"name={self.tool_call_name}. Raw input: {args_str!r}. Error: {e}"
|
||||
@@ -263,7 +268,13 @@ class AnthropicStreamingInterface:
|
||||
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
||||
)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
|
||||
# Transform Anthropic errors into our custom error types for consistent handling
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
|
||||
client = AnthropicClient()
|
||||
transformed_error = client.handle_llm_error(e)
|
||||
raise transformed_error
|
||||
finally:
|
||||
logger.info("AnthropicStreamingInterface: Stream processing complete.")
|
||||
|
||||
@@ -424,16 +435,19 @@ class AnthropicStreamingInterface:
|
||||
if current_inner_thoughts:
|
||||
tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "")
|
||||
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=self.tool_call_name,
|
||||
tool_call_id=self.tool_call_id,
|
||||
arguments=tool_call_args,
|
||||
)
|
||||
|
||||
approval_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
name=self.tool_call_name,
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.tool_call_name,
|
||||
tool_call_id=self.tool_call_id,
|
||||
arguments=tool_call_args,
|
||||
),
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
@@ -493,6 +507,9 @@ class AnthropicStreamingInterface:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
|
||||
tool_calls=ToolCallDelta(
|
||||
name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json
|
||||
),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
@@ -576,9 +593,15 @@ class AnthropicStreamingInterface:
|
||||
pass
|
||||
elif isinstance(event, BetaRawContentBlockStopEvent):
|
||||
# If we're exiting a tool use block and there are still buffered messages,
|
||||
# we should flush them now
|
||||
# we should flush them now.
|
||||
# Ensure each flushed chunk has an otid before yielding.
|
||||
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
if not buffered_msg.otid:
|
||||
if prev_message_type and prev_message_type != buffered_msg.message_type:
|
||||
message_index += 1
|
||||
buffered_msg.otid = Message.generate_otid_from_id(buffered_msg.id, message_index)
|
||||
prev_message_type = buffered_msg.message_type
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
@@ -644,7 +667,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
# Attempt to use OptimisticJSONParser to handle incomplete/malformed JSON
|
||||
try:
|
||||
tool_input = self.json_parser.parse(args_str)
|
||||
except:
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to decode tool call arguments for tool_call_id={self.tool_call_id}, "
|
||||
f"name={self.tool_call_name}. Raw input: {args_str!r}. Error: {e}"
|
||||
@@ -827,6 +850,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
|
||||
tool_calls=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
@@ -911,6 +935,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
|
||||
tool_calls=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
|
||||
@@ -3,7 +3,12 @@ import base64
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncIterator, List, Optional
|
||||
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
from google.genai.types import (
|
||||
GenerateContentResponse,
|
||||
@@ -97,9 +102,11 @@ class SimpleGeminiStreamingInterface:
|
||||
|
||||
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
|
||||
"""This is (unusually) in chunked format, instead of merged"""
|
||||
has_reasoning = any(isinstance(c, ReasoningContent) for c in self.content_parts)
|
||||
for content in self.content_parts:
|
||||
if isinstance(content, ReasoningContent):
|
||||
# This assumes there is only one signature per turn
|
||||
content.signature = self.thinking_signature
|
||||
elif isinstance(content, TextContent) and not has_reasoning and self.thinking_signature:
|
||||
content.signature = self.thinking_signature
|
||||
return self.content_parts
|
||||
|
||||
@@ -322,15 +329,18 @@ class SimpleGeminiStreamingInterface:
|
||||
self.collected_tool_calls.append(ToolCall(id=call_id, function=FunctionCall(name=name, arguments=arguments_str)))
|
||||
|
||||
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
tool_call_id=call_id,
|
||||
)
|
||||
|
||||
yield ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
tool_call_id=call_id,
|
||||
),
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
@@ -14,6 +19,7 @@ from openai.types.responses import (
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
@@ -314,9 +320,6 @@ class OpenAIStreamingInterface:
|
||||
# Track events for diagnostics
|
||||
self.total_events_received += 1
|
||||
self.last_event_type = "ChatCompletionChunk"
|
||||
# Track events for diagnostics
|
||||
self.total_events_received += 1
|
||||
self.last_event_type = "ChatCompletionChunk"
|
||||
|
||||
if not self.model or not self.message_id:
|
||||
self.model = chunk.model
|
||||
@@ -414,25 +417,22 @@ class OpenAIStreamingInterface:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
self.tool_call_name = str(self._get_function_name_buffer())
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=None,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
)
|
||||
if self.tool_call_name in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=None,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
),
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=None,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
@@ -471,7 +471,7 @@ class OpenAIStreamingInterface:
|
||||
# Minimal, robust extraction: only emit the value of "message".
|
||||
# If we buffered a prefix while name was streaming, feed it first.
|
||||
if self._function_args_buffer_parts:
|
||||
payload = "".join(self._function_args_buffer_parts + [tool_call.function.arguments])
|
||||
payload = "".join([*self._function_args_buffer_parts, tool_call.function.arguments])
|
||||
self._function_args_buffer_parts = None
|
||||
else:
|
||||
payload = tool_call.function.arguments
|
||||
@@ -498,29 +498,26 @@ class OpenAIStreamingInterface:
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self._function_args_buffer_parts:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = "".join(self._function_args_buffer_parts + [updates_main_json])
|
||||
combined_chunk = "".join([*self._function_args_buffer_parts, updates_main_json])
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
)
|
||||
if self._get_function_name_buffer() in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
),
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
@@ -540,26 +537,23 @@ class OpenAIStreamingInterface:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
)
|
||||
if self._get_function_name_buffer() in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
),
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self._get_current_function_id(),
|
||||
)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
@@ -588,7 +582,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
messages: Optional[list] = None,
|
||||
tools: Optional[list] = None,
|
||||
requires_approval_tools: list = [],
|
||||
model: str = None,
|
||||
model: str | None = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
@@ -639,7 +633,6 @@ class SimpleOpenAIStreamingInterface:
|
||||
|
||||
def get_content(self) -> list[TextContent | OmittedReasoningContent | ReasoningContent]:
|
||||
shown_omitted = False
|
||||
concat_content = ""
|
||||
merged_messages = []
|
||||
reasoning_content = []
|
||||
concat_content_parts: list[str] = []
|
||||
@@ -837,6 +830,10 @@ class SimpleOpenAIStreamingInterface:
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
# Track events for diagnostics
|
||||
self.total_events_received += 1
|
||||
self.last_event_type = "ChatCompletionChunk"
|
||||
|
||||
if not self.model or not self.message_id:
|
||||
self.model = chunk.model
|
||||
self.message_id = chunk.id
|
||||
@@ -887,14 +884,10 @@ class SimpleOpenAIStreamingInterface:
|
||||
prev_message_type = assistant_msg.message_type
|
||||
yield assistant_msg
|
||||
|
||||
if (
|
||||
hasattr(chunk, "choices")
|
||||
and len(chunk.choices) > 0
|
||||
and hasattr(chunk.choices[0], "delta")
|
||||
and hasattr(chunk.choices[0].delta, "reasoning_content")
|
||||
):
|
||||
if hasattr(chunk, "choices") and len(chunk.choices) > 0 and hasattr(chunk.choices[0], "delta"):
|
||||
delta = chunk.choices[0].delta
|
||||
reasoning_content = getattr(delta, "reasoning_content", None)
|
||||
# Check for reasoning_content (standard) or reasoning (OpenRouter)
|
||||
reasoning_content = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
||||
if reasoning_content is not None and reasoning_content != "":
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
@@ -945,7 +938,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
if resolved_id is None:
|
||||
continue
|
||||
|
||||
delta = ToolCallDelta(
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=tool_call.function.name if (tool_call.function and tool_call.function.name) else None,
|
||||
arguments=tool_call.function.arguments if (tool_call.function and tool_call.function.arguments) else None,
|
||||
tool_call_id=resolved_id,
|
||||
@@ -956,7 +949,8 @@ class SimpleOpenAIStreamingInterface:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=delta,
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
@@ -967,8 +961,8 @@ class SimpleOpenAIStreamingInterface:
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=delta,
|
||||
tool_calls=delta,
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
@@ -988,7 +982,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
messages: Optional[list] = None,
|
||||
tools: Optional[list] = None,
|
||||
requires_approval_tools: list = [],
|
||||
model: str = None,
|
||||
model: str | None = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
@@ -1029,6 +1023,9 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
self.last_event_type: str | None = None
|
||||
self.total_events_received: int = 0
|
||||
self.stream_was_cancelled: bool = False
|
||||
# For downstream finish_reason mapping (e.g. max_output_tokens -> "length")
|
||||
# None means no incomplete reason was observed.
|
||||
self.incomplete_reason: str | None = None
|
||||
|
||||
# -------- Mapping helpers (no broad try/except) --------
|
||||
def _record_tool_mapping(self, event: object, item: object) -> tuple[str | None, str | None, int | None, str | None]:
|
||||
@@ -1089,6 +1086,10 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
text=response.content[0].text,
|
||||
)
|
||||
)
|
||||
elif len(response.content) == 0:
|
||||
# Incomplete responses may have an output message with no content parts
|
||||
# (model started the message item but hit max_output_tokens before producing text)
|
||||
logger.warning("ResponseOutputMessage has 0 content parts (likely from an incomplete response), skipping.")
|
||||
else:
|
||||
raise ValueError(f"Got {len(response.content)} content parts, expected 1")
|
||||
|
||||
@@ -1254,8 +1255,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
if isinstance(new_event_item, ResponseReasoningItem):
|
||||
# Look for summary delta, or encrypted_content
|
||||
summary = new_event_item.summary
|
||||
content = new_event_item.content # NOTE: always none
|
||||
encrypted_content = new_event_item.encrypted_content
|
||||
# TODO change to summarize reasoning message, but we need to figure out the streaming indices of summary problem
|
||||
concat_summary = "".join([s.text for s in summary])
|
||||
if concat_summary != "":
|
||||
@@ -1283,27 +1282,24 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
self.tool_call_name = name
|
||||
# Record mapping so subsequent argument deltas can be associated
|
||||
self._record_tool_mapping(event, new_event_item)
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=name,
|
||||
arguments=arguments if arguments != "" else None,
|
||||
tool_call_id=call_id,
|
||||
)
|
||||
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
||||
yield ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=name,
|
||||
arguments=arguments if arguments != "" else None,
|
||||
tool_call_id=call_id,
|
||||
),
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=name,
|
||||
arguments=arguments if arguments != "" else None,
|
||||
tool_call_id=call_id,
|
||||
)
|
||||
yield ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
@@ -1394,7 +1390,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
# NOTE: is this inclusive of the deltas?
|
||||
# If not, we should add it to the rolling
|
||||
summary_index = event.summary_index
|
||||
text = event.text
|
||||
return
|
||||
|
||||
# Reasoning summary streaming
|
||||
@@ -1436,7 +1431,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
# Assistant message streaming
|
||||
elif isinstance(event, ResponseTextDoneEvent):
|
||||
# NOTE: inclusive, can skip
|
||||
text = event.text
|
||||
return
|
||||
|
||||
# Assistant message done
|
||||
@@ -1451,7 +1445,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
delta = event.delta
|
||||
|
||||
# Resolve tool_call_id/name using output_index or item_id
|
||||
resolved_call_id, resolved_name, out_idx, item_id = self._resolve_mapping_for_delta(event)
|
||||
resolved_call_id, resolved_name, _out_idx, _item_id = self._resolve_mapping_for_delta(event)
|
||||
|
||||
# Fallback to last seen tool name for approval routing if mapping name missing
|
||||
if not resolved_name:
|
||||
@@ -1462,27 +1456,24 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
return
|
||||
|
||||
# We have a call id; emit approval or tool-call message accordingly
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=None,
|
||||
arguments=delta,
|
||||
tool_call_id=resolved_call_id,
|
||||
)
|
||||
if resolved_name and resolved_name in self.requires_approval_tools:
|
||||
yield ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=delta,
|
||||
tool_call_id=resolved_call_id,
|
||||
),
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=None,
|
||||
arguments=delta,
|
||||
tool_call_id=resolved_call_id,
|
||||
)
|
||||
yield ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
@@ -1497,7 +1488,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
# Function calls
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
||||
# NOTE: inclusive
|
||||
full_args = event.arguments
|
||||
return
|
||||
|
||||
# Generic
|
||||
@@ -1506,31 +1496,55 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
return
|
||||
|
||||
# Generic finish
|
||||
elif isinstance(event, ResponseCompletedEvent):
|
||||
# NOTE we can "rebuild" the final state of the stream using the values in here, instead of relying on the accumulators
|
||||
elif isinstance(event, (ResponseCompletedEvent, ResponseIncompleteEvent)):
|
||||
# ResponseIncompleteEvent has the same response structure as ResponseCompletedEvent,
|
||||
# but indicates the response was cut short (e.g. due to max_output_tokens).
|
||||
# We still extract the partial response and usage data so they aren't silently lost.
|
||||
if isinstance(event, ResponseIncompleteEvent):
|
||||
self.incomplete_reason = (
|
||||
getattr(event.response.incomplete_details, "reason", None) if event.response.incomplete_details else None
|
||||
)
|
||||
reason = self.incomplete_reason or "unknown"
|
||||
logger.warning(
|
||||
f"OpenAI Responses API returned an incomplete response (reason: {reason}). "
|
||||
f"Model: {event.response.model}, output_tokens: {event.response.usage.output_tokens if event.response.usage else 'N/A'}. "
|
||||
f"The partial response content will still be used."
|
||||
)
|
||||
|
||||
self.final_response = event.response
|
||||
self.model = event.response.model
|
||||
self.input_tokens = event.response.usage.input_tokens
|
||||
self.output_tokens = event.response.usage.output_tokens
|
||||
self.message_id = event.response.id
|
||||
# Store raw usage for transparent provider trace logging
|
||||
try:
|
||||
self.raw_usage = event.response.usage.model_dump(exclude_none=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to capture raw_usage from OpenAI Responses API: {e}")
|
||||
self.raw_usage = None
|
||||
# Capture cache token details (Responses API uses input_tokens_details)
|
||||
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
|
||||
if hasattr(event.response.usage, "input_tokens_details") and event.response.usage.input_tokens_details:
|
||||
details = event.response.usage.input_tokens_details
|
||||
if hasattr(details, "cached_tokens") and details.cached_tokens is not None:
|
||||
self.cached_tokens = details.cached_tokens
|
||||
# Capture reasoning token details (Responses API uses output_tokens_details)
|
||||
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
|
||||
if hasattr(event.response.usage, "output_tokens_details") and event.response.usage.output_tokens_details:
|
||||
details = event.response.usage.output_tokens_details
|
||||
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens is not None:
|
||||
self.reasoning_tokens = details.reasoning_tokens
|
||||
|
||||
usage = event.response.usage
|
||||
if usage is not None:
|
||||
self.input_tokens = usage.input_tokens
|
||||
self.output_tokens = usage.output_tokens
|
||||
|
||||
# Store raw usage for transparent provider trace logging
|
||||
try:
|
||||
self.raw_usage = usage.model_dump(exclude_none=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to capture raw_usage from OpenAI Responses API: {e}")
|
||||
self.raw_usage = None
|
||||
|
||||
# Capture cache token details (Responses API uses input_tokens_details)
|
||||
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
|
||||
if hasattr(usage, "input_tokens_details") and usage.input_tokens_details:
|
||||
details = usage.input_tokens_details
|
||||
if hasattr(details, "cached_tokens") and details.cached_tokens is not None:
|
||||
self.cached_tokens = details.cached_tokens
|
||||
|
||||
# Capture reasoning token details (Responses API uses output_tokens_details)
|
||||
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
|
||||
if hasattr(usage, "output_tokens_details") and usage.output_tokens_details:
|
||||
details = usage.output_tokens_details
|
||||
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens is not None:
|
||||
self.reasoning_tokens = details.reasoning_tokens
|
||||
else:
|
||||
logger.warning(
|
||||
"OpenAI Responses API finish event had no usage payload. "
|
||||
"Proceeding with partial response but token metrics may be incomplete."
|
||||
)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
@@ -94,7 +94,7 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
|
||||
if scheduler.running:
|
||||
try:
|
||||
scheduler.shutdown(wait=False)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
finally:
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.errors import (
|
||||
LLMAuthenticationError,
|
||||
LLMBadRequestError,
|
||||
LLMConnectionError,
|
||||
LLMInsufficientCreditsError,
|
||||
LLMNotFoundError,
|
||||
LLMPermissionDeniedError,
|
||||
LLMProviderOverloaded,
|
||||
@@ -29,13 +30,16 @@ from letta.errors import (
|
||||
)
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.decorators import deprecated
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.anthropic_constants import ANTHROPIC_MAX_STRICT_TOOLS, ANTHROPIC_STRICT_MODE_ALLOWLIST
|
||||
from letta.llm_api.error_utils import is_insufficient_credits_message
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
@@ -45,9 +49,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
FunctionCall,
|
||||
Message as ChoiceMessage,
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from letta.schemas.response_format import JsonSchemaResponseFormat
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.settings import model_settings
|
||||
|
||||
@@ -62,10 +64,16 @@ class AnthropicClient(LLMClientBase):
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
client = self._get_anthropic_client(llm_config, async_client=False)
|
||||
betas: list[str] = []
|
||||
# Interleaved thinking for reasoner (sync path parity)
|
||||
|
||||
# Opus 4.6 / Sonnet 4.6 Auto Thinking
|
||||
if llm_config.enable_reasoner:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
# 1M context beta for Sonnet 4/4.5 when enabled
|
||||
if llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6"):
|
||||
betas.append("adaptive-thinking-2026-01-28")
|
||||
# Interleaved thinking for other reasoners (sync path parity)
|
||||
else:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
|
||||
# 1M context beta for Sonnet 4/4.5 or Opus 4.6 when enabled
|
||||
try:
|
||||
from letta.settings import model_settings
|
||||
|
||||
@@ -73,12 +81,23 @@ class AnthropicClient(LLMClientBase):
|
||||
llm_config.model.startswith("claude-sonnet-4") or llm_config.model.startswith("claude-sonnet-4-5")
|
||||
):
|
||||
betas.append("context-1m-2025-08-07")
|
||||
elif model_settings.anthropic_opus_1m and llm_config.model.startswith("claude-opus-4-6"):
|
||||
betas.append("context-1m-2025-08-07")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Opus 4.5 effort parameter - to extend to other models, modify the model check
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.effort is not None:
|
||||
# Effort parameter for Opus 4.5, Opus 4.6, and Sonnet 4.6 - to extend to other models, modify the model check
|
||||
if (
|
||||
llm_config.model.startswith("claude-opus-4-5")
|
||||
or llm_config.model.startswith("claude-opus-4-6")
|
||||
or llm_config.model.startswith("claude-sonnet-4-6")
|
||||
) and llm_config.effort is not None:
|
||||
betas.append("effort-2025-11-24")
|
||||
# Max effort beta for Opus 4.6 / Sonnet 4.6
|
||||
if (
|
||||
llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6")
|
||||
) and llm_config.effort == "max":
|
||||
betas.append("max-effort-2026-01-24")
|
||||
|
||||
# Context management for Opus 4.5 to preserve thinking blocks (improves cache hits)
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
|
||||
@@ -88,21 +107,46 @@ class AnthropicClient(LLMClientBase):
|
||||
if llm_config.strict and _supports_structured_outputs(llm_config.model):
|
||||
betas.append("structured-outputs-2025-11-13")
|
||||
|
||||
if betas:
|
||||
response = client.beta.messages.create(**request_data, betas=betas)
|
||||
else:
|
||||
response = client.beta.messages.create(**request_data)
|
||||
return response.model_dump()
|
||||
try:
|
||||
if betas:
|
||||
response = client.beta.messages.create(**request_data, betas=betas)
|
||||
else:
|
||||
response = client.beta.messages.create(**request_data)
|
||||
return response.model_dump()
|
||||
except ValueError as e:
|
||||
# Anthropic SDK raises ValueError when streaming is required for long-running operations
|
||||
# See: https://github.com/anthropics/anthropic-sdk-python#streaming
|
||||
if "streaming is required" in str(e).lower():
|
||||
logger.warning(
|
||||
"[Anthropic] Non-streaming request rejected due to potential long duration. Error: %s. "
|
||||
"Note: Synchronous fallback to streaming is not supported. Use async API instead.",
|
||||
str(e),
|
||||
)
|
||||
# Re-raise as LLMBadRequestError (maps to 502 Bad Gateway) since this is a downstream provider constraint
|
||||
raise LLMBadRequestError(
|
||||
message="This operation may take longer than 10 minutes and requires streaming. "
|
||||
"Please use the async API (request_async) instead of the deprecated sync API. "
|
||||
f"Original error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
) from e
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
betas: list[str] = []
|
||||
# interleaved thinking for reasoner
|
||||
if llm_config.enable_reasoner:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
|
||||
# 1M context beta for Sonnet 4/4.5 when enabled
|
||||
# Opus 4.6 / Sonnet 4.6 Auto Thinking
|
||||
if llm_config.enable_reasoner:
|
||||
if llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6"):
|
||||
betas.append("adaptive-thinking-2026-01-28")
|
||||
# Interleaved thinking for other reasoners (sync path parity)
|
||||
else:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
|
||||
# 1M context beta for Sonnet 4/4.5 or Opus 4.6 when enabled
|
||||
try:
|
||||
from letta.settings import model_settings
|
||||
|
||||
@@ -110,12 +154,23 @@ class AnthropicClient(LLMClientBase):
|
||||
llm_config.model.startswith("claude-sonnet-4") or llm_config.model.startswith("claude-sonnet-4-5")
|
||||
):
|
||||
betas.append("context-1m-2025-08-07")
|
||||
elif model_settings.anthropic_opus_1m and llm_config.model.startswith("claude-opus-4-6"):
|
||||
betas.append("context-1m-2025-08-07")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Opus 4.5 effort parameter - to extend to other models, modify the model check
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.effort is not None:
|
||||
# Effort parameter for Opus 4.5, Opus 4.6, and Sonnet 4.6 - to extend to other models, modify the model check
|
||||
if (
|
||||
llm_config.model.startswith("claude-opus-4-5")
|
||||
or llm_config.model.startswith("claude-opus-4-6")
|
||||
or llm_config.model.startswith("claude-sonnet-4-6")
|
||||
) and llm_config.effort is not None:
|
||||
betas.append("effort-2025-11-24")
|
||||
# Max effort beta for Opus 4.6 / Sonnet 4.6
|
||||
if (
|
||||
llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6")
|
||||
) and llm_config.effort == "max":
|
||||
betas.append("max-effort-2026-01-24")
|
||||
|
||||
# Context management for Opus 4.5 to preserve thinking blocks (improves cache hits)
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
|
||||
@@ -254,6 +309,8 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
request_data["stream"] = True
|
||||
|
||||
@@ -262,12 +319,15 @@ class AnthropicClient(LLMClientBase):
|
||||
# See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/fine-grained-streaming
|
||||
betas = ["fine-grained-tool-streaming-2025-05-14"]
|
||||
|
||||
# If extended thinking, turn on interleaved header
|
||||
# https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
||||
# Opus 4.6 / Sonnet 4.6 Auto Thinking
|
||||
if llm_config.enable_reasoner:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
if llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6"):
|
||||
betas.append("adaptive-thinking-2026-01-28")
|
||||
# Interleaved thinking for other reasoners (sync path parity)
|
||||
else:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
|
||||
# 1M context beta for Sonnet 4/4.5 when enabled
|
||||
# 1M context beta for Sonnet 4/4.5 or Opus 4.6 when enabled
|
||||
try:
|
||||
from letta.settings import model_settings
|
||||
|
||||
@@ -275,12 +335,23 @@ class AnthropicClient(LLMClientBase):
|
||||
llm_config.model.startswith("claude-sonnet-4") or llm_config.model.startswith("claude-sonnet-4-5")
|
||||
):
|
||||
betas.append("context-1m-2025-08-07")
|
||||
elif model_settings.anthropic_opus_1m and llm_config.model.startswith("claude-opus-4-6"):
|
||||
betas.append("context-1m-2025-08-07")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Opus 4.5 effort parameter - to extend to other models, modify the model check
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.effort is not None:
|
||||
# Effort parameter for Opus 4.5, Opus 4.6, and Sonnet 4.6 - to extend to other models, modify the model check
|
||||
if (
|
||||
llm_config.model.startswith("claude-opus-4-5")
|
||||
or llm_config.model.startswith("claude-opus-4-6")
|
||||
or llm_config.model.startswith("claude-sonnet-4-6")
|
||||
) and llm_config.effort is not None:
|
||||
betas.append("effort-2025-11-24")
|
||||
# Max effort beta for Opus 4.6 / Sonnet 4.6
|
||||
if (
|
||||
llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6")
|
||||
) and llm_config.effort == "max":
|
||||
betas.append("max-effort-2026-01-24")
|
||||
|
||||
# Context management for Opus 4.5 to preserve thinking blocks (improves cache hits)
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.enable_reasoner:
|
||||
@@ -335,7 +406,7 @@ class AnthropicClient(LLMClientBase):
|
||||
for agent_id in agent_messages_mapping
|
||||
}
|
||||
|
||||
client = await self._get_anthropic_client_async(list(agent_llm_config_mapping.values())[0], async_client=True)
|
||||
client = await self._get_anthropic_client_async(next(iter(agent_llm_config_mapping.values())), async_client=True)
|
||||
|
||||
anthropic_requests = [
|
||||
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
|
||||
@@ -461,25 +532,43 @@ class AnthropicClient(LLMClientBase):
|
||||
}
|
||||
|
||||
# Extended Thinking
|
||||
if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
|
||||
thinking_budget = max(llm_config.max_reasoning_tokens, 1024)
|
||||
if thinking_budget != llm_config.max_reasoning_tokens:
|
||||
logger.warning(
|
||||
f"Max reasoning tokens must be at least 1024 for Claude. Setting max_reasoning_tokens to 1024 for model {llm_config.model}."
|
||||
)
|
||||
data["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": thinking_budget,
|
||||
}
|
||||
# Note: Anthropic does not allow thinking when forcing tool use with split_thread_agent
|
||||
should_enable_thinking = (
|
||||
self.is_reasoning_model(llm_config)
|
||||
and llm_config.enable_reasoner
|
||||
and not (agent_type == AgentType.split_thread_agent and force_tool_call is not None)
|
||||
)
|
||||
|
||||
if should_enable_thinking:
|
||||
# Opus 4.6 / Sonnet 4.6 uses Auto Thinking (no budget tokens)
|
||||
if llm_config.model.startswith("claude-opus-4-6") or llm_config.model.startswith("claude-sonnet-4-6"):
|
||||
data["thinking"] = {
|
||||
"type": "adaptive",
|
||||
}
|
||||
else:
|
||||
# Traditional extended thinking with budget tokens
|
||||
thinking_budget = max(llm_config.max_reasoning_tokens, 1024)
|
||||
if thinking_budget != llm_config.max_reasoning_tokens:
|
||||
logger.warning(
|
||||
f"Max reasoning tokens must be at least 1024 for Claude. Setting max_reasoning_tokens to 1024 for model {llm_config.model}."
|
||||
)
|
||||
data["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": thinking_budget,
|
||||
}
|
||||
# `temperature` may only be set to 1 when thinking is enabled. Please consult our documentation at https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking'
|
||||
data["temperature"] = 1.0
|
||||
|
||||
# Silently disable prefix_fill for now
|
||||
prefix_fill = False
|
||||
|
||||
# Effort configuration for Opus 4.5 (controls token spending)
|
||||
# Effort configuration for Opus 4.5, Opus 4.6, and Sonnet 4.6 (controls token spending)
|
||||
# To extend to other models, modify the model check
|
||||
if llm_config.model.startswith("claude-opus-4-5") and llm_config.effort is not None:
|
||||
if (
|
||||
llm_config.model.startswith("claude-opus-4-5")
|
||||
or llm_config.model.startswith("claude-opus-4-6")
|
||||
or llm_config.model.startswith("claude-sonnet-4-6")
|
||||
) and llm_config.effort is not None:
|
||||
data["output_config"] = {"effort": llm_config.effort}
|
||||
|
||||
# Context management for Opus 4.5 to preserve thinking blocks and improve cache hits
|
||||
@@ -510,11 +599,16 @@ class AnthropicClient(LLMClientBase):
|
||||
# Special case for summarization path
|
||||
tools_for_request = None
|
||||
tool_choice = None
|
||||
elif self.is_reasoning_model(llm_config) and llm_config.enable_reasoner or agent_type == AgentType.letta_v1_agent:
|
||||
elif (self.is_reasoning_model(llm_config) and llm_config.enable_reasoner) or agent_type == AgentType.letta_v1_agent:
|
||||
# NOTE: reasoning models currently do not allow for `any`
|
||||
# NOTE: react agents should always have auto on, since the precense/absense of tool calls controls chaining
|
||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools]
|
||||
# NOTE: react agents should always have at least auto on, since the precense/absense of tool calls controls chaining
|
||||
if agent_type == AgentType.split_thread_agent and force_tool_call is not None:
|
||||
tool_choice = {"type": "tool", "name": force_tool_call, "disable_parallel_tool_use": True}
|
||||
# When forcing a specific tool, only include that tool
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools if f["name"] == force_tool_call]
|
||||
else:
|
||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools]
|
||||
elif force_tool_call is not None:
|
||||
tool_choice = {"type": "tool", "name": force_tool_call, "disable_parallel_tool_use": True}
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools if f["name"] == force_tool_call]
|
||||
@@ -691,7 +785,9 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
return data
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
||||
async def count_tokens(
|
||||
self, messages: List[dict] | None = None, model: str | None = None, tools: List[OpenAITool] | None = None
|
||||
) -> int:
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
# Use the default client; token counting is lightweight and does not require BYOK overrides
|
||||
client = anthropic.AsyncAnthropic()
|
||||
@@ -811,6 +907,8 @@ class AnthropicClient(LLMClientBase):
|
||||
and (model.startswith("claude-sonnet-4") or model.startswith("claude-sonnet-4-5"))
|
||||
):
|
||||
betas.append("context-1m-2025-08-07")
|
||||
elif model and model_settings.anthropic_opus_1m and model.startswith("claude-opus-4-6"):
|
||||
betas.append("context-1m-2025-08-07")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -851,10 +949,16 @@ class AnthropicClient(LLMClientBase):
|
||||
or llm_config.model.startswith("claude-haiku-4-5")
|
||||
# Opus 4.5 support - to extend effort parameter to other models, modify this check
|
||||
or llm_config.model.startswith("claude-opus-4-5")
|
||||
# Opus 4.6 support - uses Auto Thinking
|
||||
or llm_config.model.startswith("claude-opus-4-6")
|
||||
# Sonnet 4.6 support - same API as Opus 4.6
|
||||
or llm_config.model.startswith("claude-sonnet-4-6")
|
||||
)
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
|
||||
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
|
||||
|
||||
# make sure to check for overflow errors, regardless of error type
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
@@ -869,6 +973,7 @@ class AnthropicClient(LLMClientBase):
|
||||
logger.warning(f"[Anthropic] Context window exceeded: {str(e)}")
|
||||
return ContextWindowExceededError(
|
||||
message=f"Context window exceeded for Anthropic: {str(e)}",
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.APITimeoutError):
|
||||
@@ -876,7 +981,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMTimeoutError(
|
||||
message=f"Request to Anthropic timed out: {str(e)}",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.APIConnectionError):
|
||||
@@ -884,7 +989,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Failed to connect to Anthropic: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# Handle httpx.RemoteProtocolError which can occur during streaming
|
||||
@@ -895,7 +1000,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Connection error during Anthropic streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# Handle httpx network errors which can occur during streaming
|
||||
@@ -905,7 +1010,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Network error during Anthropic streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.RateLimitError):
|
||||
@@ -913,6 +1018,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMRateLimitError(
|
||||
message=f"Rate limited by Anthropic: {str(e)}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.BadRequestError):
|
||||
@@ -930,11 +1036,13 @@ class AnthropicClient(LLMClientBase):
|
||||
# 400 - {'type': 'error', 'error': {'type': 'invalid_request_error', 'message': 'input length and `max_tokens` exceed context limit: 173298 + 32000 > 200000, decrease input length or `max_tokens` and try again'}}
|
||||
return ContextWindowExceededError(
|
||||
message=f"Bad request to Anthropic (context window exceeded): {str(e)}",
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
else:
|
||||
return LLMBadRequestError(
|
||||
message=f"Bad request to Anthropic: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.AuthenticationError):
|
||||
@@ -942,6 +1050,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMAuthenticationError(
|
||||
message=f"Authentication failed with Anthropic: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.PermissionDeniedError):
|
||||
@@ -949,6 +1058,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMPermissionDeniedError(
|
||||
message=f"Permission denied by Anthropic: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.NotFoundError):
|
||||
@@ -956,6 +1066,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMNotFoundError(
|
||||
message=f"Resource not found in Anthropic: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.UnprocessableEntityError):
|
||||
@@ -963,23 +1074,29 @@ class AnthropicClient(LLMClientBase):
|
||||
return LLMUnprocessableEntityError(
|
||||
message=f"Invalid request content for Anthropic: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.APIStatusError):
|
||||
logger.warning(f"[Anthropic] API status error: {str(e)}")
|
||||
# Handle 413 Request Entity Too Large - request payload exceeds size limits
|
||||
if hasattr(e, "status_code") and e.status_code == 413:
|
||||
logger.warning(f"[Anthropic] Request too large (413): {str(e)}")
|
||||
return ContextWindowExceededError(
|
||||
message=f"Request too large for Anthropic (413): {str(e)}",
|
||||
if isinstance(e, anthropic.InternalServerError):
|
||||
error_str = str(e).lower()
|
||||
if "overflow" in error_str or "upstream connect error" in error_str:
|
||||
logger.warning(f"[Anthropic] Upstream infrastructure error (transient): {str(e)}")
|
||||
return LLMServerError(
|
||||
message=f"Anthropic upstream infrastructure error (transient, may resolve on retry): {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.status_code if hasattr(e, "status_code") else None,
|
||||
"transient": True,
|
||||
},
|
||||
)
|
||||
if "overloaded" in str(e).lower():
|
||||
if "overloaded" in error_str:
|
||||
return LLMProviderOverloaded(
|
||||
message=f"Anthropic API is overloaded: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
logger.warning(f"[Anthropic] Internal server error: {str(e)}")
|
||||
return LLMServerError(
|
||||
message=f"Anthropic API error: {str(e)}",
|
||||
message=f"Anthropic internal server error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.status_code if hasattr(e, "status_code") else None,
|
||||
@@ -987,7 +1104,38 @@ class AnthropicClient(LLMClientBase):
|
||||
},
|
||||
)
|
||||
|
||||
return super().handle_llm_error(e)
|
||||
if isinstance(e, anthropic.APIStatusError):
|
||||
logger.warning(f"[Anthropic] API status error: {str(e)}")
|
||||
if (hasattr(e, "status_code") and e.status_code == 402) or is_insufficient_credits_message(str(e)):
|
||||
msg = str(e)
|
||||
return LLMInsufficientCreditsError(
|
||||
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
|
||||
code=ErrorCode.PAYMENT_REQUIRED,
|
||||
details={"status_code": getattr(e, "status_code", None), "is_byok": is_byok},
|
||||
)
|
||||
if hasattr(e, "status_code") and e.status_code == 413:
|
||||
logger.warning(f"[Anthropic] Request too large (413): {str(e)}")
|
||||
return ContextWindowExceededError(
|
||||
message=f"Request too large for Anthropic (413): {str(e)}",
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
if "overloaded" in str(e).lower():
|
||||
return LLMProviderOverloaded(
|
||||
message=f"Anthropic API is overloaded: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
return LLMServerError(
|
||||
message=f"Anthropic API error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.status_code if hasattr(e, "status_code") else None,
|
||||
"response": str(e.response) if hasattr(e, "response") else None,
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
|
||||
return super().handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||
"""Extract usage statistics from Anthropic response and return as LettaUsageStatistics."""
|
||||
@@ -1027,6 +1175,11 @@ class AnthropicClient(LLMClientBase):
|
||||
input_messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
if isinstance(response_data, str):
|
||||
raise LLMServerError(
|
||||
message="Anthropic endpoint returned a raw string instead of a JSON object. This usually indicates the endpoint URL is incorrect or returned an error page.",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
"""
|
||||
Example response from Claude 3:
|
||||
response.json = {
|
||||
@@ -1096,7 +1249,7 @@ class AnthropicClient(LLMClientBase):
|
||||
args_json = json.loads(arguments)
|
||||
if not isinstance(args_json, dict):
|
||||
raise LLMServerError("Expected parseable json object for arguments")
|
||||
except:
|
||||
except Exception:
|
||||
arguments = str(tool_input["function"]["arguments"])
|
||||
else:
|
||||
arguments = json.dumps(tool_input, indent=2)
|
||||
@@ -1117,7 +1270,23 @@ class AnthropicClient(LLMClientBase):
|
||||
redacted_reasoning_content = content_part.data
|
||||
|
||||
else:
|
||||
raise RuntimeError("Unexpected empty content in response")
|
||||
# Log the full response for debugging
|
||||
logger.error(
|
||||
"[Anthropic] Received response with empty content. Response ID: %s, Model: %s, Stop reason: %s, Full response: %s",
|
||||
response.id,
|
||||
response.model,
|
||||
response.stop_reason,
|
||||
json.dumps(response_data),
|
||||
)
|
||||
raise LLMServerError(
|
||||
message=f"LLM provider returned empty content in response (ID: {response.id}, model: {response.model}, stop_reason: {response.stop_reason})",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"response_id": response.id,
|
||||
"model": response.model,
|
||||
"stop_reason": response.stop_reason,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.role == "assistant"
|
||||
choice = Choice(
|
||||
@@ -1372,7 +1541,7 @@ def is_heartbeat(message: dict, is_ping: bool = False) -> bool:
|
||||
|
||||
try:
|
||||
message_json = json.loads(message["content"])
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Check if message_json is a dict (not int, str, list, etc.)
|
||||
|
||||
@@ -1,18 +1,31 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AzureClient(OpenAIClient):
|
||||
@staticmethod
|
||||
def _is_v1_endpoint(base_url: str) -> bool:
|
||||
if not base_url:
|
||||
return False
|
||||
return base_url.rstrip("/").endswith("/openai/v1")
|
||||
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
@@ -29,38 +42,99 @@ class AzureClient(OpenAIClient):
|
||||
|
||||
return None, None, None
|
||||
|
||||
def _resolve_credentials(self, api_key, base_url, api_version):
|
||||
"""Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required."""
|
||||
if not api_key:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
if not base_url:
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
if not api_version and not self._is_v1_endpoint(base_url):
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
return api_key, base_url, api_version
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
|
||||
if not api_key or not base_url or not api_version:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
||||
if not api_key or not base_url or not api_version:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
try:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
return response.model_dump()
|
||||
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
try:
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = await client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
try:
|
||||
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
else:
|
||||
try:
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
@@ -68,7 +142,12 @@ class AzureClient(OpenAIClient):
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
|
||||
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
|
||||
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
# TODO: add total usage
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""ChatGPT OAuth Client - handles requests to chatgpt.com/backend-api/codex/responses."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseCompletedEvent,
|
||||
@@ -32,6 +32,7 @@ from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
from letta.errors import (
|
||||
ContextWindowExceededError,
|
||||
ErrorCode,
|
||||
LettaError,
|
||||
LLMAuthenticationError,
|
||||
LLMBadRequestError,
|
||||
LLMConnectionError,
|
||||
@@ -39,6 +40,7 @@ from letta.errors import (
|
||||
LLMServerError,
|
||||
LLMTimeoutError,
|
||||
)
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
@@ -47,11 +49,6 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
FunctionCall,
|
||||
Message as ChoiceMessage,
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from letta.schemas.providers.chatgpt_oauth import ChatGPTOAuthCredentials, ChatGPTOAuthProvider
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
@@ -100,6 +97,10 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
4. Transforms responses back to OpenAI ChatCompletion format
|
||||
"""
|
||||
|
||||
MAX_RETRIES = 3
|
||||
# Transient httpx errors that are safe to retry (connection drops, transport-level failures)
|
||||
_RETRYABLE_ERRORS = (httpx.ReadError, httpx.WriteError, httpx.ConnectError, httpx.RemoteProtocolError, LLMConnectionError)
|
||||
|
||||
@trace_method
|
||||
async def _get_provider_and_credentials_async(self, llm_config: LLMConfig) -> tuple[ChatGPTOAuthProvider, ChatGPTOAuthCredentials]:
|
||||
"""Get the ChatGPT OAuth provider and credentials with automatic refresh if needed.
|
||||
@@ -153,6 +154,11 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
Returns:
|
||||
Dictionary of HTTP headers.
|
||||
"""
|
||||
if not creds.access_token:
|
||||
raise LLMAuthenticationError(
|
||||
message="ChatGPT OAuth access_token is empty or missing",
|
||||
code=ErrorCode.UNAUTHENTICATED,
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {creds.access_token}",
|
||||
"ChatGPT-Account-Id": creds.account_id,
|
||||
@@ -356,38 +362,68 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
Returns:
|
||||
Response data in OpenAI ChatCompletion format.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
_, creds = await self._get_provider_and_credentials_async(llm_config)
|
||||
headers = self._build_headers(creds)
|
||||
|
||||
endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT
|
||||
|
||||
# ChatGPT backend requires streaming, so we use client.stream() to handle SSE
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Retry on transient network errors with exponential backoff
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
# Accumulate SSE events into a final response
|
||||
return await self._accumulate_sse_response(response)
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
# Accumulate SSE events into a final response
|
||||
return await self._accumulate_sse_response(response)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_http_error(e)
|
||||
mapped = self._handle_http_error(e)
|
||||
if isinstance(mapped, tuple(self._RETRYABLE_ERRORS)) and attempt < self.MAX_RETRIES - 1:
|
||||
wait = 2**attempt
|
||||
logger.warning(
|
||||
f"[ChatGPT] Retryable HTTP error on request (attempt {attempt + 1}/{self.MAX_RETRIES}), "
|
||||
f"retrying in {wait}s: {type(mapped).__name__}: {mapped}"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise mapped
|
||||
except httpx.TimeoutException:
|
||||
raise LLMTimeoutError(
|
||||
message="ChatGPT backend request timed out",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
)
|
||||
except self._RETRYABLE_ERRORS as e:
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
wait = 2**attempt
|
||||
logger.warning(
|
||||
f"[ChatGPT] Transient error on request (attempt {attempt + 1}/{self.MAX_RETRIES}), "
|
||||
f"retrying in {wait}s: {type(e).__name__}: {e}"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise LLMConnectionError(
|
||||
message=f"Failed to connect to ChatGPT backend after {self.MAX_RETRIES} attempts: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise LLMConnectionError(
|
||||
message=f"Failed to connect to ChatGPT backend: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
# Should not be reached, but satisfy type checker
|
||||
raise LLMConnectionError(message="ChatGPT request failed after all retries", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
async def _accumulate_sse_response(self, response: httpx.Response) -> dict:
|
||||
"""Accumulate SSE stream into a final response.
|
||||
|
||||
@@ -550,64 +586,102 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
Returns:
|
||||
Async generator yielding ResponseStreamEvent objects.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
_, creds = await self._get_provider_and_credentials_async(llm_config)
|
||||
headers = self._build_headers(creds)
|
||||
|
||||
endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT
|
||||
|
||||
async def stream_generator():
|
||||
event_count = 0
|
||||
# Track output item index for proper event construction
|
||||
output_index = 0
|
||||
# Track sequence_number in case backend doesn't provide it
|
||||
# (OpenAI SDK expects incrementing sequence numbers starting at 0)
|
||||
sequence_counter = 0
|
||||
# Track whether we've yielded any events — once we have, we can't
|
||||
# transparently retry because the caller has already consumed partial data.
|
||||
has_yielded = False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as response:
|
||||
# Check for error status
|
||||
if response.status_code != 200:
|
||||
error_body = await response.aread()
|
||||
logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}")
|
||||
raise self._handle_http_error_from_status(response.status_code, error_body.decode())
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as response:
|
||||
# Check for error status
|
||||
if response.status_code != 200:
|
||||
error_body = await response.aread()
|
||||
logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}")
|
||||
raise self._handle_http_error_from_status(response.status_code, error_body.decode())
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
raw_event = json.loads(data_str)
|
||||
event_type = raw_event.get("type")
|
||||
event_count += 1
|
||||
try:
|
||||
raw_event = json.loads(data_str)
|
||||
event_type = raw_event.get("type")
|
||||
|
||||
# Use backend-provided sequence_number if available, else use counter
|
||||
# This ensures proper ordering even if backend doesn't provide it
|
||||
if "sequence_number" not in raw_event:
|
||||
raw_event["sequence_number"] = sequence_counter
|
||||
sequence_counter = raw_event["sequence_number"] + 1
|
||||
# Check for error events from the API (context window, rate limit, etc.)
|
||||
if event_type == "error":
|
||||
logger.error(f"ChatGPT SSE error event: {json.dumps(raw_event, default=str)[:1000]}")
|
||||
raise self._handle_sse_error_event(raw_event)
|
||||
|
||||
# Track output index for output_item.added events
|
||||
if event_type == "response.output_item.added":
|
||||
output_index = raw_event.get("output_index", output_index)
|
||||
# Check for response.failed or response.incomplete events
|
||||
if event_type in ("response.failed", "response.incomplete"):
|
||||
logger.error(f"ChatGPT SSE {event_type} event: {json.dumps(raw_event, default=str)[:1000]}")
|
||||
resp_obj = raw_event.get("response", {})
|
||||
error_info = resp_obj.get("error", {})
|
||||
if error_info:
|
||||
raise self._handle_sse_error_event({"error": error_info, "type": event_type})
|
||||
else:
|
||||
raise LLMBadRequestError(
|
||||
message=f"ChatGPT request failed with status '{event_type}' (no error details provided)",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
# Convert to OpenAI SDK ResponseStreamEvent
|
||||
sdk_event = self._convert_to_sdk_event(raw_event, output_index)
|
||||
if sdk_event:
|
||||
yield sdk_event
|
||||
# Use backend-provided sequence_number if available, else use counter
|
||||
# This ensures proper ordering even if backend doesn't provide it
|
||||
if "sequence_number" not in raw_event:
|
||||
raw_event["sequence_number"] = sequence_counter
|
||||
sequence_counter = raw_event["sequence_number"] + 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse SSE event: {data_str[:100]}")
|
||||
continue
|
||||
# Track output index for output_item.added events
|
||||
if event_type == "response.output_item.added":
|
||||
output_index = raw_event.get("output_index", output_index)
|
||||
|
||||
# Convert to OpenAI SDK ResponseStreamEvent
|
||||
sdk_event = self._convert_to_sdk_event(raw_event, output_index)
|
||||
if sdk_event:
|
||||
yield sdk_event
|
||||
has_yielded = True
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse SSE event: {data_str[:100]}")
|
||||
continue
|
||||
|
||||
# Stream completed successfully
|
||||
return
|
||||
|
||||
except self._RETRYABLE_ERRORS as e:
|
||||
if has_yielded or attempt >= self.MAX_RETRIES - 1:
|
||||
# Already yielded partial data or exhausted retries — must propagate
|
||||
raise
|
||||
wait = 2**attempt
|
||||
logger.warning(
|
||||
f"[ChatGPT] Transient error on stream (attempt {attempt + 1}/{self.MAX_RETRIES}), "
|
||||
f"retrying in {wait}s: {type(e).__name__}: {e}"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
# Wrap the async generator in AsyncStreamWrapper to provide context manager protocol
|
||||
return AsyncStreamWrapper(stream_generator())
|
||||
@@ -944,10 +1018,16 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
part=part,
|
||||
)
|
||||
|
||||
# Unhandled event types - log for debugging
|
||||
logger.debug(f"Unhandled SSE event type: {event_type}")
|
||||
# Unhandled event types
|
||||
logger.warning(f"Unhandled ChatGPT SSE event type: {event_type}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_upstream_connection_error(error_body: str) -> bool:
|
||||
"""Check if an error body indicates an upstream connection/proxy failure."""
|
||||
lower = error_body.lower()
|
||||
return "upstream connect error" in lower or "reset before headers" in lower or "connection termination" in lower
|
||||
|
||||
def _handle_http_error_from_status(self, status_code: int, error_body: str) -> Exception:
|
||||
"""Create appropriate exception from HTTP status code.
|
||||
|
||||
@@ -968,9 +1048,14 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
message=f"ChatGPT rate limit exceeded: {error_body}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
)
|
||||
elif status_code == 502 or (status_code >= 500 and self._is_upstream_connection_error(error_body)):
|
||||
return LLMConnectionError(
|
||||
message=f"ChatGPT upstream connection error: {error_body}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif status_code >= 500:
|
||||
return LLMServerError(
|
||||
message=f"ChatGPT server error: {error_body}",
|
||||
message=f"ChatGPT API error: {error_body}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
else:
|
||||
@@ -992,25 +1077,43 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
return "o1" in model or "o3" in model or "o4" in model or "gpt-5" in model
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
|
||||
"""Map ChatGPT-specific errors to common LLMError types.
|
||||
|
||||
Args:
|
||||
e: Original exception.
|
||||
llm_config: Optional LLM config to determine if this is a BYOK key.
|
||||
|
||||
Returns:
|
||||
Mapped LLMError subclass.
|
||||
"""
|
||||
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
|
||||
|
||||
# Already a typed LLM/Letta error (e.g. from SSE error handling) — pass through
|
||||
if isinstance(e, LettaError):
|
||||
return e
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
return self._handle_http_error(e)
|
||||
return self._handle_http_error(e, is_byok=is_byok)
|
||||
|
||||
return super().handle_llm_error(e)
|
||||
# Handle httpx network errors which can occur during streaming
|
||||
# when the connection is unexpectedly closed while reading/writing
|
||||
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
|
||||
logger.warning(f"[ChatGPT] Network error during streaming: {type(e).__name__}: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Network error during ChatGPT streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
def _handle_http_error(self, e: httpx.HTTPStatusError) -> Exception:
|
||||
return super().handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
def _handle_http_error(self, e: httpx.HTTPStatusError, is_byok: bool | None = None) -> Exception:
|
||||
"""Handle HTTP status errors from ChatGPT backend.
|
||||
|
||||
Args:
|
||||
e: HTTP status error.
|
||||
is_byok: Whether the request used a BYOK key.
|
||||
|
||||
Returns:
|
||||
Appropriate LLMError subclass.
|
||||
@@ -1028,28 +1131,86 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
return LLMAuthenticationError(
|
||||
message=f"ChatGPT authentication failed: {error_message}",
|
||||
code=ErrorCode.UNAUTHENTICATED,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif status_code == 429:
|
||||
return LLMRateLimitError(
|
||||
message=f"ChatGPT rate limit exceeded: {error_message}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif status_code == 400:
|
||||
if "context" in error_message.lower() or "token" in error_message.lower():
|
||||
return ContextWindowExceededError(
|
||||
message=f"ChatGPT context window exceeded: {error_message}",
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
return LLMBadRequestError(
|
||||
message=f"ChatGPT bad request: {error_message}",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif status_code == 502 or (status_code >= 500 and self._is_upstream_connection_error(error_message)):
|
||||
return LLMConnectionError(
|
||||
message=f"ChatGPT upstream connection error: {error_message}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif status_code >= 500:
|
||||
return LLMServerError(
|
||||
message=f"ChatGPT server error: {error_message}",
|
||||
message=f"ChatGPT API error: {error_message}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
else:
|
||||
return LLMBadRequestError(
|
||||
message=f"ChatGPT request failed ({status_code}): {error_message}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
|
||||
def _handle_sse_error_event(self, raw_event: dict) -> Exception:
|
||||
"""Create appropriate exception from an SSE error or response.failed event.
|
||||
|
||||
The ChatGPT backend can return errors as SSE events within a 200 OK stream,
|
||||
e.g. {"type": "error", "error": {"type": "invalid_request_error",
|
||||
"code": "context_length_exceeded", "message": "..."}}.
|
||||
|
||||
Args:
|
||||
raw_event: Raw SSE event data containing an error.
|
||||
|
||||
Returns:
|
||||
Appropriate LLM exception.
|
||||
"""
|
||||
error_obj = raw_event.get("error", {})
|
||||
if isinstance(error_obj, str):
|
||||
error_message = error_obj
|
||||
error_code = None
|
||||
else:
|
||||
error_message = error_obj.get("message", "Unknown ChatGPT SSE error")
|
||||
error_code = error_obj.get("code") or None
|
||||
|
||||
if error_code == "context_length_exceeded":
|
||||
return ContextWindowExceededError(
|
||||
message=f"ChatGPT context window exceeded: {error_message}",
|
||||
)
|
||||
elif error_code == "rate_limit_exceeded":
|
||||
return LLMRateLimitError(
|
||||
message=f"ChatGPT rate limit exceeded: {error_message}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
)
|
||||
elif error_code == "authentication_error":
|
||||
return LLMAuthenticationError(
|
||||
message=f"ChatGPT authentication failed: {error_message}",
|
||||
code=ErrorCode.UNAUTHENTICATED,
|
||||
)
|
||||
elif error_code == "server_error":
|
||||
return LLMServerError(
|
||||
message=f"ChatGPT API error: {error_message}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
else:
|
||||
return LLMBadRequestError(
|
||||
message=f"ChatGPT SSE error ({error_code or 'unknown'}): {error_message}",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
@@ -97,6 +98,8 @@ class DeepseekClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
@@ -108,6 +111,8 @@ class DeepseekClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
|
||||
@@ -20,3 +20,21 @@ def is_context_window_overflow_message(msg: str) -> bool:
|
||||
or "context_length_exceeded" in msg
|
||||
or "Input tokens exceed the configured limit" in msg
|
||||
)
|
||||
|
||||
|
||||
def is_insufficient_credits_message(msg: str) -> bool:
|
||||
"""Best-effort detection for insufficient credits/quota/billing errors.
|
||||
|
||||
BYOK users on OpenRouter, OpenAI, etc. may exhaust their credits mid-stream
|
||||
or get rejected pre-flight. We detect these so they map to 402 instead of 400/500.
|
||||
"""
|
||||
lower = msg.lower()
|
||||
return (
|
||||
"insufficient credits" in lower
|
||||
or "requires more credits" in lower
|
||||
or "add more credits" in lower
|
||||
or "exceeded your current quota" in lower
|
||||
or "you've exceeded your budget" in lower
|
||||
or ("billing" in lower and "hard limit" in lower)
|
||||
or "can only afford" in lower
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.settings import model_settings, settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -16,10 +17,27 @@ logger = get_logger(__name__)
|
||||
class GoogleAIClient(GoogleVertexClient):
|
||||
provider_label = "Google AI"
|
||||
|
||||
def _get_client(self):
|
||||
def _get_client(self, llm_config: Optional[LLMConfig] = None):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
api_key = None
|
||||
if llm_config:
|
||||
api_key, _, _ = self.get_byok_overrides(llm_config)
|
||||
if not api_key:
|
||||
api_key = model_settings.gemini_api_key
|
||||
return genai.Client(
|
||||
api_key=model_settings.gemini_api_key,
|
||||
api_key=api_key,
|
||||
http_options=HttpOptions(timeout=timeout_ms),
|
||||
)
|
||||
|
||||
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
api_key = None
|
||||
if llm_config:
|
||||
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
|
||||
if not api_key:
|
||||
api_key = model_settings.gemini_api_key
|
||||
return genai.Client(
|
||||
api_key=api_key,
|
||||
http_options=HttpOptions(timeout=timeout_ms),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
GOOGLE_MODEL_TO_CONTEXT_LENGTH = {
|
||||
"gemini-3-pro-preview": 1048576,
|
||||
"gemini-3.1-pro-preview": 1048576,
|
||||
"gemini-3-flash-preview": 1048576,
|
||||
"gemini-2.5-pro": 1048576,
|
||||
"gemini-2.5-flash": 1048576,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
import httpx
|
||||
import pydantic_core
|
||||
from google.genai import Client, errors
|
||||
from google.genai.types import (
|
||||
FunctionCallingConfig,
|
||||
@@ -21,6 +23,7 @@ from letta.errors import (
|
||||
LLMAuthenticationError,
|
||||
LLMBadRequestError,
|
||||
LLMConnectionError,
|
||||
LLMInsufficientCreditsError,
|
||||
LLMNotFoundError,
|
||||
LLMPermissionDeniedError,
|
||||
LLMRateLimitError,
|
||||
@@ -29,12 +32,14 @@ from letta.errors import (
|
||||
LLMUnprocessableEntityError,
|
||||
)
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads, sanitize_unicode_surrogates
|
||||
from letta.llm_api.error_utils import is_insufficient_credits_message
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool
|
||||
@@ -50,8 +55,31 @@ class GoogleVertexClient(LLMClientBase):
|
||||
MAX_RETRIES = model_settings.gemini_max_retries
|
||||
provider_label = "Google Vertex"
|
||||
|
||||
def _get_client(self):
|
||||
def _get_client(self, llm_config: Optional[LLMConfig] = None):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
if llm_config:
|
||||
api_key, _, _ = self.get_byok_overrides(llm_config)
|
||||
if api_key:
|
||||
return Client(
|
||||
api_key=api_key,
|
||||
http_options=HttpOptions(timeout=timeout_ms),
|
||||
)
|
||||
return Client(
|
||||
vertexai=True,
|
||||
project=model_settings.google_cloud_project,
|
||||
location=model_settings.google_cloud_location,
|
||||
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
|
||||
)
|
||||
|
||||
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
if llm_config:
|
||||
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
|
||||
if api_key:
|
||||
return Client(
|
||||
api_key=api_key,
|
||||
http_options=HttpOptions(timeout=timeout_ms),
|
||||
)
|
||||
return Client(
|
||||
vertexai=True,
|
||||
project=model_settings.google_cloud_project,
|
||||
@@ -71,22 +99,36 @@ class GoogleVertexClient(LLMClientBase):
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client = self._get_client(llm_config)
|
||||
response = client.models.generate_content(
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
return response.model_dump()
|
||||
except pydantic_core._pydantic_core.ValidationError as e:
|
||||
# Handle Pydantic validation errors from the Google SDK
|
||||
# This occurs when tool schemas contain unsupported fields
|
||||
logger.error(
|
||||
f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}"
|
||||
)
|
||||
raise LLMBadRequestError(
|
||||
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
|
||||
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
|
||||
f"Please check your tool definitions. Error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
client = self._get_client()
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = await self._get_client_async(llm_config)
|
||||
|
||||
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
|
||||
# https://github.com/googleapis/python-aiplatform/issues/4472
|
||||
@@ -100,17 +142,30 @@ class GoogleVertexClient(LLMClientBase):
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
except pydantic_core._pydantic_core.ValidationError as e:
|
||||
# Handle Pydantic validation errors from the Google SDK
|
||||
# This occurs when tool schemas contain unsupported fields
|
||||
logger.error(
|
||||
f"Pydantic validation error when calling {self._provider_name()} API. "
|
||||
f"Tool schema contains unsupported fields. Error: {e}"
|
||||
)
|
||||
raise LLMBadRequestError(
|
||||
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
|
||||
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
|
||||
f"Please check your tool definitions. Error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
except errors.APIError as e:
|
||||
# Retry on 503 and 500 errors as well, usually ephemeral from Gemini
|
||||
if e.code == 503 or e.code == 500 or e.code == 504:
|
||||
logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}")
|
||||
retry_count += 1
|
||||
if retry_count > self.MAX_RETRIES:
|
||||
raise self.handle_llm_error(e)
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
continue
|
||||
raise self.handle_llm_error(e)
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
response_data = response.model_dump()
|
||||
is_malformed_function_call = self.is_malformed_function_call(response_data)
|
||||
if is_malformed_function_call:
|
||||
@@ -148,7 +203,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
|
||||
client = self._get_client()
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = await self._get_client_async(llm_config)
|
||||
|
||||
try:
|
||||
response = await client.aio.models.generate_content_stream(
|
||||
@@ -156,13 +213,35 @@ class GoogleVertexClient(LLMClientBase):
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
except pydantic_core._pydantic_core.ValidationError as e:
|
||||
# Handle Pydantic validation errors from the Google SDK
|
||||
# This occurs when tool schemas contain unsupported fields
|
||||
logger.error(
|
||||
f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}"
|
||||
)
|
||||
raise LLMBadRequestError(
|
||||
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
|
||||
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
|
||||
f"Please check your tool definitions. Error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
except errors.APIError as e:
|
||||
raise self.handle_llm_error(e)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
# Direct yield - keeps response alive in generator's local scope throughout iteration
|
||||
# This is required because the SDK's connection lifecycle is tied to the response object
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
try:
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
except errors.ClientError as e:
|
||||
if e.code == 499:
|
||||
logger.info(f"{self._provider_prefix()} Stream cancelled by client (499): {e}")
|
||||
return
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
except errors.APIError as e:
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
@staticmethod
|
||||
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
|
||||
@@ -196,7 +275,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
|
||||
# * Only a subset of the OpenAPI schema is supported.
|
||||
# * Supported parameter types in Python are limited.
|
||||
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties", "$schema"]
|
||||
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties", "$schema", "const", "$ref"]
|
||||
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
|
||||
for key_to_remove in keys_to_remove_at_this_level:
|
||||
logger.debug(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
|
||||
@@ -223,6 +302,49 @@ class GoogleVertexClient(LLMClientBase):
|
||||
for item_schema in schema_part[key]:
|
||||
self._clean_google_ai_schema_properties(item_schema)
|
||||
|
||||
def _resolve_json_schema_refs(self, schema: dict, defs: dict | None = None) -> dict:
|
||||
"""
|
||||
Recursively resolve $ref in JSON schema by inlining definitions.
|
||||
Google GenAI SDK does not support $ref.
|
||||
"""
|
||||
if defs is None:
|
||||
# Look for definitions at the top level
|
||||
defs = schema.get("$defs") or schema.get("definitions") or {}
|
||||
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
# If this is a ref, resolve it
|
||||
if "$ref" in schema:
|
||||
ref = schema["$ref"]
|
||||
if isinstance(ref, str):
|
||||
for prefix in ("#/$defs/", "#/definitions/"):
|
||||
if ref.startswith(prefix):
|
||||
ref_name = ref.split("/")[-1]
|
||||
if ref_name in defs:
|
||||
resolved = defs[ref_name].copy()
|
||||
return self._resolve_json_schema_refs(resolved, defs)
|
||||
break
|
||||
|
||||
logger.warning(f"Could not resolve $ref '{ref}' in schema — will be stripped by schema cleaner")
|
||||
|
||||
# Recursively process children
|
||||
new_schema = schema.copy()
|
||||
|
||||
# We need to remove $defs/definitions from the output schema as Google doesn't support them
|
||||
if "$defs" in new_schema:
|
||||
del new_schema["$defs"]
|
||||
if "definitions" in new_schema:
|
||||
del new_schema["definitions"]
|
||||
|
||||
for k, v in new_schema.items():
|
||||
if isinstance(v, dict):
|
||||
new_schema[k] = self._resolve_json_schema_refs(v, defs)
|
||||
elif isinstance(v, list):
|
||||
new_schema[k] = [self._resolve_json_schema_refs(i, defs) if isinstance(i, dict) else i for i in v]
|
||||
|
||||
return new_schema
|
||||
|
||||
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
|
||||
"""
|
||||
OpenAI style:
|
||||
@@ -273,7 +395,8 @@ class GoogleVertexClient(LLMClientBase):
|
||||
dict(
|
||||
name=t.function.name,
|
||||
description=t.function.description,
|
||||
parameters=t.function.parameters, # TODO need to unpack
|
||||
# Deep copy parameters to avoid modifying the original Tool object
|
||||
parameters=copy.deepcopy(t.function.parameters) if t.function.parameters else {},
|
||||
)
|
||||
for t in tools
|
||||
]
|
||||
@@ -284,6 +407,8 @@ class GoogleVertexClient(LLMClientBase):
|
||||
|
||||
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
|
||||
if "parameters" in func and isinstance(func["parameters"], dict):
|
||||
# Resolve $ref in schema because Google AI SDK doesn't support them
|
||||
func["parameters"] = self._resolve_json_schema_refs(func["parameters"])
|
||||
self._clean_google_ai_schema_properties(func["parameters"])
|
||||
|
||||
# Add inner thoughts
|
||||
@@ -549,6 +674,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
content=inner_thoughts,
|
||||
tool_calls=[tool_call],
|
||||
)
|
||||
if response_message.thought_signature:
|
||||
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
|
||||
openai_response_message.reasoning_content_signature = thought_signature
|
||||
else:
|
||||
openai_response_message.content = inner_thoughts
|
||||
if openai_response_message.tool_calls is None:
|
||||
@@ -670,6 +798,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
# "candidatesTokenCount": 27,
|
||||
# "totalTokenCount": 36
|
||||
# }
|
||||
usage = None
|
||||
if response.usage_metadata:
|
||||
# Extract usage via centralized method
|
||||
from letta.schemas.enums import ProviderType
|
||||
@@ -750,54 +879,80 @@ class GoogleVertexClient(LLMClientBase):
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
|
||||
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
|
||||
|
||||
# Handle Google GenAI specific errors
|
||||
if isinstance(e, errors.ClientError):
|
||||
if e.code == 499:
|
||||
logger.info(f"{self._provider_prefix()} Request cancelled by client (499): {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Request to {self._provider_name()} was cancelled (client disconnected): {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"status_code": 499, "cause": "client_cancelled", "is_byok": is_byok},
|
||||
)
|
||||
|
||||
logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
|
||||
|
||||
# Handle specific error codes
|
||||
if e.code == 400:
|
||||
error_str = str(e).lower()
|
||||
if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
|
||||
if ("context" in error_str or "token count" in error_str or "tokens allowed" in error_str) and (
|
||||
"exceed" in error_str or "limit" in error_str or "too long" in error_str
|
||||
):
|
||||
return ContextWindowExceededError(
|
||||
message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
else:
|
||||
return LLMBadRequestError(
|
||||
message=f"Bad request to {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif e.code == 401:
|
||||
return LLMAuthenticationError(
|
||||
message=f"Authentication failed with {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif e.code == 403:
|
||||
return LLMPermissionDeniedError(
|
||||
message=f"Permission denied by {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif e.code == 404:
|
||||
return LLMNotFoundError(
|
||||
message=f"Resource not found in {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif e.code == 408:
|
||||
return LLMTimeoutError(
|
||||
message=f"Request to {self._provider_name()} timed out: {str(e)}",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
elif e.code == 402 or is_insufficient_credits_message(str(e)):
|
||||
msg = str(e)
|
||||
return LLMInsufficientCreditsError(
|
||||
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
|
||||
code=ErrorCode.PAYMENT_REQUIRED,
|
||||
details={"status_code": e.code, "is_byok": is_byok},
|
||||
)
|
||||
elif e.code == 422:
|
||||
return LLMUnprocessableEntityError(
|
||||
message=f"Invalid request content for {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
elif e.code == 429:
|
||||
logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
|
||||
return LLMRateLimitError(
|
||||
message=f"Rate limited by {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
else:
|
||||
return LLMServerError(
|
||||
@@ -806,6 +961,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
details={
|
||||
"status_code": e.code,
|
||||
"response_json": getattr(e, "response_json", None),
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -820,13 +976,14 @@ class GoogleVertexClient(LLMClientBase):
|
||||
details={
|
||||
"status_code": e.code,
|
||||
"response_json": getattr(e, "response_json", None),
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
elif e.code == 502:
|
||||
return LLMConnectionError(
|
||||
message=f"Bad gateway from {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
elif e.code == 503:
|
||||
return LLMServerError(
|
||||
@@ -835,13 +992,14 @@ class GoogleVertexClient(LLMClientBase):
|
||||
details={
|
||||
"status_code": e.code,
|
||||
"response_json": getattr(e, "response_json", None),
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
elif e.code == 504:
|
||||
return LLMTimeoutError(
|
||||
message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
else:
|
||||
return LLMServerError(
|
||||
@@ -850,6 +1008,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
details={
|
||||
"status_code": e.code,
|
||||
"response_json": getattr(e, "response_json", None),
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -861,6 +1020,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
details={
|
||||
"status_code": e.code,
|
||||
"response_json": getattr(e, "response_json", None),
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -872,7 +1032,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Connection error during {self._provider_name()} streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# Handle httpx network errors which can occur during streaming
|
||||
@@ -882,7 +1042,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Network error during {self._provider_name()} streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# Handle connection-related errors
|
||||
@@ -891,13 +1051,15 @@ class GoogleVertexClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Failed to connect to {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# Fallback to base implementation for other errors
|
||||
return super().handle_llm_error(e)
|
||||
return super().handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
||||
async def count_tokens(
|
||||
self, messages: List[dict] | None = None, model: str | None = None, tools: List[OpenAITool] | None = None
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens for the given messages and tools using the Gemini token counting API.
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -34,6 +35,11 @@ class GroqClient(OpenAIClient):
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
# Groq only supports string values for tool_choice: "none", "auto", "required"
|
||||
# Convert object-format tool_choice (used for force_tool_call) to "required"
|
||||
if "tool_choice" in data and isinstance(data["tool_choice"], dict):
|
||||
data["tool_choice"] = "required"
|
||||
|
||||
# Groq validation - these fields are not supported and will cause 400 errors
|
||||
# https://console.groq.com/docs/openai
|
||||
if "top_logprobs" in data:
|
||||
@@ -69,6 +75,8 @@ class GroqClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous request to Groq API and returns raw response dict.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
|
||||
from letta.schemas.response_format import (
|
||||
JsonObjectResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormatType,
|
||||
ResponseFormatUnion,
|
||||
TextResponseFormat,
|
||||
)
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.utils import printd
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.orm.user import User
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.enums import LLMCallType, ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
@@ -245,6 +245,7 @@ def create(
|
||||
request_json=prepare_openai_payload(data),
|
||||
response_json=response.model_json_schema(),
|
||||
step_id=step_id,
|
||||
call_type=LLMCallType.agent_step,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from letta.errors import ErrorCode, LLMConnectionError, LLMError
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentType, ProviderCategory
|
||||
from letta.schemas.enums import AgentType, LLMCallType, ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
@@ -61,8 +61,11 @@ class LLMClientBase:
|
||||
user_id: Optional[str] = None,
|
||||
compaction_settings: Optional[Dict] = None,
|
||||
llm_config: Optional[Dict] = None,
|
||||
actor: Optional["User"] = None,
|
||||
) -> None:
|
||||
"""Set telemetry context for provider trace logging."""
|
||||
if actor is not None:
|
||||
self.actor = actor
|
||||
self._telemetry_manager = telemetry_manager
|
||||
self._telemetry_agent_id = agent_id
|
||||
self._telemetry_agent_tags = agent_tags
|
||||
@@ -82,6 +85,10 @@ class LLMClientBase:
|
||||
"""Wrapper around request_async that logs telemetry for all requests including errors.
|
||||
|
||||
Call set_telemetry_context() first to set agent_id, run_id, etc.
|
||||
|
||||
Telemetry is logged via TelemetryManager which supports multiple backends
|
||||
(postgres, clickhouse, socket, etc.) configured via
|
||||
LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND.
|
||||
"""
|
||||
from letta.log import get_logger
|
||||
|
||||
@@ -97,6 +104,7 @@ class LLMClientBase:
|
||||
error_type = type(e).__name__
|
||||
raise
|
||||
finally:
|
||||
# Log telemetry via configured backends
|
||||
if self._telemetry_manager and settings.track_provider_trace:
|
||||
if self.actor is None:
|
||||
logger.warning(f"Skipping telemetry: actor is None (call_type={self._telemetry_call_type})")
|
||||
@@ -116,24 +124,33 @@ class LLMClientBase:
|
||||
org_id=self._telemetry_org_id,
|
||||
user_id=self._telemetry_user_id,
|
||||
compaction_settings=self._telemetry_compaction_settings,
|
||||
llm_config=self._telemetry_llm_config,
|
||||
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log telemetry: {e}")
|
||||
|
||||
async def stream_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig):
|
||||
"""Returns raw stream. Caller should log telemetry after processing via log_provider_trace_async().
|
||||
|
||||
Call set_telemetry_context() first to set agent_id, run_id, etc.
|
||||
After consuming the stream, call log_provider_trace_async() with the response data.
|
||||
"""
|
||||
return await self.stream_async(request_data, llm_config)
|
||||
|
||||
async def log_provider_trace_async(self, request_data: dict, response_json: dict) -> None:
|
||||
async def log_provider_trace_async(
|
||||
self,
|
||||
request_data: dict,
|
||||
response_json: Optional[dict],
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
latency_ms: Optional[int] = None,
|
||||
error_msg: Optional[str] = None,
|
||||
error_type: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log provider trace telemetry. Call after processing LLM response.
|
||||
|
||||
Uses telemetry context set via set_telemetry_context().
|
||||
Telemetry is logged via TelemetryManager which supports multiple backends.
|
||||
|
||||
Args:
|
||||
request_data: The request payload sent to the LLM
|
||||
response_json: The response payload from the LLM
|
||||
llm_config: LLMConfig for extracting provider/model info
|
||||
latency_ms: Latency in milliseconds (not used currently, kept for API compatibility)
|
||||
error_msg: Error message if request failed (not used currently)
|
||||
error_type: Error type if request failed (not used currently)
|
||||
"""
|
||||
from letta.log import get_logger
|
||||
|
||||
@@ -146,6 +163,13 @@ class LLMClientBase:
|
||||
logger.warning(f"Skipping telemetry: actor is None (call_type={self._telemetry_call_type})")
|
||||
return
|
||||
|
||||
if response_json is None:
|
||||
if error_msg:
|
||||
response_json = {"error": error_msg, "error_type": error_type}
|
||||
else:
|
||||
logger.warning(f"Skipping telemetry: no response_json or error_msg (call_type={self._telemetry_call_type})")
|
||||
return
|
||||
|
||||
try:
|
||||
pydantic_actor = self.actor.to_pydantic() if hasattr(self.actor, "to_pydantic") else self.actor
|
||||
await self._telemetry_manager.create_provider_trace_async(
|
||||
@@ -161,7 +185,7 @@ class LLMClientBase:
|
||||
org_id=self._telemetry_org_id,
|
||||
user_id=self._telemetry_user_id,
|
||||
compaction_settings=self._telemetry_compaction_settings,
|
||||
llm_config=self._telemetry_llm_config,
|
||||
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -204,11 +228,12 @@ class LLMClientBase:
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
call_type=LLMCallType.agent_step,
|
||||
),
|
||||
)
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
|
||||
@@ -237,12 +262,13 @@ class LLMClientBase:
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
call_type=LLMCallType.agent_step,
|
||||
),
|
||||
)
|
||||
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
raise self.handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
|
||||
@@ -334,17 +360,20 @@ class LLMClientBase:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
def handle_llm_error(self, e: Exception, llm_config: Optional["LLMConfig"] = None) -> Exception:
|
||||
"""
|
||||
Maps provider-specific errors to common LLMError types.
|
||||
Each LLM provider should implement this to translate their specific errors.
|
||||
|
||||
Args:
|
||||
e: The original provider-specific exception
|
||||
llm_config: Optional LLM config to determine if this is a BYOK key
|
||||
|
||||
Returns:
|
||||
An LLMError subclass that represents the error in a provider-agnostic way
|
||||
"""
|
||||
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
|
||||
|
||||
# Handle httpx.RemoteProtocolError which can occur during streaming
|
||||
# when the remote server closes the connection unexpectedly
|
||||
# (e.g., "peer closed connection without sending complete message body")
|
||||
@@ -356,10 +385,10 @@ class LLMClientBase:
|
||||
return LLMConnectionError(
|
||||
message=f"Connection error during streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
return LLMError(f"Unhandled LLM error: {str(e)}")
|
||||
return LLMError(message=f"Unhandled LLM error: {str(e)}", details={"is_byok": is_byok})
|
||||
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ import anthropic
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import BetaMessage, BetaRawMessageStreamEvent
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
@@ -83,6 +84,8 @@ class MiniMaxClient(AnthropicClient):
|
||||
|
||||
Uses beta messages API for compatibility with Anthropic streaming interfaces.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
|
||||
try:
|
||||
@@ -105,6 +108,8 @@ class MiniMaxClient(AnthropicClient):
|
||||
|
||||
Uses beta messages API for compatibility with Anthropic streaming interfaces.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
request_data["stream"] = True
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ async def mistral_get_model_list_async(url: str, api_key: str) -> dict:
|
||||
url = smart_urljoin(url, "models")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
logger.debug("Sending request to %s", url)
|
||||
|
||||
@@ -59,7 +59,7 @@ async def openai_get_model_list_async(
|
||||
url = smart_urljoin(url, "models")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if "openrouter.ai" in url:
|
||||
if model_settings.openrouter_referer:
|
||||
@@ -86,7 +86,7 @@ async def openai_get_model_list_async(
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
try:
|
||||
error_response = http_err.response.json()
|
||||
except:
|
||||
except Exception:
|
||||
error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text}
|
||||
logger.debug(f"Got HTTPError, exception={http_err}, response={error_response}")
|
||||
raise http_err
|
||||
@@ -478,7 +478,7 @@ def openai_chat_completions_request_stream(
|
||||
|
||||
data = prepare_openai_payload(chat_completion_request)
|
||||
data["stream"] = True
|
||||
kwargs = {"api_key": api_key, "base_url": url, "max_retries": 0}
|
||||
kwargs = {"api_key": api_key or "DUMMY_API_KEY", "base_url": url, "max_retries": 0}
|
||||
if "openrouter.ai" in url:
|
||||
headers = {}
|
||||
if model_settings.openrouter_referer:
|
||||
@@ -511,7 +511,7 @@ def openai_chat_completions_request(
|
||||
https://platform.openai.com/docs/guides/text-generation?lang=curl
|
||||
"""
|
||||
data = prepare_openai_payload(chat_completion_request)
|
||||
kwargs = {"api_key": api_key, "base_url": url, "max_retries": 0}
|
||||
kwargs = {"api_key": api_key or "DUMMY_API_KEY", "base_url": url, "max_retries": 0}
|
||||
if "openrouter.ai" in url:
|
||||
headers = {}
|
||||
if model_settings.openrouter_referer:
|
||||
@@ -524,7 +524,17 @@ def openai_chat_completions_request(
|
||||
log_event(name="llm_request_sent", attributes=data)
|
||||
chat_completion = client.chat.completions.create(**data)
|
||||
log_event(name="llm_response_received", attributes=chat_completion.model_dump())
|
||||
return ChatCompletionResponse(**chat_completion.model_dump())
|
||||
response = ChatCompletionResponse(**chat_completion.model_dump())
|
||||
|
||||
# Override tool_call IDs to ensure cross-provider compatibility (matches streaming path behavior)
|
||||
# Some models (e.g. Kimi via OpenRouter) generate IDs like 'Read:93' which violate Anthropic's pattern
|
||||
for choice in response.choices:
|
||||
if choice.message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if tool_call.id is not None:
|
||||
tool_call.id = get_tool_call_id()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
|
||||
|
||||
@@ -20,6 +20,7 @@ from letta.errors import (
|
||||
LLMAuthenticationError,
|
||||
LLMBadRequestError,
|
||||
LLMConnectionError,
|
||||
LLMInsufficientCreditsError,
|
||||
LLMNotFoundError,
|
||||
LLMPermissionDeniedError,
|
||||
LLMRateLimitError,
|
||||
@@ -27,7 +28,8 @@ from letta.errors import (
|
||||
LLMTimeoutError,
|
||||
LLMUnprocessableEntityError,
|
||||
)
|
||||
from letta.llm_api.error_utils import is_context_window_overflow_message
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.error_utils import is_context_window_overflow_message, is_insufficient_credits_message
|
||||
from letta.llm_api.helpers import (
|
||||
add_inner_thoughts_to_functions,
|
||||
convert_response_format_to_responses_api,
|
||||
@@ -39,13 +41,13 @@ from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.letta_message_content import MessageContentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
ChatCompletionRequest,
|
||||
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||
FunctionSchema,
|
||||
Tool as OpenAITool,
|
||||
ToolFunctionChoice,
|
||||
cast_message_to_subtype,
|
||||
@@ -56,7 +58,6 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
FunctionCall,
|
||||
Message as ChoiceMessage,
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from letta.schemas.openai.responses_request import ResponsesRequest
|
||||
from letta.schemas.response_format import JsonSchemaResponseFormat
|
||||
@@ -105,7 +106,7 @@ def accepts_developer_role(model: str) -> bool:
|
||||
|
||||
See: https://community.openai.com/t/developer-role-not-accepted-for-o1-o1-mini-o3-mini/1110750/7
|
||||
"""
|
||||
if is_openai_reasoning_model(model) and "o1-mini" not in model or "o1-preview" in model:
|
||||
if (is_openai_reasoning_model(model) and "o1-mini" not in model) or "o1-preview" in model:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -244,6 +245,64 @@ class OpenAIClient(LLMClientBase):
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return supports_structured_output(llm_config)
|
||||
|
||||
def _is_openrouter_request(self, llm_config: LLMConfig) -> bool:
|
||||
return (llm_config.model_endpoint and "openrouter.ai" in llm_config.model_endpoint) or (llm_config.provider_name == "openrouter")
|
||||
|
||||
def _is_true_openai_request(self, llm_config: LLMConfig) -> bool:
|
||||
if llm_config.model_endpoint_type != "openai":
|
||||
return False
|
||||
|
||||
if self._is_openrouter_request(llm_config):
|
||||
return False
|
||||
|
||||
# Keep Letta inference endpoint behavior unchanged.
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
return False
|
||||
|
||||
# If provider_name is explicitly set and not openai, don't apply OpenAI-specific prompt caching fields.
|
||||
if llm_config.provider_name and llm_config.provider_name != "openai":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _normalize_model_name(self, model: Optional[str]) -> Optional[str]:
|
||||
if not model:
|
||||
return None
|
||||
return model.split("/", 1)[-1]
|
||||
|
||||
def _supports_extended_prompt_cache_retention(self, model: Optional[str]) -> bool:
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
if not normalized_model:
|
||||
return False
|
||||
|
||||
# Per OpenAI docs: extended retention is available on gpt-4.1 and gpt-5 family models.
|
||||
# gpt-5-mini is excluded (not listed in docs).
|
||||
return normalized_model == "gpt-4.1" or (normalized_model.startswith("gpt-5") and normalized_model != "gpt-5-mini")
|
||||
|
||||
def _apply_prompt_cache_settings(
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
model: Optional[str],
|
||||
messages: List[PydanticMessage],
|
||||
request_obj: Any,
|
||||
) -> None:
|
||||
"""Apply OpenAI prompt cache settings to the request.
|
||||
|
||||
We intentionally do NOT set prompt_cache_key. OpenAI's default routing
|
||||
(based on a hash of the first ~256 tokens of the prompt) already provides
|
||||
good cache affinity for Letta agents, since each agent has a unique system
|
||||
prompt. Setting an explicit key can disrupt existing warm caches and reduce
|
||||
hit rates.
|
||||
|
||||
We only set prompt_cache_retention to "24h" for models that support extended
|
||||
retention, which keeps cached prefixes active longer (up to 24h vs 5-10min).
|
||||
"""
|
||||
if not self._is_true_openai_request(llm_config):
|
||||
return
|
||||
|
||||
if self._supports_extended_prompt_cache_retention(model):
|
||||
request_obj.prompt_cache_retention = "24h"
|
||||
|
||||
@trace_method
|
||||
def build_request_data_responses(
|
||||
self,
|
||||
@@ -384,6 +443,13 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
data.model = "memgpt-openai"
|
||||
|
||||
self._apply_prompt_cache_settings(
|
||||
llm_config=llm_config,
|
||||
model=model,
|
||||
messages=messages,
|
||||
request_obj=data,
|
||||
)
|
||||
|
||||
request_data = data.model_dump(exclude_unset=True)
|
||||
# print("responses request data", request_data)
|
||||
return request_data
|
||||
@@ -452,13 +518,11 @@ class OpenAIClient(LLMClientBase):
|
||||
model = None
|
||||
|
||||
# TODO: we may need to extend this to more models using proxy?
|
||||
is_openrouter = (llm_config.model_endpoint and "openrouter.ai" in llm_config.model_endpoint) or (
|
||||
llm_config.provider_name == "openrouter"
|
||||
)
|
||||
is_openrouter = self._is_openrouter_request(llm_config)
|
||||
if is_openrouter:
|
||||
try:
|
||||
model = llm_config.handle.split("/", 1)[-1]
|
||||
except:
|
||||
except Exception:
|
||||
# don't raise error since this isn't robust against edge cases
|
||||
pass
|
||||
|
||||
@@ -468,7 +532,12 @@ class OpenAIClient(LLMClientBase):
|
||||
tool_choice = None
|
||||
if tools: # only set tool_choice if tools exist
|
||||
if force_tool_call is not None:
|
||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
# OpenRouter proxies to providers that may not support object-format tool_choice
|
||||
# Use "required" instead which achieves similar effect
|
||||
if is_openrouter:
|
||||
tool_choice = "required"
|
||||
else:
|
||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
elif requires_subsequent_tool_call:
|
||||
tool_choice = "required"
|
||||
elif self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
|
||||
@@ -505,6 +574,12 @@ class OpenAIClient(LLMClientBase):
|
||||
if llm_config.frequency_penalty is not None:
|
||||
data.frequency_penalty = llm_config.frequency_penalty
|
||||
|
||||
# Add logprobs configuration for RL training
|
||||
if llm_config.return_logprobs:
|
||||
data.logprobs = True
|
||||
if llm_config.top_logprobs is not None:
|
||||
data.top_logprobs = llm_config.top_logprobs
|
||||
|
||||
if tools and supports_parallel_tool_calling(model):
|
||||
data.parallel_tool_calls = False
|
||||
|
||||
@@ -546,6 +621,13 @@ class OpenAIClient(LLMClientBase):
|
||||
new_tools.append(tool.model_copy(deep=True))
|
||||
data.tools = new_tools
|
||||
|
||||
self._apply_prompt_cache_settings(
|
||||
llm_config=llm_config,
|
||||
model=model,
|
||||
messages=messages,
|
||||
request_obj=data,
|
||||
)
|
||||
|
||||
# Note: Tools are already processed by enable_strict_mode() in the workflow/agent code
|
||||
# (temporal_letta_v1_agent_workflow.py or letta_agent_v3.py) before reaching here.
|
||||
# enable_strict_mode() handles: strict flag, additionalProperties, required array, nullable fields
|
||||
@@ -564,6 +646,17 @@ class OpenAIClient(LLMClientBase):
|
||||
# If set, then in the backend "medium" thinking is turned on
|
||||
# request_data["reasoning_effort"] = "medium"
|
||||
|
||||
# Add OpenRouter reasoning configuration via extra_body
|
||||
if is_openrouter and llm_config.enable_reasoner:
|
||||
reasoning_config = {}
|
||||
if llm_config.reasoning_effort:
|
||||
reasoning_config["effort"] = llm_config.reasoning_effort
|
||||
if llm_config.max_reasoning_tokens and llm_config.max_reasoning_tokens > 0:
|
||||
reasoning_config["max_tokens"] = llm_config.max_reasoning_tokens
|
||||
if not reasoning_config:
|
||||
reasoning_config = {"enabled": True}
|
||||
request_data["extra_body"] = {"reasoning": reasoning_config}
|
||||
|
||||
return request_data
|
||||
|
||||
@trace_method
|
||||
@@ -571,29 +664,51 @@ class OpenAIClient(LLMClientBase):
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
# Sanitize Unicode surrogates to prevent encoding errors
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = OpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
try:
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[OpenAI] Failed to parse API response as JSON: {e}")
|
||||
raise LLMServerError(
|
||||
message=f"OpenAI API returned invalid JSON response (likely an HTML error page): {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"json_error": str(e), "error_position": f"line {e.lineno} column {e.colno}"},
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
# Sanitize Unicode surrogates to prevent encoding errors
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
kwargs = await self._prepare_client_kwargs_async(llm_config)
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = await client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
try:
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = await client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[OpenAI] Failed to parse API response as JSON: {e}")
|
||||
raise LLMServerError(
|
||||
message=f"OpenAI API returned invalid JSON response (likely an HTML error page): {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"json_error": str(e), "error_position": f"line {e.lineno} column {e.colno}"},
|
||||
)
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return is_openai_reasoning_model(llm_config.model)
|
||||
@@ -669,6 +784,12 @@ class OpenAIClient(LLMClientBase):
|
||||
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
|
||||
Handles potential extraction of inner thoughts if they were added via kwargs.
|
||||
"""
|
||||
if isinstance(response_data, str):
|
||||
raise LLMServerError(
|
||||
message="LLM endpoint returned a raw string instead of a JSON object. This usually indicates the endpoint URL is incorrect or returned an error page.",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"raw_response": response_data[:500]},
|
||||
)
|
||||
if "object" in response_data and response_data["object"] == "response":
|
||||
# Map Responses API shape to Chat Completions shape
|
||||
# See example payload in tests/integration_test_send_message_v2.py
|
||||
@@ -696,7 +817,6 @@ class OpenAIClient(LLMClientBase):
|
||||
finish_reason = None
|
||||
|
||||
# Optionally capture reasoning presence
|
||||
found_reasoning = False
|
||||
for out in outputs:
|
||||
out_type = (out or {}).get("type")
|
||||
if out_type == "message":
|
||||
@@ -707,7 +827,6 @@ class OpenAIClient(LLMClientBase):
|
||||
if text_val:
|
||||
assistant_text_parts.append(text_val)
|
||||
elif out_type == "reasoning":
|
||||
found_reasoning = True
|
||||
reasoning_summary_parts = [part.get("text") for part in out.get("summary")]
|
||||
reasoning_content_signature = out.get("encrypted_content")
|
||||
elif out_type == "function_call":
|
||||
@@ -765,12 +884,12 @@ class OpenAIClient(LLMClientBase):
|
||||
):
|
||||
if "choices" in response_data and len(response_data["choices"]) > 0:
|
||||
choice_data = response_data["choices"][0]
|
||||
if "message" in choice_data and "reasoning_content" in choice_data["message"]:
|
||||
reasoning_content = choice_data["message"]["reasoning_content"]
|
||||
if reasoning_content:
|
||||
chat_completion_response.choices[0].message.reasoning_content = reasoning_content
|
||||
|
||||
chat_completion_response.choices[0].message.reasoning_content_signature = None
|
||||
message_data = choice_data.get("message", {})
|
||||
# Check for reasoning_content (standard) or reasoning (OpenRouter)
|
||||
reasoning_content = message_data.get("reasoning_content") or message_data.get("reasoning")
|
||||
if reasoning_content:
|
||||
chat_completion_response.choices[0].message.reasoning_content = reasoning_content
|
||||
chat_completion_response.choices[0].message.reasoning_content_signature = None
|
||||
|
||||
# Unpack inner thoughts if they were embedded in function arguments
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
@@ -789,6 +908,9 @@ class OpenAIClient(LLMClientBase):
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
# Sanitize Unicode surrogates to prevent encoding errors
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
kwargs = await self._prepare_client_kwargs_async(llm_config)
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
|
||||
@@ -820,6 +942,9 @@ class OpenAIClient(LLMClientBase):
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
# Sanitize Unicode surrogates to prevent encoding errors
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
kwargs = await self._prepare_client_kwargs_async(llm_config)
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(**request_data, stream=True)
|
||||
@@ -958,10 +1083,11 @@ class OpenAIClient(LLMClientBase):
|
||||
return results
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
|
||||
"""
|
||||
Maps OpenAI-specific errors to common LLMError types.
|
||||
"""
|
||||
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
|
||||
if isinstance(e, openai.APITimeoutError):
|
||||
timeout_duration = getattr(e, "timeout", "unknown")
|
||||
logger.warning(f"[OpenAI] Request timeout after {timeout_duration} seconds: {e}")
|
||||
@@ -971,6 +1097,7 @@ class OpenAIClient(LLMClientBase):
|
||||
details={
|
||||
"timeout_duration": timeout_duration,
|
||||
"cause": str(e.__cause__) if e.__cause__ else None,
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -979,7 +1106,7 @@ class OpenAIClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Failed to connect to OpenAI: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# Handle httpx.RemoteProtocolError which can occur during streaming
|
||||
@@ -990,7 +1117,7 @@ class OpenAIClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Connection error during OpenAI streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# Handle httpx network errors which can occur during streaming
|
||||
@@ -1000,37 +1127,47 @@ class OpenAIClient(LLMClientBase):
|
||||
return LLMConnectionError(
|
||||
message=f"Network error during OpenAI streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}")
|
||||
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
|
||||
return LLMRateLimitError(
|
||||
message=f"Rate limited by OpenAI: {str(e)}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
details=e.body, # Include body which often has rate limit details
|
||||
details={**body_details, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.BadRequestError):
|
||||
logger.warning(f"[OpenAI] Bad request (400): {str(e)}")
|
||||
# BadRequestError can signify different issues (e.g., invalid args, context length)
|
||||
# Check for context_length_exceeded error code in the error body
|
||||
error_str = str(e)
|
||||
|
||||
if "<html" in error_str.lower() or (e.body and isinstance(e.body, str) and "<html" in e.body.lower()):
|
||||
logger.warning("[OpenAI] Received HTML error response from upstream endpoint (likely ALB or reverse proxy)")
|
||||
return LLMBadRequestError(
|
||||
message="Upstream endpoint returned HTML error (400 Bad Request). This usually indicates the configured API endpoint is not an OpenAI-compatible API or the request was rejected by a load balancer.",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details={"raw_body_preview": error_str[:500]},
|
||||
)
|
||||
|
||||
error_code = None
|
||||
if e.body and isinstance(e.body, dict):
|
||||
error_details = e.body.get("error", {})
|
||||
if isinstance(error_details, dict):
|
||||
error_code = error_details.get("code")
|
||||
|
||||
# Check both the error code and message content for context length issues
|
||||
if error_code == "context_length_exceeded" or is_context_window_overflow_message(str(e)):
|
||||
if error_code == "context_length_exceeded" or is_context_window_overflow_message(error_str):
|
||||
return ContextWindowExceededError(
|
||||
message=f"Bad request to OpenAI (context window exceeded): {str(e)}",
|
||||
message=f"Bad request to OpenAI (context window exceeded): {error_str}",
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
else:
|
||||
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
|
||||
return LLMBadRequestError(
|
||||
message=f"Bad request to OpenAI: {str(e)}",
|
||||
code=ErrorCode.INVALID_ARGUMENT, # Or more specific if detectable
|
||||
details=e.body,
|
||||
message=f"Bad request to OpenAI-compatible endpoint: {str(e)}",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details={**body_details, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# NOTE: The OpenAI Python SDK may raise a generic `openai.APIError` while *iterating*
|
||||
@@ -1040,46 +1177,91 @@ class OpenAIClient(LLMClientBase):
|
||||
#
|
||||
# Example message:
|
||||
# "Your input exceeds the context window of this model. Please adjust your input and try again."
|
||||
if isinstance(e, openai.APIError):
|
||||
if isinstance(e, openai.APIError) and not isinstance(e, openai.APIStatusError):
|
||||
msg = str(e)
|
||||
if is_context_window_overflow_message(msg):
|
||||
return ContextWindowExceededError(
|
||||
message=f"OpenAI request exceeded the context window: {msg}",
|
||||
details={
|
||||
"provider_exception_type": type(e).__name__,
|
||||
# Best-effort extraction (may not exist on APIError)
|
||||
"body": getattr(e, "body", None),
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
if is_insufficient_credits_message(msg):
|
||||
return LLMInsufficientCreditsError(
|
||||
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
|
||||
code=ErrorCode.PAYMENT_REQUIRED,
|
||||
details={
|
||||
"provider_exception_type": type(e).__name__,
|
||||
"body": getattr(e, "body", None),
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
return LLMBadRequestError(
|
||||
message=f"OpenAI API error: {msg}",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details={
|
||||
"provider_exception_type": type(e).__name__,
|
||||
"body": getattr(e, "body", None),
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.AuthenticationError):
|
||||
logger.error(f"[OpenAI] Authentication error (401): {str(e)}") # More severe log level
|
||||
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
|
||||
return LLMAuthenticationError(
|
||||
message=f"Authentication failed with OpenAI: {str(e)}", code=ErrorCode.UNAUTHENTICATED, details=e.body
|
||||
message=f"Authentication failed with OpenAI: {str(e)}",
|
||||
code=ErrorCode.UNAUTHENTICATED,
|
||||
details={**body_details, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.PermissionDeniedError):
|
||||
logger.error(f"[OpenAI] Permission denied (403): {str(e)}") # More severe log level
|
||||
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
|
||||
return LLMPermissionDeniedError(
|
||||
message=f"Permission denied by OpenAI: {str(e)}", code=ErrorCode.PERMISSION_DENIED, details=e.body
|
||||
message=f"Permission denied by OpenAI: {str(e)}",
|
||||
code=ErrorCode.PERMISSION_DENIED,
|
||||
details={**body_details, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.NotFoundError):
|
||||
logger.warning(f"[OpenAI] Resource not found (404): {str(e)}")
|
||||
# Could be invalid model name, etc.
|
||||
return LLMNotFoundError(message=f"Resource not found in OpenAI: {str(e)}", code=ErrorCode.NOT_FOUND, details=e.body)
|
||||
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
|
||||
return LLMNotFoundError(
|
||||
message=f"Resource not found in OpenAI: {str(e)}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
details={**body_details, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.UnprocessableEntityError):
|
||||
logger.warning(f"[OpenAI] Unprocessable entity (422): {str(e)}")
|
||||
body_details = e.body if isinstance(e.body, dict) else {"body": e.body}
|
||||
return LLMUnprocessableEntityError(
|
||||
message=f"Invalid request content for OpenAI: {str(e)}",
|
||||
code=ErrorCode.INVALID_ARGUMENT, # Usually validation errors
|
||||
details=e.body,
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details={**body_details, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
# General API error catch-all
|
||||
if isinstance(e, openai.APIStatusError):
|
||||
logger.warning(f"[OpenAI] API status error ({e.status_code}): {str(e)}")
|
||||
# Handle 413 Request Entity Too Large - request payload exceeds size limits
|
||||
if e.status_code == 413:
|
||||
return ContextWindowExceededError(
|
||||
message=f"Request too large for OpenAI (413): {str(e)}",
|
||||
details={"is_byok": is_byok},
|
||||
)
|
||||
# Handle 402 Payment Required or credit-related messages
|
||||
if e.status_code == 402 or is_insufficient_credits_message(str(e)):
|
||||
msg = str(e)
|
||||
return LLMInsufficientCreditsError(
|
||||
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
|
||||
code=ErrorCode.PAYMENT_REQUIRED,
|
||||
details={"status_code": e.status_code, "body": e.body, "is_byok": is_byok},
|
||||
)
|
||||
# Map based on status code potentially
|
||||
if e.status_code >= 500:
|
||||
error_cls = LLMServerError
|
||||
@@ -1096,11 +1278,12 @@ class OpenAIClient(LLMClientBase):
|
||||
"status_code": e.status_code,
|
||||
"response": str(e.response),
|
||||
"body": e.body,
|
||||
"is_byok": is_byok,
|
||||
},
|
||||
)
|
||||
|
||||
# Fallback for unexpected errors
|
||||
return super().handle_llm_error(e)
|
||||
return super().handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
|
||||
def fill_image_content_in_messages(openai_message_list: List[dict], pydantic_message_list: List[PydanticMessage]) -> List[dict]:
|
||||
|
||||
108
letta/llm_api/sglang_native_client.py
Normal file
108
letta/llm_api/sglang_native_client.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
SGLang Native Client for Letta.
|
||||
|
||||
This client uses SGLang's native /generate endpoint instead of the OpenAI-compatible
|
||||
/v1/chat/completions endpoint. The native endpoint returns token IDs and per-token
|
||||
logprobs, which are essential for multi-turn RL training.
|
||||
|
||||
The OpenAI-compatible endpoint only returns token strings, not IDs, making it
|
||||
impossible to accurately reconstruct the token sequence for training.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SGLangNativeClient:
|
||||
"""Client for SGLang's native /generate endpoint.
|
||||
|
||||
Unlike the OpenAI-compatible endpoint, this returns:
|
||||
- output_ids: List of token IDs
|
||||
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
|
||||
|
||||
This is essential for RL training where we need exact token IDs, not re-tokenized text.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the SGLang native client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL for SGLang server (e.g., http://localhost:30000)
|
||||
api_key: Optional API key for authentication
|
||||
"""
|
||||
# Remove /v1 suffix if present - native endpoint is at root
|
||||
self.base_url = base_url.rstrip("/")
|
||||
if self.base_url.endswith("/v1"):
|
||||
self.base_url = self.base_url[:-3]
|
||||
self.api_key = api_key
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
sampling_params: Optional[Dict[str, Any]] = None,
|
||||
return_logprob: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call SGLang's native /generate endpoint.
|
||||
|
||||
Args:
|
||||
text: The formatted prompt text (with chat template applied)
|
||||
sampling_params: Sampling parameters (temperature, max_new_tokens, etc.)
|
||||
return_logprob: Whether to return logprobs (default True for RL training)
|
||||
|
||||
Returns:
|
||||
Response dict with:
|
||||
- text: Generated text
|
||||
- output_ids: List of token IDs
|
||||
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
|
||||
- meta_info: Metadata including finish_reason, prompt_tokens, etc.
|
||||
|
||||
Example response:
|
||||
{
|
||||
"text": "Hello! How can I help?",
|
||||
"output_ids": [9707, 0, 2585, 646, 358, 1492, 30],
|
||||
"output_token_logprobs": [
|
||||
[-0.005, 9707, null],
|
||||
[0.0, 0, null],
|
||||
...
|
||||
],
|
||||
"meta_info": {
|
||||
"finish_reason": {"type": "stop", "matched": 151645},
|
||||
"prompt_tokens": 42,
|
||||
...
|
||||
}
|
||||
}
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"sampling_params": sampling_params or {},
|
||||
"return_logprob": return_logprob,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if the SGLang server is healthy."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(f"{self.base_url}/health")
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -34,6 +35,8 @@ class TogetherClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
|
||||
|
||||
if not api_key:
|
||||
|
||||
@@ -5,6 +5,7 @@ from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -59,6 +60,8 @@ class XAIClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
@@ -70,6 +73,8 @@ class XAIClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
|
||||
@@ -1,19 +1,30 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
def is_zai_reasoning_model(model_name: str) -> bool:
|
||||
"""Check if the model is a ZAI reasoning model (GLM-4.5+)."""
|
||||
return (
|
||||
model_name.startswith("glm-4.5")
|
||||
or model_name.startswith("glm-4.6")
|
||||
or model_name.startswith("glm-4.7")
|
||||
or model_name.startswith("glm-5")
|
||||
)
|
||||
|
||||
|
||||
class ZAIClient(OpenAIClient):
|
||||
"""Z.ai (ZhipuAI) client - uses OpenAI-compatible API."""
|
||||
|
||||
@@ -23,6 +34,10 @@ class ZAIClient(OpenAIClient):
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
"""Returns True if the model is a ZAI reasoning model (GLM-4.5+)."""
|
||||
return is_zai_reasoning_model(llm_config.model)
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
@@ -35,6 +50,50 @@ class ZAIClient(OpenAIClient):
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
# Add thinking configuration for ZAI GLM-4.5+ models
|
||||
# Must explicitly send type: "disabled" when reasoning is off, as GLM-4.7 has thinking on by default
|
||||
if self.is_reasoning_model(llm_config):
|
||||
if llm_config.enable_reasoner:
|
||||
data["extra_body"] = {
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"clear_thinking": False, # Preserved thinking for agents
|
||||
}
|
||||
}
|
||||
else:
|
||||
data["extra_body"] = {
|
||||
"thinking": {
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
|
||||
# Sanitize empty text content — ZAI rejects empty text blocks
|
||||
if "messages" in data:
|
||||
for msg in data["messages"]:
|
||||
content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", None)
|
||||
# String content: replace empty with None (assistant+tool_calls) or "."
|
||||
if isinstance(content, str) and not content.strip():
|
||||
role = msg.get("role") if isinstance(msg, dict) else getattr(msg, "role", None)
|
||||
has_tool_calls = msg.get("tool_calls") if isinstance(msg, dict) else getattr(msg, "tool_calls", None)
|
||||
if role == "assistant" and has_tool_calls:
|
||||
# assistant + tool_calls: null content is valid in OpenAI format
|
||||
if isinstance(msg, dict):
|
||||
msg["content"] = None
|
||||
else:
|
||||
msg.content = None
|
||||
else:
|
||||
if isinstance(msg, dict):
|
||||
msg["content"] = "."
|
||||
else:
|
||||
msg.content = "."
|
||||
# List content: fix empty text blocks within arrays
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
if not block.get("text", "").strip():
|
||||
block["text"] = "."
|
||||
|
||||
return data
|
||||
|
||||
@trace_method
|
||||
@@ -53,6 +112,8 @@ class ZAIClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous request to Z.ai API and returns raw response dict.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key = model_settings.zai_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
@@ -64,6 +125,8 @@ class ZAIClient(OpenAIClient):
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to Z.ai and returns the async stream iterator.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key = model_settings.zai_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
@@ -79,3 +142,39 @@ class ZAIClient(OpenAIClient):
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
return [r.embedding for r in response.data]
|
||||
|
||||
@trace_method
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts raw ZAI response dict into the ChatCompletionResponse Pydantic model.
|
||||
Handles extraction of reasoning_content from ZAI GLM-4.5+ responses.
|
||||
"""
|
||||
# Use parent class conversion first
|
||||
chat_completion_response = await super().convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
|
||||
# Parse reasoning_content from ZAI responses (similar to OpenAI pattern)
|
||||
# ZAI returns reasoning_content in delta.reasoning_content (streaming) or message.reasoning_content
|
||||
if (
|
||||
chat_completion_response.choices
|
||||
and len(chat_completion_response.choices) > 0
|
||||
and chat_completion_response.choices[0].message
|
||||
and not chat_completion_response.choices[0].message.reasoning_content
|
||||
):
|
||||
if "choices" in response_data and len(response_data["choices"]) > 0:
|
||||
choice_data = response_data["choices"][0]
|
||||
if "message" in choice_data and "reasoning_content" in choice_data["message"]:
|
||||
reasoning_content = choice_data["message"]["reasoning_content"]
|
||||
if reasoning_content:
|
||||
chat_completion_response.choices[0].message.reasoning_content = reasoning_content
|
||||
chat_completion_response.choices[0].message.reasoning_content_signature = None
|
||||
|
||||
# If we used a reasoning model, mark that reasoning content was used
|
||||
if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
|
||||
chat_completion_response.choices[0].message.omitted_reasoning_content = True
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
@@ -16,7 +16,7 @@ from letta.local_llm.llamacpp.api import get_llamacpp_completion
|
||||
from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
|
||||
from letta.local_llm.lmstudio.api import get_lmstudio_completion, get_lmstudio_completion_chatcompletions
|
||||
from letta.local_llm.ollama.api import get_ollama_completion
|
||||
from letta.local_llm.utils import count_tokens, get_available_wrappers
|
||||
from letta.local_llm.utils import get_available_wrappers
|
||||
from letta.local_llm.vllm.api import get_vllm_completion
|
||||
from letta.local_llm.webui.api import get_webui_completion
|
||||
from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
|
||||
@@ -177,7 +177,7 @@ def get_chat_completion(
|
||||
raise LocalLLMError(
|
||||
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
|
||||
)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}")
|
||||
|
||||
attributes = usage if isinstance(usage, dict) else {"usage": usage}
|
||||
@@ -207,10 +207,12 @@ def get_chat_completion(
|
||||
|
||||
if usage["prompt_tokens"] is None:
|
||||
printd("usage dict was missing prompt_tokens, computing on-the-fly...")
|
||||
usage["prompt_tokens"] = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
usage["prompt_tokens"] = len(prompt.encode("utf-8")) // 4
|
||||
|
||||
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
|
||||
usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result))
|
||||
# Approximate token count: bytes / 4
|
||||
usage["completion_tokens"] = len(json_dumps(chat_completion_result).encode("utf-8")) // 4
|
||||
"""
|
||||
if usage["completion_tokens"] is None:
|
||||
printd(f"usage dict was missing completion_tokens, computing on-the-fly...")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# (settings.py imports from this module indirectly through log.py)
|
||||
# Import this here to avoid circular dependency at module level
|
||||
from letta.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper
|
||||
from letta.settings import DEFAULT_WRAPPER_NAME, INNER_THOUGHTS_KWARG
|
||||
from letta.settings import INNER_THOUGHTS_KWARG
|
||||
|
||||
DEFAULT_WRAPPER = ChatMLInnerMonologueWrapper
|
||||
INNER_THOUGHTS_KWARG_VERTEX = "thinking"
|
||||
|
||||
@@ -5,7 +5,7 @@ from copy import copy
|
||||
from enum import Enum
|
||||
from inspect import getdoc, isclass
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAlias, get_args, get_origin
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAlias, get_args, get_origin # type: ignore[attr-defined]
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, create_model
|
||||
@@ -58,13 +58,13 @@ def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
|
||||
|
||||
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
|
||||
return format_model_and_field_name(pydantic_type.__name__)
|
||||
elif get_origin(pydantic_type) == list:
|
||||
elif get_origin(pydantic_type) is list:
|
||||
element_type = get_args(pydantic_type)[0]
|
||||
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
|
||||
elif get_origin(pydantic_type) == set:
|
||||
elif get_origin(pydantic_type) is set:
|
||||
element_type = get_args(pydantic_type)[0]
|
||||
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
|
||||
elif get_origin(pydantic_type) == Union:
|
||||
elif get_origin(pydantic_type) is Union:
|
||||
union_types = get_args(pydantic_type)
|
||||
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
|
||||
return f"union-{'-or-'.join(union_rules)}"
|
||||
@@ -73,7 +73,7 @@ def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
|
||||
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
|
||||
elif isclass(pydantic_type):
|
||||
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
|
||||
elif get_origin(pydantic_type) == dict:
|
||||
elif get_origin(pydantic_type) is dict:
|
||||
key_type, value_type = get_args(pydantic_type)
|
||||
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
|
||||
else:
|
||||
@@ -299,7 +299,7 @@ def generate_gbnf_rule_for_type(
|
||||
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
|
||||
rules.append(enum_rule)
|
||||
gbnf_type, rules = model_name + "-" + field_name, rules
|
||||
elif get_origin(field_type) == list: # Array
|
||||
elif get_origin(field_type) is list: # Array
|
||||
element_type = get_args(field_type)[0]
|
||||
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
|
||||
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
|
||||
@@ -309,7 +309,7 @@ def generate_gbnf_rule_for_type(
|
||||
rules.append(array_rule)
|
||||
gbnf_type, rules = model_name + "-" + field_name, rules
|
||||
|
||||
elif get_origin(field_type) == set or field_type == set: # Array
|
||||
elif get_origin(field_type) is set or field_type is set: # Array
|
||||
element_type = get_args(field_type)[0]
|
||||
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
|
||||
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
|
||||
@@ -320,7 +320,7 @@ def generate_gbnf_rule_for_type(
|
||||
gbnf_type, rules = model_name + "-" + field_name, rules
|
||||
|
||||
elif gbnf_type.startswith("custom-class-"):
|
||||
nested_model_rules, field_types = get_members_structure(field_type, gbnf_type)
|
||||
nested_model_rules, _field_types = get_members_structure(field_type, gbnf_type)
|
||||
rules.append(nested_model_rules)
|
||||
elif gbnf_type.startswith("custom-dict-"):
|
||||
key_type, value_type = get_args(field_type)
|
||||
@@ -502,15 +502,15 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
|
||||
model_rule += '"\\n" ws "}"'
|
||||
model_rule += '"\\n" markdown-code-block'
|
||||
has_special_string = True
|
||||
all_rules = [model_rule] + nested_rules
|
||||
all_rules = [model_rule, *nested_rules]
|
||||
|
||||
return all_rules, has_special_string
|
||||
|
||||
|
||||
def generate_gbnf_grammar_from_pydantic_models(
|
||||
models: List[Type[BaseModel]],
|
||||
outer_object_name: str = None,
|
||||
outer_object_content: str = None,
|
||||
outer_object_name: str | None = None,
|
||||
outer_object_content: str | None = None,
|
||||
list_of_outputs: bool = False,
|
||||
add_inner_thoughts: bool = False,
|
||||
allow_only_inner_thoughts: bool = False,
|
||||
@@ -704,11 +704,11 @@ def generate_markdown_documentation(
|
||||
# continue
|
||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
pyd_models.append((field_type, False))
|
||||
if get_origin(field_type) == list:
|
||||
if get_origin(field_type) is list:
|
||||
element_type = get_args(field_type)[0]
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
pyd_models.append((element_type, False))
|
||||
if get_origin(field_type) == Union:
|
||||
if get_origin(field_type) is Union:
|
||||
element_types = get_args(field_type)
|
||||
for element_type in element_types:
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
@@ -747,14 +747,14 @@ def generate_field_markdown(
|
||||
field_info = model.model_fields.get(field_name)
|
||||
field_description = field_info.description if field_info and field_info.description else ""
|
||||
|
||||
if get_origin(field_type) == list:
|
||||
if get_origin(field_type) is list:
|
||||
element_type = get_args(field_type)[0]
|
||||
field_text = f"{indent}{field_name} ({field_type.__name__} of {element_type.__name__})"
|
||||
if field_description != "":
|
||||
field_text += ": "
|
||||
else:
|
||||
field_text += "\n"
|
||||
elif get_origin(field_type) == Union:
|
||||
elif get_origin(field_type) is Union:
|
||||
element_types = get_args(field_type)
|
||||
types = []
|
||||
for element_type in element_types:
|
||||
@@ -857,11 +857,11 @@ def generate_text_documentation(
|
||||
for name, field_type in model.__annotations__.items():
|
||||
# if name == "markdown_code_block":
|
||||
# continue
|
||||
if get_origin(field_type) == list:
|
||||
if get_origin(field_type) is list:
|
||||
element_type = get_args(field_type)[0]
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
pyd_models.append((element_type, False))
|
||||
if get_origin(field_type) == Union:
|
||||
if get_origin(field_type) is Union:
|
||||
element_types = get_args(field_type)
|
||||
for element_type in element_types:
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
@@ -905,14 +905,14 @@ def generate_field_text(
|
||||
field_info = model.model_fields.get(field_name)
|
||||
field_description = field_info.description if field_info and field_info.description else ""
|
||||
|
||||
if get_origin(field_type) == list:
|
||||
if get_origin(field_type) is list:
|
||||
element_type = get_args(field_type)[0]
|
||||
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
|
||||
if field_description != "":
|
||||
field_text += ":\n"
|
||||
else:
|
||||
field_text += "\n"
|
||||
elif get_origin(field_type) == Union:
|
||||
elif get_origin(field_type) is Union:
|
||||
element_types = get_args(field_type)
|
||||
types = []
|
||||
for element_type in element_types:
|
||||
@@ -1015,8 +1015,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
|
||||
pydantic_model_list,
|
||||
grammar_file_path="./generated_grammar.gbnf",
|
||||
documentation_file_path="./generated_grammar_documentation.md",
|
||||
outer_object_name: str = None,
|
||||
outer_object_content: str = None,
|
||||
outer_object_name: str | None = None,
|
||||
outer_object_content: str | None = None,
|
||||
model_prefix: str = "Output Model",
|
||||
fields_prefix: str = "Output Fields",
|
||||
list_of_outputs: bool = False,
|
||||
@@ -1049,8 +1049,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
|
||||
|
||||
def generate_gbnf_grammar_and_documentation(
|
||||
pydantic_model_list,
|
||||
outer_object_name: str = None,
|
||||
outer_object_content: str = None,
|
||||
outer_object_name: str | None = None,
|
||||
outer_object_content: str | None = None,
|
||||
model_prefix: str = "Output Model",
|
||||
fields_prefix: str = "Output Fields",
|
||||
list_of_outputs: bool = False,
|
||||
@@ -1087,8 +1087,8 @@ def generate_gbnf_grammar_and_documentation(
|
||||
|
||||
def generate_gbnf_grammar_and_documentation_from_dictionaries(
|
||||
dictionaries: List[dict],
|
||||
outer_object_name: str = None,
|
||||
outer_object_content: str = None,
|
||||
outer_object_name: str | None = None,
|
||||
outer_object_content: str | None = None,
|
||||
model_prefix: str = "Output Model",
|
||||
fields_prefix: str = "Output Fields",
|
||||
list_of_outputs: bool = False,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import count_tokens, post_json_auth_request
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
|
||||
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
|
||||
|
||||
@@ -10,7 +10,8 @@ def get_koboldcpp_completion(endpoint, auth_type, auth_key, prompt, context_wind
|
||||
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import count_tokens, post_json_auth_request
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
|
||||
LLAMACPP_API_SUFFIX = "/completion"
|
||||
|
||||
@@ -10,7 +10,8 @@ def get_llamacpp_completion(endpoint, auth_type, auth_key, prompt, context_windo
|
||||
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user