MCP server connection failures were raising Python's builtin ConnectionError, which bypassed the LettaMCPConnectionError FastAPI exception handler and hit Datadog as unhandled 500 errors. Now all MCP client classes convert ConnectionError to LettaMCPConnectionError at the source, which the existing exception handler returns as a user-friendly 502. Datadog: https://us5.datadoghq.com/error-tracking/issue/93db4a82-fe5a-11f0-85f0-da7ad0900000 🐛 Generated with [Letta Code](https://letta.com) Co-authored-by: Letta <noreply@letta.com>
1880 lines
86 KiB
Python
1880 lines
86 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import traceback
|
|
from abc import abstractmethod
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import httpx
|
|
from anthropic import AsyncAnthropic
|
|
from fastapi import HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
import letta.constants as constants
|
|
import letta.server.utils as server_utils
|
|
import letta.system as system
|
|
from letta.config import LettaConfig
|
|
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
|
from letta.data_sources.connectors import DataConnector, load_data
|
|
from letta.errors import (
|
|
HandleNotFoundError,
|
|
LettaInvalidArgumentError,
|
|
LettaMCPConnectionError,
|
|
LettaMCPTimeoutError,
|
|
)
|
|
from letta.functions.mcp_client.types import MCPServerType, MCPTool, MCPToolHealth, SSEServerConfig, StdioServerConfig
|
|
from letta.functions.schema_validator import validate_complete_json_schema
|
|
from letta.groups.helpers import load_multi_agent
|
|
from letta.helpers.datetime_helpers import get_utc_time
|
|
from letta.helpers.json_helpers import json_dumps, json_loads
|
|
|
|
# TODO use custom interface
|
|
from letta.interface import (
|
|
AgentInterface, # abstract
|
|
CLIInterface, # for printing to terminal
|
|
)
|
|
from letta.log import get_logger
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.otel.tracing import log_event, trace_method
|
|
from letta.prompts.gpt_system import get_system_text
|
|
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
|
from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
|
|
# openai schemas
|
|
from letta.schemas.enums import AgentType, JobStatus, MessageStreamStatus, ProviderCategory, ProviderType, SandboxType, ToolSourceType
|
|
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
|
from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
|
|
from letta.schemas.job import Job, JobUpdate
|
|
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, MessageType, ToolReturnMessage
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.letta_response import LettaResponse
|
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
|
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
|
from letta.schemas.passage import Passage
|
|
from letta.schemas.pip_requirement import PipRequirement
|
|
from letta.schemas.providers import (
|
|
AnthropicProvider,
|
|
AzureProvider,
|
|
BedrockProvider,
|
|
DeepSeekProvider,
|
|
GoogleAIProvider,
|
|
GoogleVertexProvider,
|
|
GroqProvider,
|
|
LettaProvider,
|
|
LMStudioOpenAIProvider,
|
|
MiniMaxProvider,
|
|
OllamaProvider,
|
|
OpenAIProvider,
|
|
OpenRouterProvider,
|
|
Provider,
|
|
SGLangProvider,
|
|
TogetherProvider,
|
|
VLLMProvider,
|
|
XAIProvider,
|
|
ZAIProvider,
|
|
)
|
|
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate
|
|
from letta.schemas.secret import Secret
|
|
from letta.schemas.source import Source
|
|
from letta.schemas.tool import Tool
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.schemas.user import User
|
|
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
|
|
from letta.server.rest_api.interface import StreamingServerInterface
|
|
from letta.server.rest_api.utils import sse_async_generator
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.agent_serialization_manager import AgentSerializationManager
|
|
from letta.services.archive_manager import ArchiveManager
|
|
from letta.services.block_manager import BlockManager
|
|
from letta.services.block_manager_git import GIT_MEMORY_ENABLED_TAG, GitEnabledBlockManager
|
|
from letta.services.file_manager import FileManager
|
|
from letta.services.files_agents_manager import FileAgentManager
|
|
from letta.services.group_manager import GroupManager
|
|
from letta.services.helpers.tool_execution_helper import prepare_local_sandbox
|
|
from letta.services.identity_manager import IdentityManager
|
|
from letta.services.job_manager import JobManager
|
|
from letta.services.llm_batch_manager import LLMBatchManager
|
|
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
|
from letta.services.mcp.fastmcp_client import AsyncFastMCPSSEClient
|
|
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY
|
|
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
|
from letta.services.mcp_manager import MCPManager
|
|
from letta.services.mcp_server_manager import MCPServerManager
|
|
from letta.services.memory_repo import MemoryRepoManager
|
|
from letta.services.memory_repo.storage.gcs import GCSStorageBackend
|
|
from letta.services.message_manager import MessageManager
|
|
from letta.services.organization_manager import OrganizationManager
|
|
from letta.services.passage_manager import PassageManager
|
|
from letta.services.provider_manager import ProviderManager
|
|
from letta.services.run_manager import RunManager
|
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
|
from letta.services.source_manager import SourceManager
|
|
from letta.services.step_manager import StepManager
|
|
from letta.services.telemetry_manager import TelemetryManager
|
|
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
|
from letta.services.tool_manager import ToolManager
|
|
from letta.services.user_manager import UserManager
|
|
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
|
|
from letta.streaming_interface import AgentChunkStreamingInterface
|
|
from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task
|
|
|
|
config = LettaConfig.load()
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class SyncServer(object):
|
|
"""Simple single-threaded / blocking server process"""
|
|
|
|
def __init__(
|
|
self,
|
|
chaining: bool = True,
|
|
max_chaining_steps: Optional[int] = 100,
|
|
default_interface_factory: Callable[[], AgentChunkStreamingInterface] = lambda: CLIInterface(),
|
|
init_with_default_org_and_user: bool = True,
|
|
# default_interface: AgentInterface = CLIInterface(),
|
|
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
|
# auth_mode: str = "none", # "none, "jwt", "external"
|
|
):
|
|
"""Server process holds in-memory agents that are being run"""
|
|
# chaining = whether or not to run again if request_heartbeat=true
|
|
self.chaining = chaining
|
|
|
|
# if chaining == true, what's the max number of times we'll chain before yielding?
|
|
# none = no limit, can go on forever
|
|
self.max_chaining_steps = max_chaining_steps
|
|
|
|
# The default interface that will get assigned to agents ON LOAD
|
|
self.default_interface_factory = default_interface_factory
|
|
|
|
# Initialize the metadata store
|
|
config = LettaConfig.load()
|
|
if settings.database_engine is DatabaseChoice.POSTGRES:
|
|
config.recall_storage_type = "postgres"
|
|
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
|
config.archival_storage_type = "postgres"
|
|
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
|
config.save()
|
|
self.config = config
|
|
|
|
# Managers that interface with data models
|
|
self.organization_manager = OrganizationManager()
|
|
self.passage_manager = PassageManager()
|
|
self.user_manager = UserManager()
|
|
self.tool_manager = ToolManager()
|
|
self.mcp_manager = MCPManager()
|
|
self.mcp_server_manager = MCPServerManager()
|
|
self.memory_repo_manager = self._init_memory_repo_manager()
|
|
# Use git-enabled block manager if memory repo is configured
|
|
# It falls back to standard PostgreSQL behavior when git isn't enabled for an agent
|
|
if self.memory_repo_manager:
|
|
self.block_manager = GitEnabledBlockManager(memory_repo_manager=self.memory_repo_manager)
|
|
else:
|
|
self.block_manager = BlockManager()
|
|
self.source_manager = SourceManager()
|
|
self.sandbox_config_manager = SandboxConfigManager()
|
|
self.message_manager = MessageManager()
|
|
self.job_manager = JobManager()
|
|
self.run_manager = RunManager()
|
|
self.agent_manager = AgentManager(block_manager=self.block_manager)
|
|
self.archive_manager = ArchiveManager()
|
|
self.provider_manager = ProviderManager()
|
|
self.step_manager = StepManager()
|
|
self.identity_manager = IdentityManager()
|
|
self.group_manager = GroupManager()
|
|
self.batch_manager = LLMBatchManager()
|
|
self.telemetry_manager = TelemetryManager()
|
|
self.file_agent_manager = FileAgentManager()
|
|
self.file_manager = FileManager()
|
|
|
|
# Import and initialize the agent generate completion manager
|
|
from letta.services.agent_generate_completion_manager import AgentGenerateCompletionManager
|
|
|
|
self.agent_generate_completion_manager = AgentGenerateCompletionManager(server=self)
|
|
|
|
self.agent_serialization_manager = AgentSerializationManager(
|
|
agent_manager=self.agent_manager,
|
|
tool_manager=self.tool_manager,
|
|
source_manager=self.source_manager,
|
|
block_manager=self.block_manager,
|
|
group_manager=self.group_manager,
|
|
mcp_manager=self.mcp_manager,
|
|
file_manager=self.file_manager,
|
|
file_agent_manager=self.file_agent_manager,
|
|
message_manager=self.message_manager,
|
|
)
|
|
|
|
if settings.enable_batch_job_polling:
|
|
# A resusable httpx client
|
|
timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0)
|
|
limits = httpx.Limits(max_connections=100, max_keepalive_connections=80, keepalive_expiry=300)
|
|
self.httpx_client = httpx.AsyncClient(timeout=timeout, follow_redirects=True, limits=limits)
|
|
|
|
# TODO: Replace this with the Anthropic client we have in house
|
|
# Reuse the shared httpx client to prevent duplicate SSL contexts and connection pools
|
|
self.anthropic_async_client = AsyncAnthropic(http_client=self.httpx_client)
|
|
else:
|
|
self.httpx_client = None
|
|
self.anthropic_async_client = None
|
|
|
|
# For MCP
|
|
# TODO: remove this
|
|
"""Initialize the MCP clients (there may be multiple)"""
|
|
self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {}
|
|
|
|
# collect providers (always has Letta as a default)
|
|
from letta.constants import LETTA_MODEL_ENDPOINT
|
|
|
|
self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)]
|
|
if model_settings.openai_api_key:
|
|
self._enabled_providers.append(
|
|
OpenAIProvider(
|
|
name="openai",
|
|
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
|
|
base_url=model_settings.openai_api_base,
|
|
)
|
|
)
|
|
if model_settings.anthropic_api_key:
|
|
self._enabled_providers.append(
|
|
AnthropicProvider(
|
|
name="anthropic",
|
|
api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key),
|
|
)
|
|
)
|
|
if model_settings.ollama_base_url:
|
|
self._enabled_providers.append(
|
|
OllamaProvider(
|
|
name="ollama",
|
|
base_url=model_settings.ollama_base_url,
|
|
default_prompt_formatter=model_settings.default_prompt_formatter,
|
|
)
|
|
)
|
|
if model_settings.gemini_api_key:
|
|
self._enabled_providers.append(
|
|
GoogleAIProvider(
|
|
name="google_ai",
|
|
api_key_enc=Secret.from_plaintext(model_settings.gemini_api_key),
|
|
)
|
|
)
|
|
if model_settings.google_cloud_location and model_settings.google_cloud_project:
|
|
self._enabled_providers.append(
|
|
GoogleVertexProvider(
|
|
name="google_vertex",
|
|
google_cloud_project=model_settings.google_cloud_project,
|
|
google_cloud_location=model_settings.google_cloud_location,
|
|
)
|
|
)
|
|
if model_settings.azure_api_key and model_settings.azure_base_url:
|
|
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
|
|
self._enabled_providers.append(
|
|
AzureProvider(
|
|
name="azure",
|
|
api_key_enc=Secret.from_plaintext(model_settings.azure_api_key),
|
|
base_url=model_settings.azure_base_url,
|
|
api_version=model_settings.azure_api_version,
|
|
)
|
|
)
|
|
if model_settings.groq_api_key:
|
|
self._enabled_providers.append(
|
|
GroqProvider(
|
|
name="groq",
|
|
api_key_enc=Secret.from_plaintext(model_settings.groq_api_key),
|
|
)
|
|
)
|
|
if model_settings.together_api_key:
|
|
self._enabled_providers.append(
|
|
TogetherProvider(
|
|
name="together",
|
|
api_key_enc=Secret.from_plaintext(model_settings.together_api_key),
|
|
default_prompt_formatter=model_settings.default_prompt_formatter,
|
|
)
|
|
)
|
|
if model_settings.vllm_api_base:
|
|
# vLLM exposes both a /chat/completions and a /completions endpoint
|
|
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
|
|
# see: https://docs.vllm.ai/en/stable/features/tool_calling.html
|
|
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
|
|
# Auto-append /v1 to the base URL
|
|
vllm_url = (
|
|
model_settings.vllm_api_base if model_settings.vllm_api_base.endswith("/v1") else model_settings.vllm_api_base + "/v1"
|
|
)
|
|
self._enabled_providers.append(
|
|
VLLMProvider(
|
|
name="vllm",
|
|
base_url=vllm_url,
|
|
default_prompt_formatter=model_settings.default_prompt_formatter,
|
|
handle_base=model_settings.vllm_handle_base,
|
|
)
|
|
)
|
|
|
|
if model_settings.sglang_api_base:
|
|
# Auto-append /v1 to the base URL
|
|
sglang_url = (
|
|
model_settings.sglang_api_base if model_settings.sglang_api_base.endswith("/v1") else model_settings.sglang_api_base + "/v1"
|
|
)
|
|
self._enabled_providers.append(
|
|
SGLangProvider(
|
|
name="sglang",
|
|
base_url=sglang_url,
|
|
default_prompt_formatter=model_settings.default_prompt_formatter,
|
|
handle_base=model_settings.sglang_handle_base,
|
|
)
|
|
)
|
|
|
|
if model_settings.aws_access_key_id and model_settings.aws_secret_access_key and model_settings.aws_default_region:
|
|
self._enabled_providers.append(
|
|
BedrockProvider(
|
|
name="bedrock",
|
|
access_key=model_settings.aws_access_key_id,
|
|
api_key=model_settings.aws_secret_access_key,
|
|
region=model_settings.aws_default_region,
|
|
)
|
|
)
|
|
# Attempt to enable LM Studio by default
|
|
if model_settings.lmstudio_base_url:
|
|
# Auto-append v1 to the base URL
|
|
lmstudio_url = (
|
|
model_settings.lmstudio_base_url
|
|
if model_settings.lmstudio_base_url.endswith("/v1")
|
|
else model_settings.lmstudio_base_url + "/v1"
|
|
)
|
|
self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url))
|
|
if model_settings.deepseek_api_key:
|
|
self._enabled_providers.append(
|
|
DeepSeekProvider(
|
|
name="deepseek",
|
|
api_key_enc=Secret.from_plaintext(model_settings.deepseek_api_key),
|
|
)
|
|
)
|
|
if model_settings.xai_api_key:
|
|
self._enabled_providers.append(
|
|
XAIProvider(
|
|
name="xai",
|
|
api_key_enc=Secret.from_plaintext(model_settings.xai_api_key),
|
|
)
|
|
)
|
|
if model_settings.minimax_api_key:
|
|
self._enabled_providers.append(
|
|
MiniMaxProvider(
|
|
name="minimax",
|
|
api_key_enc=Secret.from_plaintext(model_settings.minimax_api_key),
|
|
)
|
|
)
|
|
if model_settings.zai_api_key:
|
|
self._enabled_providers.append(
|
|
ZAIProvider(
|
|
name="zai",
|
|
api_key_enc=Secret.from_plaintext(model_settings.zai_api_key),
|
|
base_url=model_settings.zai_base_url,
|
|
)
|
|
)
|
|
if model_settings.openrouter_api_key:
|
|
self._enabled_providers.append(
|
|
OpenRouterProvider(
|
|
name=model_settings.openrouter_handle_base if model_settings.openrouter_handle_base else "openrouter",
|
|
api_key_enc=Secret.from_plaintext(model_settings.openrouter_api_key),
|
|
)
|
|
)
|
|
|
|
async def init_async(self, init_with_default_org_and_user: bool = True):
|
|
# Make default user and org
|
|
if init_with_default_org_and_user:
|
|
self.default_org = await self.organization_manager.create_default_organization_async()
|
|
self.default_user = await self.user_manager.create_default_actor_async()
|
|
print(f"Default user: {self.default_user} and org: {self.default_org}")
|
|
await self.tool_manager.upsert_base_tools_async(actor=self.default_user)
|
|
|
|
# Sync environment-based providers to database (idempotent, safe for multi-pod startup)
|
|
await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user)
|
|
|
|
# Sync provider models to database
|
|
await self._sync_provider_models_async()
|
|
|
|
# For OSS users, create a local sandbox config
|
|
oss_default_user = await self.user_manager.get_default_actor_async()
|
|
use_venv = False if not tool_settings.tool_exec_venv_name else True
|
|
venv_name = tool_settings.tool_exec_venv_name or "venv"
|
|
tool_dir = tool_settings.tool_exec_dir or LETTA_TOOL_EXECUTION_DIR
|
|
|
|
venv_dir = Path(tool_dir) / venv_name
|
|
tool_path = Path(tool_dir)
|
|
|
|
if tool_path.exists() and not tool_path.is_dir():
|
|
logger.error(f"LETTA_TOOL_SANDBOX_DIR exists but is not a directory: {tool_dir}")
|
|
else:
|
|
if not tool_path.exists():
|
|
logger.warning(f"LETTA_TOOL_SANDBOX_DIR does not exist, creating now: {tool_dir}")
|
|
tool_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
if tool_settings.tool_exec_venv_name and not venv_dir.is_dir():
|
|
logger.warning(
|
|
f"Provided LETTA_TOOL_SANDBOX_VENV_NAME is not a valid venv ({venv_dir}), one will be created for you during tool execution."
|
|
)
|
|
|
|
sandbox_config_create = SandboxConfigCreate(
|
|
config=LocalSandboxConfig(sandbox_dir=tool_settings.tool_exec_dir, use_venv=use_venv, venv_name=venv_name)
|
|
)
|
|
sandbox_config = await self.sandbox_config_manager.create_or_update_sandbox_config_async(
|
|
sandbox_config_create=sandbox_config_create, actor=oss_default_user
|
|
)
|
|
logger.debug(f"Successfully created default local sandbox config:\n{sandbox_config.get_local_config().model_dump()}")
|
|
|
|
if use_venv and tool_settings.tool_exec_autoreload_venv:
|
|
prepare_local_sandbox(
|
|
sandbox_config.get_local_config(),
|
|
env=os.environ.copy(),
|
|
force_recreate=True,
|
|
)
|
|
|
|
def _init_memory_repo_manager(self) -> Optional[MemoryRepoManager]:
|
|
"""Initialize the memory repository manager if configured.
|
|
|
|
If LETTA_MEMFS_SERVICE_URL is set, uses the external memfs service.
|
|
Otherwise, configure the object store via settings (recommended):
|
|
|
|
LETTA_OBJECT_STORE_URI="gs://my-bucket/repository?project=my-gcp-project"
|
|
|
|
Supported schemes:
|
|
- gs:// (or gcs://) -> Google Cloud Storage
|
|
|
|
Returns:
|
|
MemoryRepoManager (or MemfsClient) if configured, None otherwise
|
|
"""
|
|
|
|
# Keep import local to avoid import/circular issues during server bootstrap.
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
from letta.settings import settings
|
|
|
|
# Check if memfs service is configured (takes priority over local object store)
|
|
if settings.memfs_service_url:
|
|
from letta.services.memory_repo import MemfsClient
|
|
|
|
logger.info("Memory repo manager using memfs service: %s", settings.memfs_service_url)
|
|
return MemfsClient(base_url=settings.memfs_service_url)
|
|
|
|
uri = settings.object_store_uri
|
|
if not uri:
|
|
logger.debug("Memory repo manager not configured (object_store_uri not set)")
|
|
return None
|
|
|
|
try:
|
|
parsed = urlparse(uri)
|
|
scheme = (parsed.scheme or "").lower()
|
|
|
|
if scheme in {"gs", "gcs"}:
|
|
bucket = parsed.netloc
|
|
if not bucket:
|
|
raise ValueError(f"Invalid GCS object store URI (missing bucket): {uri}")
|
|
|
|
# URI path is treated as the storage prefix
|
|
prefix = parsed.path.lstrip("/") or "repository"
|
|
qs = parse_qs(parsed.query)
|
|
|
|
# Allow settings-level overrides (handy for templated URIs).
|
|
project = settings.object_store_project or (qs.get("project") or [None])[0]
|
|
|
|
storage = GCSStorageBackend(bucket=bucket, prefix=prefix, project=project)
|
|
logger.info("Memory repo manager initialized with object store: %s", uri)
|
|
return MemoryRepoManager(storage=storage)
|
|
|
|
raise ValueError(f"Unsupported object store scheme '{scheme}' in URI: {uri}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize memory repo manager: {e}")
|
|
return None
|
|
|
|
def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]:
|
|
"""Find and return an enabled provider by name.
|
|
|
|
Args:
|
|
provider_name: The name of the provider to find
|
|
|
|
Returns:
|
|
The matching enabled provider, or None if not found
|
|
"""
|
|
for provider in self._enabled_providers:
|
|
if provider.name == provider_name:
|
|
return provider
|
|
return None
|
|
|
|
async def _sync_provider_models_async(self):
|
|
"""Sync all provider models to database at startup."""
|
|
logger.info("Syncing provider models to database")
|
|
|
|
# Get persisted providers from database (they now have IDs)
|
|
persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user)
|
|
|
|
for persisted_provider in persisted_providers:
|
|
try:
|
|
# Find the matching enabled provider instance to call list_models on
|
|
enabled_provider = self._get_enabled_provider(persisted_provider.name)
|
|
|
|
if not enabled_provider:
|
|
# Only delete base providers that are no longer enabled
|
|
# BYOK providers are user-created and should not be automatically deleted
|
|
if persisted_provider.provider_category == ProviderCategory.base:
|
|
logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database")
|
|
try:
|
|
await self.provider_manager.delete_provider_by_id_async(
|
|
provider_id=persisted_provider.id, actor=self.default_user
|
|
)
|
|
except NoResultFound:
|
|
# Provider was already deleted (race condition in multi-pod startup)
|
|
logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping")
|
|
else:
|
|
logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync")
|
|
continue
|
|
|
|
# Fetch models from provider
|
|
llm_models = await enabled_provider.list_llm_models_async()
|
|
embedding_models = await enabled_provider.list_embedding_models_async()
|
|
|
|
# Save to database with the persisted provider (which has an ID)
|
|
await self.provider_manager.sync_provider_models_async(
|
|
provider=persisted_provider,
|
|
llm_models=llm_models,
|
|
embedding_models=embedding_models,
|
|
organization_id=None, # Global models
|
|
)
|
|
# Update last_synced timestamp
|
|
await self.provider_manager.update_provider_last_synced_async(persisted_provider.id)
|
|
logger.info(
|
|
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True)
|
|
|
|
async def init_mcp_clients(self):
|
|
# TODO: remove this
|
|
mcp_server_configs = self.get_mcp_servers()
|
|
|
|
for server_name, server_config in mcp_server_configs.items():
|
|
if server_config.type == MCPServerType.SSE:
|
|
self.mcp_clients[server_name] = AsyncFastMCPSSEClient(server_config)
|
|
elif server_config.type == MCPServerType.STDIO:
|
|
self.mcp_clients[server_name] = AsyncStdioMCPClient(server_config)
|
|
else:
|
|
raise LettaInvalidArgumentError(f"Invalid MCP server config: {server_config}", argument_name="server_config")
|
|
|
|
try:
|
|
await self.mcp_clients[server_name].connect_to_server()
|
|
except Exception as e:
|
|
logger.error(e)
|
|
self.mcp_clients.pop(server_name)
|
|
|
|
logger.info(f"MCP clients initialized: {len(self.mcp_clients)} active connections")
|
|
|
|
# Print out the tools that are connected
|
|
for server_name, client in self.mcp_clients.items():
|
|
logger.info(f"Attempting to fetch tools from MCP server: {server_name}")
|
|
mcp_tools = await client.list_tools()
|
|
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
|
|
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
|
|
|
|
@trace_method
|
|
async def create_agent_async(
|
|
self,
|
|
request: CreateAgent,
|
|
actor: User,
|
|
) -> AgentState:
|
|
if request.llm_config is None:
|
|
additional_config_params = {}
|
|
if request.model is None:
|
|
if settings.default_llm_handle is None:
|
|
raise LettaInvalidArgumentError("Must specify either model or llm_config in request", argument_name="model")
|
|
else:
|
|
handle = settings.default_llm_handle
|
|
else:
|
|
if isinstance(request.model, str):
|
|
handle = request.model
|
|
elif isinstance(request.model, list):
|
|
raise LettaInvalidArgumentError("Multiple models are not supported yet")
|
|
else:
|
|
# EXTREMELEY HACKY, TEMPORARY WORKAROUND
|
|
handle = f"{request.model.provider}/{request.model.model}"
|
|
# TODO: figure out how to override various params
|
|
additional_config_params = request.model._to_legacy_config_params()
|
|
additional_config_params["model"] = request.model.model
|
|
additional_config_params["provider_name"] = request.model.provider
|
|
|
|
config_params = {
|
|
"handle": handle,
|
|
"context_window_limit": request.context_window_limit,
|
|
"max_tokens": request.max_tokens,
|
|
"max_reasoning_tokens": request.max_reasoning_tokens,
|
|
"enable_reasoner": request.enable_reasoner,
|
|
}
|
|
log_event(name="start get_llm_config_from_handle", attributes=config_params)
|
|
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
|
|
log_event(name="end get_llm_config_from_handle", attributes=config_params)
|
|
if request.model and isinstance(request.model, str):
|
|
assert request.llm_config.handle == request.model, (
|
|
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
|
|
)
|
|
|
|
# update with model_settings
|
|
if request.model_settings is not None:
|
|
update_llm_config_params = request.model_settings._to_legacy_config_params()
|
|
request.llm_config = request.llm_config.model_copy(update=update_llm_config_params)
|
|
|
|
# Copy parallel_tool_calls from request to llm_config if provided
|
|
if request.parallel_tool_calls is not None:
|
|
request.llm_config.parallel_tool_calls = request.parallel_tool_calls
|
|
|
|
if request.reasoning is None:
|
|
request.reasoning = request.llm_config.enable_reasoner or request.llm_config.put_inner_thoughts_in_kwargs
|
|
|
|
if request.embedding_config is None:
|
|
if request.embedding is None:
|
|
if settings.default_embedding_handle is not None:
|
|
request.embedding = settings.default_embedding_handle
|
|
# Only resolve embedding config if we have an embedding handle
|
|
if request.embedding is not None:
|
|
embedding_config_params = {
|
|
"handle": request.embedding,
|
|
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
|
}
|
|
log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params)
|
|
request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params)
|
|
log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params)
|
|
|
|
# If git-backed memory is requested on create, we enable it *after* agent creation.
|
|
# We strip the tag during creation so `enable_git_memory_for_agent` can be the
|
|
# single place that both creates the repo and writes the tag.
|
|
wants_git_memory = bool(request.tags and GIT_MEMORY_ENABLED_TAG in request.tags)
|
|
create_request = request
|
|
if wants_git_memory:
|
|
filtered_tags = [t for t in (request.tags or []) if t != GIT_MEMORY_ENABLED_TAG]
|
|
updates: dict = {"tags": filtered_tags}
|
|
|
|
# Transform block labels to path-based for git-memory agents.
|
|
# Blocks without a "/" prefix go under system/ (rendered in system prompt).
|
|
# e.g. "human" -> "system/human", "persona" -> "system/persona"
|
|
# Blocks with an explicit path (e.g. "notes/project") keep their label.
|
|
if request.memory_blocks:
|
|
transformed_blocks = []
|
|
for block in request.memory_blocks:
|
|
if not block.label.startswith("system/"):
|
|
block = block.model_copy(update={"label": f"system/{block.label}"})
|
|
transformed_blocks.append(block)
|
|
updates["memory_blocks"] = transformed_blocks
|
|
|
|
create_request = request.model_copy(update=updates)
|
|
|
|
log_event(name="start create_agent db")
|
|
main_agent = await self.agent_manager.create_agent_async(
|
|
agent_create=create_request,
|
|
actor=actor,
|
|
)
|
|
log_event(name="end create_agent db")
|
|
|
|
# Enable git-backed memory (creates repo + commits initial blocks + adds tag)
|
|
if wants_git_memory and isinstance(self.block_manager, GitEnabledBlockManager):
|
|
await self.block_manager.enable_git_memory_for_agent(agent_id=main_agent.id, actor=actor)
|
|
# Preserve the user's requested tags and git_enabled flag in the response model.
|
|
try:
|
|
main_agent.tags = list(request.tags or [])
|
|
main_agent.memory.git_enabled = True
|
|
except Exception:
|
|
pass
|
|
|
|
log_event(name="start insert_files_into_context_window db")
|
|
# Use folder_ids if provided, otherwise fall back to deprecated source_ids for backwards compatibility
|
|
folder_ids_to_attach = request.folder_ids if request.folder_ids else request.source_ids
|
|
if folder_ids_to_attach:
|
|
for folder_id in folder_ids_to_attach:
|
|
files = await self.file_manager.list_files(folder_id, actor, include_content=True)
|
|
await self.agent_manager.insert_files_into_context_window(
|
|
agent_state=main_agent, file_metadata_with_content=files, actor=actor
|
|
)
|
|
|
|
main_agent = await self.agent_manager.refresh_file_blocks(agent_state=main_agent, actor=actor)
|
|
main_agent = await self.agent_manager.attach_missing_files_tools_async(agent_state=main_agent, actor=actor)
|
|
log_event(name="end insert_files_into_context_window db")
|
|
|
|
if request.enable_sleeptime:
|
|
if request.agent_type == AgentType.voice_convo_agent:
|
|
main_agent = await self.create_voice_sleeptime_agent_async(main_agent=main_agent, actor=actor)
|
|
else:
|
|
main_agent = await self.create_sleeptime_agent_async(main_agent=main_agent, actor=actor)
|
|
|
|
return main_agent
|
|
|
|
async def update_agent_async(
|
|
self,
|
|
agent_id: str,
|
|
request: UpdateAgent,
|
|
actor: User,
|
|
) -> AgentState:
|
|
# Build llm_config from convenience fields if llm_config is not provided
|
|
if request.llm_config is None and (
|
|
request.model is not None or request.context_window_limit is not None or request.max_tokens is not None
|
|
):
|
|
if request.model is None:
|
|
agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
request.model = agent.llm_config.handle
|
|
config_params = {
|
|
"handle": request.model,
|
|
"context_window_limit": request.context_window_limit,
|
|
"max_tokens": request.max_tokens,
|
|
}
|
|
log_event(name="start get_llm_config_from_handle", attributes=config_params)
|
|
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
|
|
log_event(name="end get_llm_config_from_handle", attributes=config_params)
|
|
|
|
# update with model_settings
|
|
if request.model_settings is not None:
|
|
if request.llm_config is None:
|
|
# Get the current agent's llm_config if not already set
|
|
agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
request.llm_config = agent.llm_config.model_copy()
|
|
update_llm_config_params = request.model_settings._to_legacy_config_params()
|
|
request.llm_config = request.llm_config.model_copy(update=update_llm_config_params)
|
|
|
|
# Copy parallel_tool_calls from request to llm_config if provided
|
|
if request.parallel_tool_calls is not None:
|
|
if request.llm_config is None:
|
|
# Get the current agent's llm_config and update it
|
|
agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
request.llm_config = agent.llm_config.model_copy()
|
|
request.llm_config.parallel_tool_calls = request.parallel_tool_calls
|
|
|
|
if request.embedding is not None:
|
|
request.embedding_config = await self.get_embedding_config_from_handle_async(handle=request.embedding, actor=actor)
|
|
|
|
if request.enable_sleeptime:
|
|
agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
if agent.multi_agent_group is None:
|
|
if agent.agent_type == AgentType.voice_convo_agent:
|
|
await self.create_voice_sleeptime_agent_async(main_agent=agent, actor=actor)
|
|
else:
|
|
await self.create_sleeptime_agent_async(main_agent=agent, actor=actor)
|
|
|
|
# If git-backed memory is requested via tag update, initialize/backfill the repo.
|
|
wants_git_memory = bool(request.tags and GIT_MEMORY_ENABLED_TAG in request.tags)
|
|
|
|
updated_agent = await self.agent_manager.update_agent_async(
|
|
agent_id=agent_id,
|
|
agent_update=request,
|
|
actor=actor,
|
|
)
|
|
|
|
# Ensure repo exists and initial blocks are committed when the tag is present.
|
|
if wants_git_memory and isinstance(self.block_manager, GitEnabledBlockManager):
|
|
await self.block_manager.enable_git_memory_for_agent(agent_id=agent_id, actor=actor)
|
|
# Preserve the user's requested tags in the response model.
|
|
try:
|
|
updated_agent.tags = list(request.tags or [])
|
|
except Exception:
|
|
pass
|
|
|
|
return updated_agent
|
|
|
|
async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]:
|
|
if main_agent.embedding_config is None:
|
|
logger.warning(f"Skipping sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
|
return None
|
|
request = CreateAgent(
|
|
name=main_agent.name + "-sleeptime",
|
|
agent_type=AgentType.sleeptime_agent,
|
|
block_ids=[block.id for block in main_agent.memory.blocks],
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="memory_persona",
|
|
value=get_persona_text("sleeptime_memory_persona"),
|
|
),
|
|
],
|
|
llm_config=main_agent.llm_config,
|
|
embedding_config=main_agent.embedding_config,
|
|
project_id=main_agent.project_id,
|
|
)
|
|
sleeptime_agent = await self.agent_manager.create_agent_async(
|
|
agent_create=request,
|
|
actor=actor,
|
|
)
|
|
await self.group_manager.create_group_async(
|
|
group=GroupCreate(
|
|
description="",
|
|
agent_ids=[sleeptime_agent.id],
|
|
manager_config=SleeptimeManager(
|
|
manager_agent_id=main_agent.id,
|
|
sleeptime_agent_frequency=5,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor)
|
|
|
|
async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]:
|
|
if main_agent.embedding_config is None:
|
|
logger.warning(f"Skipping voice sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
|
return None
|
|
# TODO: Inject system
|
|
request = CreateAgent(
|
|
name=main_agent.name + "-sleeptime",
|
|
agent_type=AgentType.voice_sleeptime_agent,
|
|
block_ids=[block.id for block in main_agent.memory.blocks],
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="memory_persona",
|
|
value=get_persona_text("voice_memory_persona"),
|
|
),
|
|
],
|
|
llm_config=LLMConfig.default_config("gpt-4.1"),
|
|
embedding_config=main_agent.embedding_config,
|
|
project_id=main_agent.project_id,
|
|
)
|
|
voice_sleeptime_agent = await self.agent_manager.create_agent_async(
|
|
agent_create=request,
|
|
actor=actor,
|
|
)
|
|
await self.group_manager.create_group_async(
|
|
group=GroupCreate(
|
|
description="Low latency voice chat with async memory management.",
|
|
agent_ids=[voice_sleeptime_agent.id],
|
|
manager_config=VoiceSleeptimeManager(
|
|
manager_agent_id=main_agent.id,
|
|
max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH,
|
|
min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor)
|
|
|
|
async def get_agent_memory_async(self, agent_id: str, actor: User) -> Memory:
|
|
"""Return the memory of an agent (core memory)"""
|
|
agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
return agent.memory
|
|
|
|
async def get_agent_archival_async(
|
|
self,
|
|
agent_id: str,
|
|
actor: User,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
limit: Optional[int] = 100,
|
|
query_text: Optional[str] = None,
|
|
ascending: Optional[bool] = True,
|
|
) -> List[Passage]:
|
|
# iterate over records
|
|
records = await self.agent_manager.query_agent_passages_async(
|
|
actor=actor,
|
|
agent_id=agent_id,
|
|
after=after,
|
|
query_text=query_text,
|
|
before=before,
|
|
ascending=ascending,
|
|
limit=limit,
|
|
)
|
|
# Extract just the passages (SQL path returns empty metadata)
|
|
return [passage for passage, _, _ in records]
|
|
|
|
async def insert_archival_memory_async(
|
|
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]
|
|
) -> List[Passage]:
|
|
from letta.services.context_window_calculator.token_counter import create_token_counter
|
|
from letta.settings import settings
|
|
|
|
# Get the agent object (loaded in memory)
|
|
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
|
|
# Check token count against limit
|
|
token_counter = create_token_counter(
|
|
model_endpoint_type=agent_state.llm_config.model_endpoint_type,
|
|
model=agent_state.llm_config.model,
|
|
actor=actor,
|
|
agent_id=agent_id,
|
|
)
|
|
token_count = await token_counter.count_text_tokens(memory_contents)
|
|
if token_count > settings.archival_memory_token_limit:
|
|
raise LettaInvalidArgumentError(
|
|
message=f"Archival memory content exceeds token limit of {settings.archival_memory_token_limit} tokens (found {token_count} tokens)",
|
|
argument_name="memory_contents",
|
|
)
|
|
|
|
# Use passage manager which handles dual-write to Turbopuffer if enabled
|
|
passages = await self.passage_manager.insert_passage(
|
|
agent_state=agent_state, text=memory_contents, tags=tags, actor=actor, created_at=created_at
|
|
)
|
|
|
|
return passages
|
|
|
|
async def delete_archival_memory_async(self, memory_id: str, actor: User):
|
|
# TODO check if it exists first, and throw error if not
|
|
# TODO: need to also rebuild the prompt here
|
|
passage = await self.passage_manager.get_passage_by_id_async(passage_id=memory_id, actor=actor)
|
|
|
|
# delete the passage
|
|
await self.passage_manager.delete_passage_by_id_async(passage_id=memory_id, actor=actor)
|
|
|
|
async def get_agent_recall(
|
|
self,
|
|
user_id: str,
|
|
agent_id: str,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
limit: Optional[int] = 100,
|
|
group_id: Optional[str] = None,
|
|
reverse: Optional[bool] = False,
|
|
return_message_object: bool = True,
|
|
use_assistant_message: bool = True,
|
|
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
|
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
|
) -> Union[List[Message], List[LettaMessage]]:
|
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
|
|
actor = await self.user_manager.get_actor_or_default_async(actor_id=user_id)
|
|
|
|
records = await self.message_manager.list_messages(
|
|
agent_id=agent_id,
|
|
actor=actor,
|
|
after=after,
|
|
before=before,
|
|
limit=limit,
|
|
ascending=not reverse,
|
|
group_id=group_id,
|
|
)
|
|
|
|
if not return_message_object:
|
|
records = Message.to_letta_messages_from_list(
|
|
messages=records,
|
|
use_assistant_message=use_assistant_message,
|
|
assistant_message_tool_name=assistant_message_tool_name,
|
|
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
|
reverse=reverse,
|
|
)
|
|
|
|
if reverse:
|
|
records = records[::-1]
|
|
|
|
return records
|
|
|
|
async def get_agent_recall_async(
|
|
self,
|
|
agent_id: str,
|
|
actor: User,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
limit: Optional[int] = 100,
|
|
group_id: Optional[str] = None,
|
|
reverse: Optional[bool] = False,
|
|
return_message_object: bool = True,
|
|
use_assistant_message: bool = True,
|
|
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
|
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
|
include_err: Optional[bool] = None,
|
|
conversation_id: Optional[str] = None,
|
|
) -> Union[List[Message], List[LettaMessage]]:
|
|
records = await self.message_manager.list_messages(
|
|
agent_id=agent_id,
|
|
actor=actor,
|
|
after=after,
|
|
before=before,
|
|
limit=limit,
|
|
ascending=not reverse,
|
|
group_id=group_id,
|
|
include_err=include_err,
|
|
conversation_id=conversation_id,
|
|
)
|
|
|
|
if not return_message_object:
|
|
# Get agent state to determine if it's a react agent
|
|
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
text_is_assistant_message = agent_state.agent_type == AgentType.letta_v1_agent
|
|
|
|
records = Message.to_letta_messages_from_list(
|
|
messages=records,
|
|
use_assistant_message=use_assistant_message,
|
|
assistant_message_tool_name=assistant_message_tool_name,
|
|
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
|
reverse=reverse,
|
|
include_err=include_err,
|
|
text_is_assistant_message=text_is_assistant_message,
|
|
)
|
|
|
|
if reverse:
|
|
records = records[::-1]
|
|
|
|
return records
|
|
|
|
async def get_all_messages_recall_async(
|
|
self,
|
|
actor: User,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
limit: Optional[int] = 100,
|
|
group_id: Optional[str] = None,
|
|
reverse: Optional[bool] = False,
|
|
return_message_object: bool = True,
|
|
use_assistant_message: bool = True,
|
|
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
|
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
|
include_err: Optional[bool] = None,
|
|
conversation_id: Optional[str] = None,
|
|
) -> Union[List[Message], List[LettaMessage]]:
|
|
records = await self.message_manager.list_messages(
|
|
agent_id=None,
|
|
actor=actor,
|
|
after=after,
|
|
before=before,
|
|
limit=limit,
|
|
ascending=not reverse,
|
|
group_id=group_id,
|
|
include_err=include_err,
|
|
conversation_id=conversation_id,
|
|
)
|
|
|
|
if not return_message_object:
|
|
# NOTE: We are assuming all messages are coming from letta_v1_agent. This may lead to slightly incorrect assistant message handling.
|
|
# text_is_assistant_message = agent_state.agent_type == AgentType.letta_v1_agent
|
|
text_is_assistant_message = True
|
|
|
|
records = Message.to_letta_messages_from_list(
|
|
messages=records,
|
|
use_assistant_message=use_assistant_message,
|
|
assistant_message_tool_name=assistant_message_tool_name,
|
|
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
|
reverse=reverse,
|
|
include_err=include_err,
|
|
text_is_assistant_message=text_is_assistant_message,
|
|
)
|
|
|
|
if reverse:
|
|
records = records[::-1]
|
|
|
|
return records
|
|
|
|
def get_server_config(self, include_defaults: bool = False) -> dict:
|
|
"""Return the base config"""
|
|
|
|
def clean_keys(config):
|
|
config_copy = config.copy()
|
|
for k, v in config.items():
|
|
if k == "key" or "_key" in k:
|
|
config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5)
|
|
return config_copy
|
|
|
|
# TODO: do we need a separate server config?
|
|
base_config = vars(self.config)
|
|
clean_base_config = clean_keys(base_config)
|
|
|
|
response = {"config": clean_base_config}
|
|
|
|
if include_defaults:
|
|
default_config = vars(LettaConfig())
|
|
clean_default_config = clean_keys(default_config)
|
|
response["defaults"] = clean_default_config
|
|
|
|
return response
|
|
|
|
def update_agent_core_memory(self, agent_id: str, label: str, value: str, actor: User) -> Memory:
|
|
"""Update the value of a block in the agent's memory"""
|
|
|
|
# get the block id
|
|
block = self.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=actor)
|
|
|
|
# update the block
|
|
self.block_manager.update_block(block_id=block.id, block_update=BlockUpdate(value=value), actor=actor)
|
|
|
|
# rebuild system prompt for agent, potentially changed
|
|
return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory
|
|
|
|
async def delete_source(self, source_id: str, actor: User):
|
|
"""Delete a data source"""
|
|
await self.source_manager.delete_source(source_id=source_id, actor=actor)
|
|
|
|
# delete data from passage store
|
|
passages_to_be_deleted = await self.agent_manager.query_source_passages_async(actor=actor, source_id=source_id, limit=None)
|
|
await self.passage_manager.delete_source_passages_async(actor=actor, passages=passages_to_be_deleted)
|
|
|
|
# TODO: delete data from agent passage stores (?)
|
|
|
|
async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
|
|
# update job
|
|
job = await self.job_manager.get_job_by_id_async(job_id, actor=actor)
|
|
job.status = JobStatus.running
|
|
await self.job_manager.update_job_by_id_async(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor)
|
|
|
|
# try:
|
|
from letta.data_sources.connectors import DirectoryConnector
|
|
|
|
# TODO: move this into a thread
|
|
source = await self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
connector = DirectoryConnector(input_files=[file_path])
|
|
num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
|
|
|
|
# update all agents who have this source attached
|
|
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
|
for agent_state in agent_states:
|
|
agent_id = agent_state.id
|
|
|
|
# Attach source to agent
|
|
curr_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
|
|
agent_state = await self.agent_manager.attach_source_async(agent_id=agent_state.id, source_id=source_id, actor=actor)
|
|
new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
|
|
assert new_passage_size >= curr_passage_size # in case empty files are added
|
|
|
|
# update job status
|
|
job.status = JobStatus.completed
|
|
job.metadata["num_passages"] = num_passages
|
|
job.metadata["num_documents"] = num_documents
|
|
await self.job_manager.update_job_by_id_async(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor)
|
|
|
|
return job
|
|
|
|
async def load_file_to_source_via_mistral(self):
|
|
pass
|
|
|
|
async def sleeptime_document_ingest_async(
|
|
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
|
) -> None:
|
|
pass
|
|
|
|
async def _remove_file_from_agent(self, agent_id: str, file_id: str, actor: User) -> None:
|
|
"""
|
|
Internal method to remove a document block for an agent.
|
|
"""
|
|
try:
|
|
await self.file_agent_manager.detach_file(
|
|
agent_id=agent_id,
|
|
file_id=file_id,
|
|
actor=actor,
|
|
)
|
|
except NoResultFound:
|
|
logger.info(f"File {file_id} already removed from agent {agent_id}, skipping...")
|
|
|
|
async def remove_file_from_context_windows(self, source_id: str, file_id: str, actor: User) -> None:
|
|
"""
|
|
Remove the document from the context window of all agents
|
|
attached to the given source.
|
|
"""
|
|
# Use the optimized ids_only parameter
|
|
agent_ids = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor, ids_only=True)
|
|
|
|
# Return early if no agents
|
|
if not agent_ids:
|
|
return
|
|
|
|
logger.info(f"Removing file from context window for source: {source_id}")
|
|
logger.info(f"Attached agents: {agent_ids}")
|
|
|
|
# Create agent-file pairs for bulk deletion
|
|
agent_file_pairs = [(agent_id, file_id) for agent_id in agent_ids]
|
|
|
|
# Bulk delete in a single query
|
|
deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor)
|
|
|
|
logger.info(f"Removed file {file_id} from {deleted_count} agent context windows")
|
|
|
|
async def remove_files_from_context_window(self, agent_state: AgentState, file_ids: List[str], actor: User) -> None:
|
|
"""
|
|
Remove multiple documents from the context window of an agent
|
|
attached to the given source.
|
|
"""
|
|
logger.info(f"Removing files from context window for agent_state: {agent_state.id}")
|
|
logger.info(f"Files to remove: {file_ids}")
|
|
|
|
# Create agent-file pairs for bulk deletion
|
|
agent_file_pairs = [(agent_state.id, file_id) for file_id in file_ids]
|
|
|
|
# Bulk delete in a single query
|
|
deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor)
|
|
|
|
logger.info(f"Removed {deleted_count} files from agent {agent_state.id}")
|
|
|
|
async def create_document_sleeptime_agent_async(
|
|
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
|
) -> Optional[AgentState]:
|
|
if main_agent.embedding_config is None:
|
|
logger.warning(f"Skipping document sleeptime agent creation for agent {main_agent.id}: no embedding config provided")
|
|
return None
|
|
try:
|
|
block = await self.agent_manager.get_block_with_label_async(agent_id=main_agent.id, block_label=source.name, actor=actor)
|
|
except:
|
|
block = await self.block_manager.create_or_update_block_async(Block(label=source.name, value=""), actor=actor)
|
|
await self.agent_manager.attach_block_async(agent_id=main_agent.id, block_id=block.id, actor=actor)
|
|
|
|
if clear_history and block.value != "":
|
|
block = await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=""), actor=actor)
|
|
|
|
request = CreateAgent(
|
|
name=main_agent.name + "-doc-sleeptime",
|
|
system=get_system_text("sleeptime_doc_ingest"),
|
|
agent_type=AgentType.sleeptime_agent,
|
|
block_ids=[block.id],
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="persona",
|
|
value=get_persona_text("sleeptime_doc_persona"),
|
|
),
|
|
CreateBlock(
|
|
label="instructions",
|
|
value=source.instructions,
|
|
),
|
|
],
|
|
llm_config=main_agent.llm_config,
|
|
embedding_config=main_agent.embedding_config,
|
|
project_id=main_agent.project_id,
|
|
include_base_tools=False,
|
|
tools=constants.BASE_SLEEPTIME_TOOLS,
|
|
)
|
|
return await self.agent_manager.create_agent_async(
|
|
agent_create=request,
|
|
actor=actor,
|
|
)
|
|
|
|
async def load_data(
|
|
self,
|
|
user_id: str,
|
|
connector: DataConnector,
|
|
source_name: str,
|
|
) -> Tuple[int, int]:
|
|
"""Load data from a DataConnector into a source for a specified user_id"""
|
|
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
|
|
|
# load data from a data source into the document store
|
|
actor = await self.user_manager.get_actor_by_id_async(actor_id=user_id)
|
|
source = await self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
|
if source is None:
|
|
raise NoResultFound(f"Data source {source_name} does not exist for user {user_id}")
|
|
|
|
# load data into the document store
|
|
passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor)
|
|
return passage_count, document_count
|
|
|
|
def _get_provider_sort_key(self, model: LLMConfig) -> Tuple[int, str, str]:
|
|
"""Get sort key for a model: (provider_priority, provider_name, model_name)"""
|
|
provider_priority = constants.PROVIDER_ORDER.get(model.provider_name, 999)
|
|
return (provider_priority, model.provider_name or "", model.model or "")
|
|
|
|
def _get_embedding_provider_sort_key(self, model: EmbeddingConfig) -> Tuple[int, str, str]:
|
|
"""Get sort key for an embedding model: (provider_priority, provider_name, model_name)"""
|
|
# Extract provider name from handle (format: "provider_name/model_name")
|
|
provider_name = model.handle.split("/")[0] if model.handle and "/" in model.handle else ""
|
|
provider_priority = constants.PROVIDER_ORDER.get(provider_name, 999)
|
|
return (provider_priority, provider_name, model.embedding_model or "")
|
|
|
|
@trace_method
|
|
async def list_llm_models_async(
|
|
self,
|
|
actor: User,
|
|
provider_category: Optional[List[ProviderCategory]] = None,
|
|
provider_name: Optional[str] = None,
|
|
provider_type: Optional[ProviderType] = None,
|
|
) -> List[LLMConfig]:
|
|
"""List available LLM models - base from DB, BYOK from provider endpoints"""
|
|
llm_models = []
|
|
|
|
# Determine which categories to include
|
|
include_base = not provider_category or ProviderCategory.base in provider_category
|
|
include_byok = not provider_category or ProviderCategory.byok in provider_category
|
|
|
|
# Get base provider models from database
|
|
if include_base:
|
|
provider_models = await self.provider_manager.list_models_async(
|
|
actor=actor,
|
|
model_type="llm",
|
|
enabled=True,
|
|
)
|
|
|
|
# Build LLMConfig objects from database
|
|
provider_cache: Dict[str, Provider] = {}
|
|
for model in provider_models:
|
|
# Get provider details (with caching to avoid N+1 queries)
|
|
if model.provider_id not in provider_cache:
|
|
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
|
provider = provider_cache[model.provider_id]
|
|
|
|
# Skip non-base providers (they're handled separately)
|
|
if provider.provider_category != ProviderCategory.base:
|
|
continue
|
|
|
|
# Apply provider_name/provider_type filters if specified
|
|
if provider_name and provider.name != provider_name:
|
|
continue
|
|
if provider_type and provider.provider_type != provider_type:
|
|
continue
|
|
|
|
# For bedrock, use schema default for base_url since DB may have NULL
|
|
# TODO: can maybe do this for all models but want to isolate change so we don't break any other providers
|
|
if provider.provider_type == ProviderType.bedrock:
|
|
typed_provider = provider.cast_to_subtype()
|
|
model_endpoint = typed_provider.base_url
|
|
else:
|
|
model_endpoint = provider.base_url
|
|
|
|
llm_config = LLMConfig(
|
|
model=model.name,
|
|
model_endpoint_type=model.model_endpoint_type,
|
|
model_endpoint=model_endpoint,
|
|
context_window=model.max_context_window or 16384,
|
|
handle=model.handle,
|
|
provider_name=provider.name,
|
|
provider_category=provider.provider_category,
|
|
)
|
|
llm_models.append(llm_config)
|
|
|
|
# Get BYOK provider models - sync if not synced yet, then read from DB
|
|
if include_byok:
|
|
byok_providers = await self.provider_manager.list_providers_async(
|
|
actor=actor,
|
|
name=provider_name,
|
|
provider_type=provider_type,
|
|
provider_category=[ProviderCategory.byok],
|
|
)
|
|
|
|
for provider in byok_providers:
|
|
try:
|
|
# Get typed provider to access schema defaults (e.g., base_url)
|
|
typed_provider = provider.cast_to_subtype()
|
|
|
|
# Sync models if not synced yet
|
|
if provider.last_synced is None:
|
|
models = await typed_provider.list_llm_models_async()
|
|
embedding_models = await typed_provider.list_embedding_models_async()
|
|
await self.provider_manager.sync_provider_models_async(
|
|
provider=provider,
|
|
llm_models=models,
|
|
embedding_models=embedding_models,
|
|
organization_id=provider.organization_id,
|
|
)
|
|
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
|
|
|
|
# Read from database
|
|
provider_llm_models = await self.provider_manager.list_models_async(
|
|
actor=actor,
|
|
model_type="llm",
|
|
provider_id=provider.id,
|
|
enabled=True,
|
|
)
|
|
for model in provider_llm_models:
|
|
llm_config = LLMConfig(
|
|
model=model.name,
|
|
model_endpoint_type=model.model_endpoint_type,
|
|
model_endpoint=typed_provider.base_url,
|
|
context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW,
|
|
handle=model.handle,
|
|
provider_name=provider.name,
|
|
provider_category=ProviderCategory.byok,
|
|
)
|
|
llm_models.append(llm_config)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
|
|
|
|
# Sort by provider order (matching old _enabled_providers order), then by model name
|
|
llm_models.sort(key=self._get_provider_sort_key)
|
|
|
|
return llm_models
|
|
|
|
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
|
|
"""List available embedding models - base from DB, BYOK from provider endpoints"""
|
|
embedding_models = []
|
|
|
|
# Get base provider models from database
|
|
provider_models = await self.provider_manager.list_models_async(
|
|
actor=actor,
|
|
model_type="embedding",
|
|
enabled=True,
|
|
)
|
|
|
|
# Build EmbeddingConfig objects from database (base providers only)
|
|
provider_cache: Dict[str, Provider] = {}
|
|
for model in provider_models:
|
|
# Get provider details (with caching to avoid N+1 queries)
|
|
if model.provider_id not in provider_cache:
|
|
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
|
provider = provider_cache[model.provider_id]
|
|
|
|
# Skip non-base providers (they're handled separately)
|
|
if provider.provider_category != ProviderCategory.base:
|
|
continue
|
|
|
|
embedding_config = EmbeddingConfig(
|
|
embedding_model=model.name,
|
|
embedding_endpoint_type=model.model_endpoint_type,
|
|
embedding_endpoint=provider.base_url or model.model_endpoint_type,
|
|
embedding_dim=model.embedding_dim or 1536,
|
|
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
|
handle=model.handle,
|
|
)
|
|
embedding_models.append(embedding_config)
|
|
|
|
# Get BYOK provider models - sync if not synced yet, then read from DB
|
|
byok_providers = await self.provider_manager.list_providers_async(
|
|
actor=actor,
|
|
provider_category=[ProviderCategory.byok],
|
|
)
|
|
|
|
for provider in byok_providers:
|
|
try:
|
|
# Get typed provider to access schema defaults (e.g., base_url)
|
|
typed_provider = provider.cast_to_subtype()
|
|
|
|
# Sync models if not synced yet
|
|
if provider.last_synced is None:
|
|
llm_models = await typed_provider.list_llm_models_async()
|
|
emb_models = await typed_provider.list_embedding_models_async()
|
|
await self.provider_manager.sync_provider_models_async(
|
|
provider=provider,
|
|
llm_models=llm_models,
|
|
embedding_models=emb_models,
|
|
organization_id=provider.organization_id,
|
|
)
|
|
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
|
|
|
|
# Read from database
|
|
provider_embedding_models = await self.provider_manager.list_models_async(
|
|
actor=actor,
|
|
model_type="embedding",
|
|
provider_id=provider.id,
|
|
enabled=True,
|
|
)
|
|
for model in provider_embedding_models:
|
|
embedding_config = EmbeddingConfig(
|
|
embedding_model=model.name,
|
|
embedding_endpoint_type=model.model_endpoint_type,
|
|
embedding_endpoint=typed_provider.base_url,
|
|
embedding_dim=model.embedding_dim or 1536,
|
|
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
|
handle=model.handle,
|
|
)
|
|
embedding_models.append(embedding_config)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")
|
|
|
|
# Sort by provider order (matching old _enabled_providers order), then by model name
|
|
embedding_models.sort(key=self._get_embedding_provider_sort_key)
|
|
|
|
return embedding_models
|
|
|
|
async def get_enabled_providers_async(
|
|
self,
|
|
actor: User,
|
|
provider_category: Optional[List[ProviderCategory]] = None,
|
|
provider_name: Optional[str] = None,
|
|
provider_type: Optional[ProviderType] = None,
|
|
) -> List[Provider]:
|
|
# Query all persisted providers from database
|
|
persisted_providers = await self.provider_manager.list_providers_async(
|
|
name=provider_name,
|
|
provider_type=provider_type,
|
|
actor=actor,
|
|
)
|
|
providers = [p.cast_to_subtype() for p in persisted_providers]
|
|
|
|
# Filter by category if specified
|
|
if provider_category:
|
|
providers = [p for p in providers if p.provider_category in provider_category]
|
|
|
|
return providers
|
|
|
|
@trace_method
|
|
async def get_llm_config_from_handle_async(
|
|
self,
|
|
actor: User,
|
|
handle: str,
|
|
context_window_limit: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
max_reasoning_tokens: Optional[int] = None,
|
|
enable_reasoner: Optional[bool] = None,
|
|
) -> LLMConfig:
|
|
# Use provider_manager to get LLMConfig from handle
|
|
try:
|
|
llm_config = await self.provider_manager.get_llm_config_from_handle(
|
|
handle=handle,
|
|
actor=actor,
|
|
)
|
|
except Exception as e:
|
|
# Convert to HandleNotFoundError for backwards compatibility
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
if isinstance(e, NoResultFound):
|
|
raise HandleNotFoundError(handle, [])
|
|
raise
|
|
|
|
if context_window_limit is not None:
|
|
if context_window_limit > llm_config.context_window:
|
|
raise LettaInvalidArgumentError(
|
|
f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})",
|
|
argument_name="context_window_limit",
|
|
)
|
|
llm_config.context_window = context_window_limit
|
|
else:
|
|
llm_config.context_window = min(llm_config.context_window, model_settings.global_max_context_window_limit)
|
|
|
|
if max_tokens is not None:
|
|
llm_config.max_tokens = max_tokens
|
|
if max_reasoning_tokens is not None:
|
|
if not max_tokens or max_reasoning_tokens > max_tokens:
|
|
raise LettaInvalidArgumentError(
|
|
f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})",
|
|
argument_name="max_reasoning_tokens",
|
|
)
|
|
llm_config.max_reasoning_tokens = max_reasoning_tokens
|
|
if enable_reasoner is not None:
|
|
llm_config.enable_reasoner = enable_reasoner
|
|
if enable_reasoner and llm_config.model_endpoint_type == "anthropic":
|
|
llm_config.put_inner_thoughts_in_kwargs = False
|
|
|
|
return llm_config
|
|
|
|
@trace_method
|
|
async def get_embedding_config_from_handle_async(
|
|
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
) -> EmbeddingConfig:
|
|
# Use provider_manager to get EmbeddingConfig from handle
|
|
try:
|
|
embedding_config = await self.provider_manager.get_embedding_config_from_handle(
|
|
handle=handle,
|
|
actor=actor,
|
|
)
|
|
except Exception as e:
|
|
# Convert to LettaInvalidArgumentError for backwards compatibility
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
if isinstance(e, NoResultFound):
|
|
raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle")
|
|
raise
|
|
|
|
# Override chunk size if provided
|
|
embedding_config.embedding_chunk_size = embedding_chunk_size
|
|
|
|
return embedding_config
|
|
|
|
async def get_provider_from_name_async(self, provider_name: str, actor: User) -> Provider:
|
|
all_providers = await self.get_enabled_providers_async(actor)
|
|
providers = [provider for provider in all_providers if provider.name == provider_name]
|
|
if not providers:
|
|
raise LettaInvalidArgumentError(
|
|
f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in all_providers])})",
|
|
argument_name="provider_name",
|
|
)
|
|
elif len(providers) > 1:
|
|
logger.warning(f"Multiple providers with name {provider_name} supported")
|
|
provider = providers[0]
|
|
else:
|
|
provider = providers[0]
|
|
|
|
return provider
|
|
|
|
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
|
"""Add a new LLM model"""
|
|
|
|
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
|
|
"""Add a new embedding model"""
|
|
|
|
async def run_tool_from_source(
|
|
self,
|
|
actor: User,
|
|
tool_args: Dict[str, str],
|
|
tool_source: str,
|
|
tool_env_vars: Optional[Dict[str, str]] = None,
|
|
tool_source_type: Optional[str] = None,
|
|
tool_name: Optional[str] = None,
|
|
tool_args_json_schema: Optional[Dict[str, Any]] = None,
|
|
tool_json_schema: Optional[Dict[str, Any]] = None,
|
|
pip_requirements: Optional[List[PipRequirement]] = None,
|
|
) -> ToolReturnMessage:
|
|
"""Run a tool from source code"""
|
|
|
|
from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update
|
|
|
|
if tool_source_type not in (None, ToolSourceType.python, ToolSourceType.typescript):
|
|
raise LettaInvalidArgumentError(
|
|
f"Tool source type is not supported at this time. Found {tool_source_type}", argument_name="tool_source_type"
|
|
)
|
|
|
|
# If tools_json_schema is explicitly passed in, override it on the created Tool object
|
|
if tool_json_schema:
|
|
tool = Tool(
|
|
name=tool_name,
|
|
source_code=tool_source,
|
|
json_schema=tool_json_schema,
|
|
pip_requirements=pip_requirements,
|
|
source_type=tool_source_type,
|
|
)
|
|
else:
|
|
# NOTE: we're creating a floating Tool object and NOT persisting to DB
|
|
tool = Tool(
|
|
name=tool_name,
|
|
source_code=tool_source,
|
|
args_json_schema=tool_args_json_schema,
|
|
pip_requirements=pip_requirements,
|
|
source_type=tool_source_type,
|
|
)
|
|
|
|
# try to get the schema
|
|
if not tool.name:
|
|
if not tool.json_schema:
|
|
tool.json_schema = generate_schema_for_tool_creation(tool)
|
|
tool.name = tool.json_schema.get("name")
|
|
assert tool.name is not None, "Failed to create tool object"
|
|
|
|
# TODO eventually allow using agent state in tools
|
|
agent_state = None
|
|
|
|
# Next, attempt to run the tool with the sandbox
|
|
try:
|
|
tool_execution_manager = ToolExecutionManager(
|
|
agent_state=agent_state,
|
|
message_manager=self.message_manager,
|
|
agent_manager=self.agent_manager,
|
|
block_manager=self.block_manager,
|
|
run_manager=self.run_manager,
|
|
passage_manager=self.passage_manager,
|
|
actor=actor,
|
|
sandbox_env_vars=tool_env_vars,
|
|
)
|
|
|
|
# TODO: Integrate sandbox result
|
|
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
|
function_name=tool_name,
|
|
function_args=tool_args,
|
|
tool=tool,
|
|
)
|
|
from letta.schemas.letta_message import ToolReturn as ToolReturnSchema
|
|
|
|
tool_return_obj = ToolReturnSchema(
|
|
tool_return=str(tool_execution_result.func_return),
|
|
status=tool_execution_result.status,
|
|
tool_call_id="null",
|
|
stdout=tool_execution_result.stdout,
|
|
stderr=tool_execution_result.stderr,
|
|
)
|
|
|
|
return ToolReturnMessage(
|
|
id="null",
|
|
tool_call_id="null",
|
|
date=get_utc_time(),
|
|
name=tool_name,
|
|
status=tool_execution_result.status,
|
|
tool_return=str(tool_execution_result.func_return),
|
|
stdout=tool_execution_result.stdout,
|
|
stderr=tool_execution_result.stderr,
|
|
tool_returns=[tool_return_obj],
|
|
)
|
|
|
|
except Exception as e:
|
|
func_return = get_friendly_error_msg(function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e))
|
|
from letta.schemas.letta_message import ToolReturn as ToolReturnSchema
|
|
|
|
tool_return_obj = ToolReturnSchema(
|
|
tool_return=func_return,
|
|
status="error",
|
|
tool_call_id="null",
|
|
stdout=[],
|
|
stderr=[traceback.format_exc()],
|
|
)
|
|
|
|
return ToolReturnMessage(
|
|
id="null",
|
|
tool_call_id="null",
|
|
date=get_utc_time(),
|
|
name=tool.name,
|
|
status="error",
|
|
tool_return=func_return,
|
|
stdout=[],
|
|
stderr=[traceback.format_exc()],
|
|
tool_returns=[tool_return_obj],
|
|
)
|
|
|
|
# MCP wrappers
|
|
# TODO support both command + SSE servers (via config)
|
|
async def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
|
"""List the MCP servers in the config (doesn't test that they are actually working)"""
|
|
|
|
# TODO implement non-flatfile mechanism
|
|
if not tool_settings.mcp_read_from_config:
|
|
return {}
|
|
# raise RuntimeError("MCP config file disabled. Enable it in settings.")
|
|
|
|
mcp_server_list = {}
|
|
|
|
# Attempt to read from ~/.letta/mcp_config.json
|
|
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
|
if os.path.exists(mcp_config_path):
|
|
|
|
def _read_config():
|
|
with open(mcp_config_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
try:
|
|
mcp_config = await asyncio.to_thread(_read_config)
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse MCP config file ({mcp_config_path}) as json: {e}")
|
|
return mcp_server_list
|
|
|
|
# Proper formatting is "mcpServers" key at the top level,
|
|
# then a dict with the MCP server name as the key,
|
|
# with the value being the schema from StdioServerParameters
|
|
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
|
|
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
|
|
# No support for duplicate server names
|
|
if server_name in mcp_server_list:
|
|
logger.error(f"Duplicate MCP server name found (skipping): {server_name}")
|
|
continue
|
|
|
|
if "url" in server_params_raw:
|
|
# Attempt to parse the server params as an SSE server
|
|
try:
|
|
server_params = SSEServerConfig(
|
|
server_name=server_name,
|
|
server_url=server_params_raw["url"],
|
|
)
|
|
mcp_server_list[server_name] = server_params
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}")
|
|
continue
|
|
else:
|
|
# Attempt to parse the server params as a StdioServerParameters
|
|
try:
|
|
server_params = StdioServerConfig(
|
|
server_name=server_name,
|
|
command=server_params_raw["command"],
|
|
args=server_params_raw.get("args", []),
|
|
env=server_params_raw.get("env", {}),
|
|
)
|
|
mcp_server_list[server_name] = server_params
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}")
|
|
continue
|
|
|
|
# If the file doesn't exist, return empty dictionary
|
|
return mcp_server_list
|
|
|
|
async def get_tools_from_mcp_server(self, mcp_server_name: str) -> List[MCPTool]:
|
|
"""List the tools in an MCP server. Requires a client to be created."""
|
|
if mcp_server_name not in self.mcp_clients:
|
|
raise LettaInvalidArgumentError(f"No client was created for MCP server: {mcp_server_name}", argument_name="mcp_server_name")
|
|
|
|
tools = await self.mcp_clients[mcp_server_name].list_tools()
|
|
# Add health information to each tool
|
|
for tool in tools:
|
|
if tool.inputSchema:
|
|
health_status, reasons = validate_complete_json_schema(tool.inputSchema)
|
|
tool.health = MCPToolHealth(status=health_status.value, reasons=reasons)
|
|
|
|
return tools
|
|
|
|
async def add_mcp_server_to_config(
|
|
self, server_config: Union[SSEServerConfig, StdioServerConfig], allow_upsert: bool = True
|
|
) -> List[Union[SSEServerConfig, StdioServerConfig]]:
|
|
"""Add a new server config to the MCP config file"""
|
|
|
|
# TODO implement non-flatfile mechanism
|
|
if not tool_settings.mcp_read_from_config:
|
|
raise RuntimeError("MCP config file disabled. Enable it in settings.")
|
|
|
|
# If the config file doesn't exist, throw an error.
|
|
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
|
if not os.path.exists(mcp_config_path):
|
|
# Create the file if it doesn't exist
|
|
logger.debug(f"MCP config file not found, creating new file at: {mcp_config_path}")
|
|
|
|
# If the file does exist, attempt to parse it get calling get_mcp_servers
|
|
try:
|
|
current_mcp_servers = self.get_mcp_servers()
|
|
except Exception as e:
|
|
# Raise an error telling the user to fix the config file
|
|
logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}")
|
|
raise LettaInvalidArgumentError(f"Failed to parse MCP config file {mcp_config_path}")
|
|
|
|
# Check if the server name is already in the config
|
|
if server_config.server_name in current_mcp_servers and not allow_upsert:
|
|
raise LettaInvalidArgumentError(
|
|
f"Server name {server_config.server_name} is already in the config file", argument_name="server_name"
|
|
)
|
|
|
|
# Attempt to initialize the connection to the server
|
|
if server_config.type == MCPServerType.SSE:
|
|
new_mcp_client = AsyncFastMCPSSEClient(server_config)
|
|
elif server_config.type == MCPServerType.STDIO:
|
|
new_mcp_client = AsyncStdioMCPClient(server_config)
|
|
else:
|
|
raise LettaInvalidArgumentError(f"Invalid MCP server config: {server_config}", argument_name="server_config")
|
|
try:
|
|
await new_mcp_client.connect_to_server()
|
|
except LettaMCPConnectionError:
|
|
raise
|
|
except Exception:
|
|
logger.exception(f"Failed to connect to MCP server: {server_config.server_name}")
|
|
raise LettaMCPConnectionError(
|
|
message=f"Failed to connect to MCP server: {server_config.server_name}",
|
|
server_name=server_config.server_name,
|
|
)
|
|
# Print out the tools that are connected
|
|
logger.info(f"Attempting to fetch tools from MCP server: {server_config.server_name}")
|
|
new_mcp_tools = await new_mcp_client.list_tools()
|
|
logger.info(f"MCP tools connected: {', '.join([t.name for t in new_mcp_tools])}")
|
|
logger.debug(f"MCP tools: {', '.join([str(t) for t in new_mcp_tools])}")
|
|
|
|
# Now that we've confirmed the config is working, let's add it to the client list
|
|
self.mcp_clients[server_config.server_name] = new_mcp_client
|
|
|
|
# Add to the server file
|
|
current_mcp_servers[server_config.server_name] = server_config
|
|
|
|
# Write out the file, and make sure to in include the top-level mcpConfig (wrapped to avoid blocking event loop)
|
|
try:
|
|
new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}}
|
|
|
|
def _write_config():
|
|
with open(mcp_config_path, "w") as f:
|
|
json.dump(new_mcp_file, f, indent=4)
|
|
|
|
await asyncio.to_thread(_write_config)
|
|
except Exception as e:
|
|
logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}")
|
|
raise LettaInvalidArgumentError(f"Failed to write MCP config file {mcp_config_path}")
|
|
|
|
return list(current_mcp_servers.values())
|
|
|
|
async def delete_mcp_server_from_config(self, server_name: str) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
|
"""Delete a server config from the MCP config file"""
|
|
|
|
# TODO implement non-flatfile mechanism
|
|
if not tool_settings.mcp_read_from_config:
|
|
raise RuntimeError("MCP config file disabled. Enable it in settings.")
|
|
|
|
# If the config file doesn't exist, throw an error.
|
|
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
|
if not os.path.exists(mcp_config_path):
|
|
# If the file doesn't exist, raise an error
|
|
raise FileNotFoundError(f"MCP config file not found: {mcp_config_path}")
|
|
|
|
# If the file does exist, attempt to parse it get calling get_mcp_servers
|
|
try:
|
|
current_mcp_servers = await self.get_mcp_servers()
|
|
except Exception as e:
|
|
# Raise an error telling the user to fix the config file
|
|
logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}")
|
|
raise LettaInvalidArgumentError(f"Failed to parse MCP config file {mcp_config_path}")
|
|
|
|
# Check if the server name is already in the config
|
|
# If it's not, throw an error
|
|
if server_name not in current_mcp_servers:
|
|
raise LettaInvalidArgumentError(f"Server name {server_name} not found in MCP config file", argument_name="server_name")
|
|
|
|
# Remove from the server file
|
|
del current_mcp_servers[server_name]
|
|
|
|
# Write out the file, and make sure to in include the top-level mcpConfig (wrapped to avoid blocking event loop)
|
|
try:
|
|
new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}}
|
|
|
|
def _write_config():
|
|
with open(mcp_config_path, "w") as f:
|
|
json.dump(new_mcp_file, f, indent=4)
|
|
|
|
await asyncio.to_thread(_write_config)
|
|
except Exception as e:
|
|
logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}")
|
|
raise LettaInvalidArgumentError(f"Failed to write MCP config file {mcp_config_path}")
|
|
|
|
return list(current_mcp_servers.values())
|