import asyncio import json import os import traceback from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import httpx from anthropic import AsyncAnthropic import letta.constants as constants import letta.server.utils as server_utils from letta.config import LettaConfig from letta.constants import LETTA_TOOL_EXECUTION_DIR from letta.data_sources.connectors import DataConnector, load_data from letta.errors import ( HandleNotFoundError, LettaInvalidArgumentError, LettaMCPConnectionError, ) from letta.functions.mcp_client.types import MCPServerType, MCPTool, MCPToolHealth, SSEServerConfig, StdioServerConfig from letta.functions.schema_validator import validate_complete_json_schema from letta.helpers.datetime_helpers import get_utc_time # TODO use custom interface from letta.interface import ( CLIInterface, # for printing to terminal ) from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.otel.tracing import log_event, trace_method from letta.prompts.gpt_system import get_system_text from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.block import Block, BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig # openai schemas from letta.schemas.enums import AgentType, JobStatus, ProviderCategory, ProviderType, ToolSourceType from letta.schemas.group import GroupCreate, SleeptimeManager, VoiceSleeptimeManager from letta.schemas.job import Job, JobUpdate from letta.schemas.letta_message import LettaMessage, ToolReturnMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.message import Message from letta.schemas.passage import Passage from letta.schemas.pip_requirement import PipRequirement from letta.schemas.providers import ( AnthropicProvider, AzureProvider, BedrockProvider, DeepSeekProvider, GoogleAIProvider, GoogleVertexProvider, GroqProvider, LettaProvider, LMStudioOpenAIProvider, MiniMaxProvider, OllamaProvider, OpenAIProvider, OpenRouterProvider, Provider, SGLangProvider, TogetherProvider, VLLMProvider, XAIProvider, ZAIProvider, ) from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate from letta.schemas.secret import Secret from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.user import User from letta.services.agent_manager import AgentManager from letta.services.agent_serialization_manager import AgentSerializationManager from letta.services.archive_manager import ArchiveManager from letta.services.block_manager import BlockManager from letta.services.block_manager_git import GIT_MEMORY_ENABLED_TAG, GitEnabledBlockManager from letta.services.file_manager import FileManager from letta.services.files_agents_manager import FileAgentManager from letta.services.group_manager import GroupManager from letta.services.helpers.tool_execution_helper import prepare_local_sandbox from letta.services.identity_manager import IdentityManager from letta.services.job_manager import JobManager from letta.services.llm_batch_manager import LLMBatchManager from letta.services.mcp.base_client import AsyncBaseMCPClient from letta.services.mcp.fastmcp_client import AsyncFastMCPSSEClient from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY from letta.services.mcp.stdio_client import AsyncStdioMCPClient from letta.services.mcp_manager import MCPManager from letta.services.mcp_server_manager import MCPServerManager from letta.services.memory_repo import MemfsClient from letta.services.message_manager import MessageManager from letta.services.organization_manager import OrganizationManager from letta.services.passage_manager import PassageManager from letta.services.provider_manager import ProviderManager from letta.services.run_manager import RunManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.step_manager import StepManager from letta.services.telemetry_manager import TelemetryManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import DatabaseChoice, model_settings, settings, tool_settings from letta.streaming_interface import AgentChunkStreamingInterface from letta.utils import get_friendly_error_msg, get_persona_text config = LettaConfig.load() logger = get_logger(__name__) class SyncServer(object): """Simple single-threaded / blocking server process""" def __init__( self, chaining: bool = True, max_chaining_steps: Optional[int] = 100, default_interface_factory: Callable[[], AgentChunkStreamingInterface] = lambda: CLIInterface(), init_with_default_org_and_user: bool = True, # default_interface: AgentInterface = CLIInterface(), # default_persistence_manager_cls: PersistenceManager = LocalStateManager, # auth_mode: str = "none", # "none, "jwt", "external" ): """Server process holds in-memory agents that are being run""" # chaining = whether or not to run again if request_heartbeat=true self.chaining = chaining # if chaining == true, what's the max number of times we'll chain before yielding? # none = no limit, can go on forever self.max_chaining_steps = max_chaining_steps # The default interface that will get assigned to agents ON LOAD self.default_interface_factory = default_interface_factory # Initialize the metadata store config = LettaConfig.load() if settings.database_engine is DatabaseChoice.POSTGRES: config.recall_storage_type = "postgres" config.recall_storage_uri = settings.letta_pg_uri_no_default config.archival_storage_type = "postgres" config.archival_storage_uri = settings.letta_pg_uri_no_default config.save() self.config = config # Managers that interface with data models self.organization_manager = OrganizationManager() self.passage_manager = PassageManager() self.user_manager = UserManager() self.tool_manager = ToolManager() self.mcp_manager = MCPManager() self.mcp_server_manager = MCPServerManager() self.memory_repo_manager = self._init_memory_repo_manager() # Use git-enabled block manager if memory repo is configured # It falls back to standard PostgreSQL behavior when git isn't enabled for an agent if self.memory_repo_manager: self.block_manager = GitEnabledBlockManager(memory_repo_manager=self.memory_repo_manager) else: self.block_manager = BlockManager() self.source_manager = SourceManager() self.sandbox_config_manager = SandboxConfigManager() self.message_manager = MessageManager() self.job_manager = JobManager() self.run_manager = RunManager() self.agent_manager = AgentManager(block_manager=self.block_manager) self.archive_manager = ArchiveManager() self.provider_manager = ProviderManager() self.step_manager = StepManager() self.identity_manager = IdentityManager() self.group_manager = GroupManager() self.batch_manager = LLMBatchManager() self.telemetry_manager = TelemetryManager() self.file_agent_manager = FileAgentManager() self.file_manager = FileManager() # Import and initialize the agent generate completion manager from letta.services.agent_generate_completion_manager import AgentGenerateCompletionManager self.agent_generate_completion_manager = AgentGenerateCompletionManager(server=self) self.agent_serialization_manager = AgentSerializationManager( agent_manager=self.agent_manager, tool_manager=self.tool_manager, source_manager=self.source_manager, block_manager=self.block_manager, group_manager=self.group_manager, mcp_manager=self.mcp_manager, file_manager=self.file_manager, file_agent_manager=self.file_agent_manager, message_manager=self.message_manager, ) if settings.enable_batch_job_polling: # A resusable httpx client timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0) limits = httpx.Limits(max_connections=100, max_keepalive_connections=80, keepalive_expiry=300) self.httpx_client = httpx.AsyncClient(timeout=timeout, follow_redirects=True, limits=limits) # TODO: Replace this with the Anthropic client we have in house # Reuse the shared httpx client to prevent duplicate SSL contexts and connection pools self.anthropic_async_client = AsyncAnthropic(http_client=self.httpx_client) else: self.httpx_client = None self.anthropic_async_client = None # For MCP # TODO: remove this """Initialize the MCP clients (there may be multiple)""" self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {} # collect providers (always has Letta as a default) from letta.constants import LETTA_MODEL_ENDPOINT self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)] if model_settings.openai_api_key: self._enabled_providers.append( OpenAIProvider( name="openai", api_key_enc=Secret.from_plaintext(model_settings.openai_api_key), base_url=model_settings.openai_api_base, ) ) if model_settings.anthropic_api_key: self._enabled_providers.append( AnthropicProvider( name="anthropic", api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key), ) ) if model_settings.ollama_base_url: self._enabled_providers.append( OllamaProvider( name="ollama", base_url=model_settings.ollama_base_url, default_prompt_formatter=model_settings.default_prompt_formatter, ) ) if model_settings.gemini_api_key: self._enabled_providers.append( GoogleAIProvider( name="google_ai", api_key_enc=Secret.from_plaintext(model_settings.gemini_api_key), ) ) if model_settings.google_cloud_location and model_settings.google_cloud_project: self._enabled_providers.append( GoogleVertexProvider( name="google_vertex", google_cloud_project=model_settings.google_cloud_project, google_cloud_location=model_settings.google_cloud_location, ) ) if model_settings.azure_api_key and model_settings.azure_base_url: assert model_settings.azure_api_version, "AZURE_API_VERSION is required" self._enabled_providers.append( AzureProvider( name="azure", api_key_enc=Secret.from_plaintext(model_settings.azure_api_key), base_url=model_settings.azure_base_url, api_version=model_settings.azure_api_version, ) ) if model_settings.groq_api_key: self._enabled_providers.append( GroqProvider( name="groq", api_key_enc=Secret.from_plaintext(model_settings.groq_api_key), ) ) if model_settings.together_api_key: self._enabled_providers.append( TogetherProvider( name="together", api_key_enc=Secret.from_plaintext(model_settings.together_api_key), default_prompt_formatter=model_settings.default_prompt_formatter, ) ) if model_settings.vllm_api_base: # vLLM exposes both a /chat/completions and a /completions endpoint # NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup # see: https://docs.vllm.ai/en/stable/features/tool_calling.html # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" # Auto-append /v1 to the base URL vllm_url = ( model_settings.vllm_api_base if model_settings.vllm_api_base.endswith("/v1") else model_settings.vllm_api_base + "/v1" ) self._enabled_providers.append( VLLMProvider( name="vllm", base_url=vllm_url, default_prompt_formatter=model_settings.default_prompt_formatter, handle_base=model_settings.vllm_handle_base, ) ) if model_settings.sglang_api_base: # Auto-append /v1 to the base URL sglang_url = ( model_settings.sglang_api_base if model_settings.sglang_api_base.endswith("/v1") else model_settings.sglang_api_base + "/v1" ) self._enabled_providers.append( SGLangProvider( name="sglang", base_url=sglang_url, default_prompt_formatter=model_settings.default_prompt_formatter, handle_base=model_settings.sglang_handle_base, ) ) if model_settings.aws_access_key_id and model_settings.aws_secret_access_key and model_settings.aws_default_region: self._enabled_providers.append( BedrockProvider( name="bedrock", access_key=model_settings.aws_access_key_id, api_key=model_settings.aws_secret_access_key, region=model_settings.aws_default_region, ) ) # Attempt to enable LM Studio by default if model_settings.lmstudio_base_url: # Auto-append v1 to the base URL lmstudio_url = ( model_settings.lmstudio_base_url if model_settings.lmstudio_base_url.endswith("/v1") else model_settings.lmstudio_base_url + "/v1" ) self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url)) if model_settings.deepseek_api_key: self._enabled_providers.append( DeepSeekProvider( name="deepseek", api_key_enc=Secret.from_plaintext(model_settings.deepseek_api_key), ) ) if model_settings.xai_api_key: self._enabled_providers.append( XAIProvider( name="xai", api_key_enc=Secret.from_plaintext(model_settings.xai_api_key), ) ) if model_settings.minimax_api_key: self._enabled_providers.append( MiniMaxProvider( name="minimax", api_key_enc=Secret.from_plaintext(model_settings.minimax_api_key), ) ) if model_settings.zai_api_key: self._enabled_providers.append( ZAIProvider( name="zai", api_key_enc=Secret.from_plaintext(model_settings.zai_api_key), base_url=model_settings.zai_base_url, ) ) if model_settings.openrouter_api_key: self._enabled_providers.append( OpenRouterProvider( name=model_settings.openrouter_handle_base if model_settings.openrouter_handle_base else "openrouter", api_key_enc=Secret.from_plaintext(model_settings.openrouter_api_key), ) ) async def init_async(self, init_with_default_org_and_user: bool = True): # Make default user and org if init_with_default_org_and_user: self.default_org = await self.organization_manager.create_default_organization_async() self.default_user = await self.user_manager.create_default_actor_async() print(f"Default user: {self.default_user} and org: {self.default_org}") await self.tool_manager.upsert_base_tools_async(actor=self.default_user) # Sync environment-based providers to database (idempotent, safe for multi-pod startup) await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user) # Sync provider models to database await self._sync_provider_models_async() # For OSS users, create a local sandbox config oss_default_user = await self.user_manager.get_default_actor_async() use_venv = False if not tool_settings.tool_exec_venv_name else True venv_name = tool_settings.tool_exec_venv_name or "venv" tool_dir = tool_settings.tool_exec_dir or LETTA_TOOL_EXECUTION_DIR venv_dir = Path(tool_dir) / venv_name tool_path = Path(tool_dir) if tool_path.exists() and not tool_path.is_dir(): logger.error(f"LETTA_TOOL_SANDBOX_DIR exists but is not a directory: {tool_dir}") else: if not tool_path.exists(): logger.warning(f"LETTA_TOOL_SANDBOX_DIR does not exist, creating now: {tool_dir}") tool_path.mkdir(parents=True, exist_ok=True) if tool_settings.tool_exec_venv_name and not venv_dir.is_dir(): logger.warning( f"Provided LETTA_TOOL_SANDBOX_VENV_NAME is not a valid venv ({venv_dir}), one will be created for you during tool execution." ) sandbox_config_create = SandboxConfigCreate( config=LocalSandboxConfig(sandbox_dir=tool_settings.tool_exec_dir, use_venv=use_venv, venv_name=venv_name) ) sandbox_config = await self.sandbox_config_manager.create_or_update_sandbox_config_async( sandbox_config_create=sandbox_config_create, actor=oss_default_user ) logger.debug(f"Successfully created default local sandbox config:\n{sandbox_config.get_local_config().model_dump()}") if use_venv and tool_settings.tool_exec_autoreload_venv: prepare_local_sandbox( sandbox_config.get_local_config(), env=os.environ.copy(), force_recreate=True, ) def _init_memory_repo_manager(self) -> Optional[MemfsClient]: """Initialize the memory repository manager if configured. Requires LETTA_MEMFS_SERVICE_URL to be set to the external memfs service URL. Returns: MemfsClient if configured, None otherwise """ from letta.settings import settings if not settings.memfs_service_url: logger.debug("Memory repo manager not configured (memfs_service_url not set)") return None logger.info("Memory repo manager using memfs service: %s", settings.memfs_service_url) return MemfsClient(base_url=settings.memfs_service_url) def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]: """Find and return an enabled provider by name. Args: provider_name: The name of the provider to find Returns: The matching enabled provider, or None if not found """ for provider in self._enabled_providers: if provider.name == provider_name: return provider return None async def _sync_provider_models_async(self): """Sync all provider models to database at startup.""" logger.info("Syncing provider models to database") # Get persisted providers from database (they now have IDs) persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user) for persisted_provider in persisted_providers: try: # Find the matching enabled provider instance to call list_models on enabled_provider = self._get_enabled_provider(persisted_provider.name) if not enabled_provider: # Only delete base providers that are no longer enabled # BYOK providers are user-created and should not be automatically deleted if persisted_provider.provider_category == ProviderCategory.base: logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database") try: await self.provider_manager.delete_provider_by_id_async( provider_id=persisted_provider.id, actor=self.default_user ) except NoResultFound: # Provider was already deleted (race condition in multi-pod startup) logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping") else: logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync") continue # Fetch models from provider llm_models = await enabled_provider.list_llm_models_async() embedding_models = await enabled_provider.list_embedding_models_async() # Save to database with the persisted provider (which has an ID) await self.provider_manager.sync_provider_models_async( provider=persisted_provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=None, # Global models ) # Update last_synced timestamp await self.provider_manager.update_provider_last_synced_async(persisted_provider.id) logger.info( f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}" ) except Exception as e: logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True) async def init_mcp_clients(self): # TODO: remove this mcp_server_configs = self.get_mcp_servers() for server_name, server_config in mcp_server_configs.items(): if server_config.type == MCPServerType.SSE: self.mcp_clients[server_name] = AsyncFastMCPSSEClient(server_config) elif server_config.type == MCPServerType.STDIO: self.mcp_clients[server_name] = AsyncStdioMCPClient(server_config) else: raise LettaInvalidArgumentError(f"Invalid MCP server config: {server_config}", argument_name="server_config") try: await self.mcp_clients[server_name].connect_to_server() except Exception as e: logger.error(e) self.mcp_clients.pop(server_name) logger.info(f"MCP clients initialized: {len(self.mcp_clients)} active connections") # Print out the tools that are connected for server_name, client in self.mcp_clients.items(): logger.info(f"Attempting to fetch tools from MCP server: {server_name}") mcp_tools = await client.list_tools() logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}") logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}") @trace_method async def create_agent_async( self, request: CreateAgent, actor: User, ) -> AgentState: if request.llm_config is None: additional_config_params = {} if request.model is None: if settings.default_llm_handle is None: raise LettaInvalidArgumentError("Must specify either model or llm_config in request", argument_name="model") else: handle = settings.default_llm_handle else: if isinstance(request.model, str): handle = request.model elif isinstance(request.model, list): raise LettaInvalidArgumentError("Multiple models are not supported yet") else: # EXTREMELEY HACKY, TEMPORARY WORKAROUND handle = f"{request.model.provider}/{request.model.model}" # TODO: figure out how to override various params additional_config_params = request.model._to_legacy_config_params() additional_config_params["model"] = request.model.model additional_config_params["provider_name"] = request.model.provider config_params = { "handle": handle, "context_window_limit": request.context_window_limit, "max_tokens": request.max_tokens, "max_reasoning_tokens": request.max_reasoning_tokens, "enable_reasoner": request.enable_reasoner, } log_event(name="start get_llm_config_from_handle", attributes=config_params) request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params) log_event(name="end get_llm_config_from_handle", attributes=config_params) if request.model and isinstance(request.model, str): assert request.llm_config.handle == request.model, ( f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}" ) # update with model_settings if request.model_settings is not None: update_llm_config_params = request.model_settings._to_legacy_config_params() # Don't clobber max_tokens with the Pydantic default when the caller # didn't explicitly provide max_output_tokens in the request. if "max_output_tokens" not in request.model_settings.model_fields_set: update_llm_config_params.pop("max_tokens", None) request.llm_config = request.llm_config.model_copy(update=update_llm_config_params) # Copy parallel_tool_calls from request to llm_config if provided if request.parallel_tool_calls is not None: request.llm_config.parallel_tool_calls = request.parallel_tool_calls if request.reasoning is None: request.reasoning = request.llm_config.enable_reasoner or request.llm_config.put_inner_thoughts_in_kwargs if request.embedding_config is None: if request.embedding is None: if settings.default_embedding_handle is not None: request.embedding = settings.default_embedding_handle # Only resolve embedding config if we have an embedding handle if request.embedding is not None: embedding_config_params = { "handle": request.embedding, "embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, } log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params) request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params) log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params) # If git-backed memory is requested on create, we enable it *after* agent creation. # We strip the tag during creation so `enable_git_memory_for_agent` can be the # single place that both creates the repo and writes the tag. wants_git_memory = bool(request.tags and GIT_MEMORY_ENABLED_TAG in request.tags) create_request = request if wants_git_memory: filtered_tags = [t for t in (request.tags or []) if t != GIT_MEMORY_ENABLED_TAG] updates: dict = {"tags": filtered_tags} # Transform block labels to path-based for git-memory agents. # Blocks without a "/" prefix go under system/ (rendered in system prompt). # e.g. "human" -> "system/human", "persona" -> "system/persona" # Blocks with an explicit path (e.g. "notes/project") keep their label. if request.memory_blocks: transformed_blocks = [] for block in request.memory_blocks: if not block.label.startswith("system/"): block = block.model_copy(update={"label": f"system/{block.label}"}) transformed_blocks.append(block) updates["memory_blocks"] = transformed_blocks create_request = request.model_copy(update=updates) log_event(name="start create_agent db") main_agent = await self.agent_manager.create_agent_async( agent_create=create_request, actor=actor, ) log_event(name="end create_agent db") # Enable git-backed memory (creates repo + commits initial blocks + adds tag) if wants_git_memory and isinstance(self.block_manager, GitEnabledBlockManager): await self.block_manager.enable_git_memory_for_agent(agent_id=main_agent.id, actor=actor) # Preserve the user's requested tags and git_enabled flag in the response model. try: main_agent.tags = list(request.tags or []) main_agent.memory.git_enabled = True except Exception: pass log_event(name="start insert_files_into_context_window db") # Use folder_ids if provided, otherwise fall back to deprecated source_ids for backwards compatibility folder_ids_to_attach = request.folder_ids if request.folder_ids else request.source_ids if folder_ids_to_attach: for folder_id in folder_ids_to_attach: files = await self.file_manager.list_files(folder_id, actor, include_content=True) await self.agent_manager.insert_files_into_context_window( agent_state=main_agent, file_metadata_with_content=files, actor=actor ) main_agent = await self.agent_manager.refresh_file_blocks(agent_state=main_agent, actor=actor) main_agent = await self.agent_manager.attach_missing_files_tools_async(agent_state=main_agent, actor=actor) log_event(name="end insert_files_into_context_window db") if request.enable_sleeptime: if request.agent_type == AgentType.voice_convo_agent: main_agent = await self.create_voice_sleeptime_agent_async(main_agent=main_agent, actor=actor) else: main_agent = await self.create_sleeptime_agent_async(main_agent=main_agent, actor=actor) return main_agent async def update_agent_async( self, agent_id: str, request: UpdateAgent, actor: User, ) -> AgentState: # Build llm_config from convenience fields if llm_config is not provided if request.llm_config is None and ( request.model is not None or request.context_window_limit is not None or request.max_tokens is not None ): if request.model is None: agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) request.model = agent.llm_config.handle config_params = { "handle": request.model, "context_window_limit": request.context_window_limit, "max_tokens": request.max_tokens, } log_event(name="start get_llm_config_from_handle", attributes=config_params) request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params) log_event(name="end get_llm_config_from_handle", attributes=config_params) # update with model_settings if request.model_settings is not None: if request.llm_config is None: # Get the current agent's llm_config if not already set agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) request.llm_config = agent.llm_config.model_copy() else: # TODO: Refactor update_agent to accept partial llm_config so we # don't need to fetch the full agent just to preserve max_tokens. if request.max_tokens is None and "max_output_tokens" not in request.model_settings.model_fields_set: agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) request.llm_config.max_tokens = agent.llm_config.max_tokens update_llm_config_params = request.model_settings._to_legacy_config_params() # Don't clobber max_tokens with the Pydantic default when the caller # didn't explicitly provide max_output_tokens in the request. if "max_output_tokens" not in request.model_settings.model_fields_set: update_llm_config_params.pop("max_tokens", None) request.llm_config = request.llm_config.model_copy(update=update_llm_config_params) # Copy parallel_tool_calls from request to llm_config if provided if request.parallel_tool_calls is not None: if request.llm_config is None: # Get the current agent's llm_config and update it agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) request.llm_config = agent.llm_config.model_copy() request.llm_config.parallel_tool_calls = request.parallel_tool_calls if request.embedding is not None: request.embedding_config = await self.get_embedding_config_from_handle_async(handle=request.embedding, actor=actor) if request.enable_sleeptime: agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) if agent.multi_agent_group is None: if agent.agent_type == AgentType.voice_convo_agent: await self.create_voice_sleeptime_agent_async(main_agent=agent, actor=actor) else: await self.create_sleeptime_agent_async(main_agent=agent, actor=actor) # If git-backed memory is requested via tag update, initialize/backfill the repo. wants_git_memory = bool(request.tags and GIT_MEMORY_ENABLED_TAG in request.tags) updated_agent = await self.agent_manager.update_agent_async( agent_id=agent_id, agent_update=request, actor=actor, ) # Ensure repo exists and initial blocks are committed when the tag is present. if wants_git_memory and isinstance(self.block_manager, GitEnabledBlockManager): await self.block_manager.enable_git_memory_for_agent(agent_id=agent_id, actor=actor) # Preserve the user's requested tags in the response model. try: updated_agent.tags = list(request.tags or []) except Exception: pass return updated_agent async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]: if main_agent.embedding_config is None: logger.warning(f"Skipping sleeptime agent creation for agent {main_agent.id}: no embedding config provided") return None request = CreateAgent( name=main_agent.name + "-sleeptime", agent_type=AgentType.sleeptime_agent, block_ids=[block.id for block in main_agent.memory.blocks], memory_blocks=[ CreateBlock( label="memory_persona", value=get_persona_text("sleeptime_memory_persona"), ), ], llm_config=main_agent.llm_config, embedding_config=main_agent.embedding_config, project_id=main_agent.project_id, ) sleeptime_agent = await self.agent_manager.create_agent_async( agent_create=request, actor=actor, ) await self.group_manager.create_group_async( group=GroupCreate( description="", agent_ids=[sleeptime_agent.id], manager_config=SleeptimeManager( manager_agent_id=main_agent.id, sleeptime_agent_frequency=5, ), ), actor=actor, ) return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor) async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> Optional[AgentState]: if main_agent.embedding_config is None: logger.warning(f"Skipping voice sleeptime agent creation for agent {main_agent.id}: no embedding config provided") return None # TODO: Inject system request = CreateAgent( name=main_agent.name + "-sleeptime", agent_type=AgentType.voice_sleeptime_agent, block_ids=[block.id for block in main_agent.memory.blocks], memory_blocks=[ CreateBlock( label="memory_persona", value=get_persona_text("voice_memory_persona"), ), ], llm_config=LLMConfig.default_config("gpt-4.1"), embedding_config=main_agent.embedding_config, project_id=main_agent.project_id, ) voice_sleeptime_agent = await self.agent_manager.create_agent_async( agent_create=request, actor=actor, ) await self.group_manager.create_group_async( group=GroupCreate( description="Low latency voice chat with async memory management.", agent_ids=[voice_sleeptime_agent.id], manager_config=VoiceSleeptimeManager( manager_agent_id=main_agent.id, max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, ), ), actor=actor, ) return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor) async def get_agent_memory_async(self, agent_id: str, actor: User) -> Memory: """Return the memory of an agent (core memory)""" agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) return agent.memory async def get_agent_archival_async( self, agent_id: str, actor: User, after: Optional[str] = None, before: Optional[str] = None, limit: Optional[int] = 100, query_text: Optional[str] = None, ascending: Optional[bool] = True, ) -> List[Passage]: # iterate over records records = await self.agent_manager.query_agent_passages_async( actor=actor, agent_id=agent_id, after=after, query_text=query_text, before=before, ascending=ascending, limit=limit, ) # Extract just the passages (SQL path returns empty metadata) return [passage for passage, _, _ in records] async def insert_archival_memory_async( self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime] ) -> List[Passage]: from letta.services.context_window_calculator.token_counter import create_token_counter from letta.settings import settings # Get the agent object (loaded in memory) agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) # Check token count against limit token_counter = create_token_counter( model_endpoint_type=agent_state.llm_config.model_endpoint_type, model=agent_state.llm_config.model, actor=actor, agent_id=agent_id, ) token_count = await token_counter.count_text_tokens(memory_contents) if token_count > settings.archival_memory_token_limit: raise LettaInvalidArgumentError( message=f"Archival memory content exceeds token limit of {settings.archival_memory_token_limit} tokens (found {token_count} tokens)", argument_name="memory_contents", ) # Use passage manager which handles dual-write to Turbopuffer if enabled passages = await self.passage_manager.insert_passage( agent_state=agent_state, text=memory_contents, tags=tags, actor=actor, created_at=created_at ) return passages async def delete_archival_memory_async(self, memory_id: str, actor: User): # TODO check if it exists first, and throw error if not # TODO: need to also rebuild the prompt here await self.passage_manager.get_passage_by_id_async(passage_id=memory_id, actor=actor) # delete the passage await self.passage_manager.delete_passage_by_id_async(passage_id=memory_id, actor=actor) async def get_agent_recall( self, user_id: str, agent_id: str, after: Optional[str] = None, before: Optional[str] = None, limit: Optional[int] = 100, group_id: Optional[str] = None, reverse: Optional[bool] = False, return_message_object: bool = True, use_assistant_message: bool = True, assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[List[Message], List[LettaMessage]]: # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user actor = await self.user_manager.get_actor_or_default_async(actor_id=user_id) records = await self.message_manager.list_messages( agent_id=agent_id, actor=actor, after=after, before=before, limit=limit, ascending=not reverse, group_id=group_id, ) if not return_message_object: records = Message.to_letta_messages_from_list( messages=records, use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, reverse=reverse, ) if reverse: records = records[::-1] return records async def get_agent_recall_async( self, agent_id: str, actor: User, after: Optional[str] = None, before: Optional[str] = None, limit: Optional[int] = 100, group_id: Optional[str] = None, reverse: Optional[bool] = False, return_message_object: bool = True, use_assistant_message: bool = True, assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, include_err: Optional[bool] = None, conversation_id: Optional[str] = None, ) -> Union[List[Message], List[LettaMessage]]: records = await self.message_manager.list_messages( agent_id=agent_id, actor=actor, after=after, before=before, limit=limit, ascending=not reverse, group_id=group_id, include_err=include_err, conversation_id=conversation_id, ) if not return_message_object: # Get agent state to determine if it's a react agent agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) text_is_assistant_message = agent_state.agent_type == AgentType.letta_v1_agent records = Message.to_letta_messages_from_list( messages=records, use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, reverse=reverse, include_err=include_err, text_is_assistant_message=text_is_assistant_message, ) if reverse: records = records[::-1] return records async def get_all_messages_recall_async( self, actor: User, after: Optional[str] = None, before: Optional[str] = None, limit: Optional[int] = 100, group_id: Optional[str] = None, reverse: Optional[bool] = False, return_message_object: bool = True, use_assistant_message: bool = True, assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, include_err: Optional[bool] = None, conversation_id: Optional[str] = None, ) -> Union[List[Message], List[LettaMessage]]: records = await self.message_manager.list_messages( agent_id=None, actor=actor, after=after, before=before, limit=limit, ascending=not reverse, group_id=group_id, include_err=include_err, conversation_id=conversation_id, ) if not return_message_object: # NOTE: We are assuming all messages are coming from letta_v1_agent. This may lead to slightly incorrect assistant message handling. # text_is_assistant_message = agent_state.agent_type == AgentType.letta_v1_agent text_is_assistant_message = True records = Message.to_letta_messages_from_list( messages=records, use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, reverse=reverse, include_err=include_err, text_is_assistant_message=text_is_assistant_message, ) if reverse: records = records[::-1] return records def get_server_config(self, include_defaults: bool = False) -> dict: """Return the base config""" def clean_keys(config): config_copy = config.copy() for k, v in config.items(): if k == "key" or "_key" in k: config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5) return config_copy # TODO: do we need a separate server config? base_config = vars(self.config) clean_base_config = clean_keys(base_config) response = {"config": clean_base_config} if include_defaults: default_config = vars(LettaConfig()) clean_default_config = clean_keys(default_config) response["defaults"] = clean_default_config return response def update_agent_core_memory(self, agent_id: str, label: str, value: str, actor: User) -> Memory: """Update the value of a block in the agent's memory""" # get the block id block = self.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=actor) # update the block self.block_manager.update_block(block_id=block.id, block_update=BlockUpdate(value=value), actor=actor) # rebuild system prompt for agent, potentially changed return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory async def delete_source(self, source_id: str, actor: User): """Delete a data source""" await self.source_manager.delete_source(source_id=source_id, actor=actor) # delete data from passage store passages_to_be_deleted = await self.agent_manager.query_source_passages_async(actor=actor, source_id=source_id, limit=None) await self.passage_manager.delete_source_passages_async(actor=actor, passages=passages_to_be_deleted) # TODO: delete data from agent passage stores (?) async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job: # update job job = await self.job_manager.get_job_by_id_async(job_id, actor=actor) job.status = JobStatus.running await self.job_manager.update_job_by_id_async(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor) # try: from letta.data_sources.connectors import DirectoryConnector # TODO: move this into a thread source = await self.source_manager.get_source_by_id(source_id=source_id, actor=actor) connector = DirectoryConnector(input_files=[file_path]) num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector) # update all agents who have this source attached agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor) for agent_state in agent_states: agent_id = agent_state.id # Attach source to agent curr_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id) agent_state = await self.agent_manager.attach_source_async(agent_id=agent_state.id, source_id=source_id, actor=actor) new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id) assert new_passage_size >= curr_passage_size # in case empty files are added # update job status job.status = JobStatus.completed job.metadata["num_passages"] = num_passages job.metadata["num_documents"] = num_documents await self.job_manager.update_job_by_id_async(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor) return job async def load_file_to_source_via_mistral(self): pass async def sleeptime_document_ingest_async( self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False ) -> None: pass async def _remove_file_from_agent(self, agent_id: str, file_id: str, actor: User) -> None: """ Internal method to remove a document block for an agent. """ try: await self.file_agent_manager.detach_file( agent_id=agent_id, file_id=file_id, actor=actor, ) except NoResultFound: logger.info(f"File {file_id} already removed from agent {agent_id}, skipping...") async def remove_file_from_context_windows(self, source_id: str, file_id: str, actor: User) -> None: """ Remove the document from the context window of all agents attached to the given source. """ # Use the optimized ids_only parameter agent_ids = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor, ids_only=True) # Return early if no agents if not agent_ids: return logger.info(f"Removing file from context window for source: {source_id}") logger.info(f"Attached agents: {agent_ids}") # Create agent-file pairs for bulk deletion agent_file_pairs = [(agent_id, file_id) for agent_id in agent_ids] # Bulk delete in a single query deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor) logger.info(f"Removed file {file_id} from {deleted_count} agent context windows") async def remove_files_from_context_window(self, agent_state: AgentState, file_ids: List[str], actor: User) -> None: """ Remove multiple documents from the context window of an agent attached to the given source. """ logger.info(f"Removing files from context window for agent_state: {agent_state.id}") logger.info(f"Files to remove: {file_ids}") # Create agent-file pairs for bulk deletion agent_file_pairs = [(agent_state.id, file_id) for file_id in file_ids] # Bulk delete in a single query deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor) logger.info(f"Removed {deleted_count} files from agent {agent_state.id}") async def create_document_sleeptime_agent_async( self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False ) -> Optional[AgentState]: if main_agent.embedding_config is None: logger.warning(f"Skipping document sleeptime agent creation for agent {main_agent.id}: no embedding config provided") return None try: block = await self.agent_manager.get_block_with_label_async(agent_id=main_agent.id, block_label=source.name, actor=actor) except Exception: block = await self.block_manager.create_or_update_block_async(Block(label=source.name, value=""), actor=actor) await self.agent_manager.attach_block_async(agent_id=main_agent.id, block_id=block.id, actor=actor) if clear_history and block.value != "": block = await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=""), actor=actor) request = CreateAgent( name=main_agent.name + "-doc-sleeptime", system=get_system_text("sleeptime_doc_ingest"), agent_type=AgentType.sleeptime_agent, block_ids=[block.id], memory_blocks=[ CreateBlock( label="persona", value=get_persona_text("sleeptime_doc_persona"), ), CreateBlock( label="instructions", value=source.instructions, ), ], llm_config=main_agent.llm_config, embedding_config=main_agent.embedding_config, project_id=main_agent.project_id, include_base_tools=False, tools=constants.BASE_SLEEPTIME_TOOLS, ) return await self.agent_manager.create_agent_async( agent_create=request, actor=actor, ) async def load_data( self, user_id: str, connector: DataConnector, source_name: str, ) -> Tuple[int, int]: """Load data from a DataConnector into a source for a specified user_id""" # TODO: this should be implemented as a batch job or at least async, since it may take a long time # load data from a data source into the document store actor = await self.user_manager.get_actor_by_id_async(actor_id=user_id) source = await self.source_manager.get_source_by_name(source_name=source_name, actor=actor) if source is None: raise NoResultFound(f"Data source {source_name} does not exist for user {user_id}") # load data into the document store passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor) return passage_count, document_count def _get_provider_sort_key(self, model: LLMConfig) -> Tuple[int, str, str]: """Get sort key for a model: (provider_priority, provider_name, model_name)""" provider_priority = constants.PROVIDER_ORDER.get(model.provider_name, 999) return (provider_priority, model.provider_name or "", model.model or "") def _get_embedding_provider_sort_key(self, model: EmbeddingConfig) -> Tuple[int, str, str]: """Get sort key for an embedding model: (provider_priority, provider_name, model_name)""" # Extract provider name from handle (format: "provider_name/model_name") provider_name = model.handle.split("/")[0] if model.handle and "/" in model.handle else "" provider_priority = constants.PROVIDER_ORDER.get(provider_name, 999) return (provider_priority, provider_name, model.embedding_model or "") @trace_method async def list_llm_models_async( self, actor: User, provider_category: Optional[List[ProviderCategory]] = None, provider_name: Optional[str] = None, provider_type: Optional[ProviderType] = None, ) -> List[LLMConfig]: """List available LLM models - base from DB, BYOK from provider endpoints""" llm_models = [] # Determine which categories to include include_base = not provider_category or ProviderCategory.base in provider_category include_byok = not provider_category or ProviderCategory.byok in provider_category # Get base provider models from database if include_base: provider_models = await self.provider_manager.list_models_async( actor=actor, model_type="llm", enabled=True, ) # Build LLMConfig objects from database provider_cache: Dict[str, Provider] = {} typed_provider_cache: Dict[str, Any] = {} for model in provider_models: # Get provider details (with caching to avoid N+1 queries) if model.provider_id not in provider_cache: provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) typed_provider_cache[model.provider_id] = provider_cache[model.provider_id].cast_to_subtype() provider = provider_cache[model.provider_id] typed_provider = typed_provider_cache[model.provider_id] # Skip non-base providers (they're handled separately) if provider.provider_category != ProviderCategory.base: continue # Apply provider_name/provider_type filters if specified if provider_name and provider.name != provider_name: continue if provider_type and provider.provider_type != provider_type: continue # For bedrock, use schema default for base_url since DB may have NULL # TODO: can maybe do this for all models but want to isolate change so we don't break any other providers if provider.provider_type == ProviderType.bedrock: model_endpoint = typed_provider.base_url else: model_endpoint = provider.base_url # Get provider-specific default max_tokens max_tokens = typed_provider.get_default_max_output_tokens(model.name) llm_config = LLMConfig( model=model.name, model_endpoint_type=model.model_endpoint_type, model_endpoint=model_endpoint, context_window=model.max_context_window or 16384, handle=model.handle, provider_name=provider.name, provider_category=provider.provider_category, max_tokens=max_tokens, ) llm_models.append(llm_config) # Get BYOK provider models - sync if not synced yet, then read from DB if include_byok: byok_providers = await self.provider_manager.list_providers_async( actor=actor, name=provider_name, provider_type=provider_type, provider_category=[ProviderCategory.byok], ) for provider in byok_providers: try: # Get typed provider to access schema defaults (e.g., base_url) typed_provider = provider.cast_to_subtype() provider_llm_models = None should_sync_models = provider.last_synced is None # ChatGPT OAuth uses a hardcoded model list. If that list changes, # backfill already-synced providers that are missing new handles. if provider.provider_type == ProviderType.chatgpt_oauth and not should_sync_models: expected_models = await typed_provider.list_llm_models_async() expected_handles = {model.handle for model in expected_models} provider_llm_models = await self.provider_manager.list_models_async( actor=actor, model_type="llm", provider_id=provider.id, enabled=True, ) existing_handles = {model.handle for model in provider_llm_models} should_sync_models = not expected_handles.issubset(existing_handles) if should_sync_models: models = await typed_provider.list_llm_models_async() embedding_models = await typed_provider.list_embedding_models_async() await self.provider_manager.sync_provider_models_async( provider=provider, llm_models=models, embedding_models=embedding_models, organization_id=provider.organization_id, ) await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor) # Read from database if provider_llm_models is None: provider_llm_models = await self.provider_manager.list_models_async( actor=actor, model_type="llm", provider_id=provider.id, enabled=True, ) for model in provider_llm_models: max_tokens = typed_provider.get_default_max_output_tokens(model.name) llm_config = LLMConfig( model=model.name, model_endpoint_type=model.model_endpoint_type, model_endpoint=typed_provider.base_url, context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW, handle=model.handle, provider_name=provider.name, provider_category=ProviderCategory.byok, max_tokens=max_tokens, ) llm_models.append(llm_config) except Exception as e: logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}") # Sort by provider order (matching old _enabled_providers order), then by model name llm_models.sort(key=self._get_provider_sort_key) return llm_models async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: """List available embedding models - base from DB, BYOK from provider endpoints""" embedding_models = [] # Get base provider models from database provider_models = await self.provider_manager.list_models_async( actor=actor, model_type="embedding", enabled=True, ) # Build EmbeddingConfig objects from database (base providers only) provider_cache: Dict[str, Provider] = {} for model in provider_models: # Get provider details (with caching to avoid N+1 queries) if model.provider_id not in provider_cache: provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) provider = provider_cache[model.provider_id] # Skip non-base providers (they're handled separately) if provider.provider_category != ProviderCategory.base: continue embedding_config = EmbeddingConfig( embedding_model=model.name, embedding_endpoint_type=model.model_endpoint_type, embedding_endpoint=provider.base_url or model.model_endpoint_type, embedding_dim=model.embedding_dim or 1536, embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE, handle=model.handle, ) embedding_models.append(embedding_config) # Get BYOK provider models - sync if not synced yet, then read from DB byok_providers = await self.provider_manager.list_providers_async( actor=actor, provider_category=[ProviderCategory.byok], ) for provider in byok_providers: try: # Get typed provider to access schema defaults (e.g., base_url) typed_provider = provider.cast_to_subtype() # Sync models if not synced yet if provider.last_synced is None: llm_models = await typed_provider.list_llm_models_async() emb_models = await typed_provider.list_embedding_models_async() await self.provider_manager.sync_provider_models_async( provider=provider, llm_models=llm_models, embedding_models=emb_models, organization_id=provider.organization_id, ) await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor) # Read from database provider_embedding_models = await self.provider_manager.list_models_async( actor=actor, model_type="embedding", provider_id=provider.id, enabled=True, ) for model in provider_embedding_models: embedding_config = EmbeddingConfig( embedding_model=model.name, embedding_endpoint_type=model.model_endpoint_type, embedding_endpoint=typed_provider.base_url, embedding_dim=model.embedding_dim or 1536, embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE, handle=model.handle, ) embedding_models.append(embedding_config) except Exception as e: logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}") # Sort by provider order (matching old _enabled_providers order), then by model name embedding_models.sort(key=self._get_embedding_provider_sort_key) return embedding_models async def get_enabled_providers_async( self, actor: User, provider_category: Optional[List[ProviderCategory]] = None, provider_name: Optional[str] = None, provider_type: Optional[ProviderType] = None, ) -> List[Provider]: # Query all persisted providers from database persisted_providers = await self.provider_manager.list_providers_async( name=provider_name, provider_type=provider_type, actor=actor, ) providers = [p.cast_to_subtype() for p in persisted_providers] # Filter by category if specified if provider_category: providers = [p for p in providers if p.provider_category in provider_category] return providers @trace_method async def get_llm_config_from_handle_async( self, actor: User, handle: str, context_window_limit: Optional[int] = None, max_tokens: Optional[int] = None, max_reasoning_tokens: Optional[int] = None, enable_reasoner: Optional[bool] = None, ) -> LLMConfig: # Use provider_manager to get LLMConfig from handle try: llm_config = await self.provider_manager.get_llm_config_from_handle( handle=handle, actor=actor, ) except Exception as e: # Convert to HandleNotFoundError for backwards compatibility from letta.orm.errors import NoResultFound if isinstance(e, NoResultFound): raise HandleNotFoundError(handle, []) raise if context_window_limit is not None: if context_window_limit > llm_config.context_window: raise LettaInvalidArgumentError( f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})", argument_name="context_window_limit", ) llm_config.context_window = context_window_limit else: llm_config.context_window = min(llm_config.context_window, model_settings.global_max_context_window_limit) if max_tokens is not None: llm_config.max_tokens = max_tokens if max_reasoning_tokens is not None: if not max_tokens or max_reasoning_tokens > max_tokens: raise LettaInvalidArgumentError( f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})", argument_name="max_reasoning_tokens", ) llm_config.max_reasoning_tokens = max_reasoning_tokens if enable_reasoner is not None: llm_config.enable_reasoner = enable_reasoner if enable_reasoner and llm_config.model_endpoint_type == "anthropic": llm_config.put_inner_thoughts_in_kwargs = False return llm_config @trace_method async def get_embedding_config_from_handle_async( self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE ) -> EmbeddingConfig: # Use provider_manager to get EmbeddingConfig from handle try: embedding_config = await self.provider_manager.get_embedding_config_from_handle( handle=handle, actor=actor, ) except Exception as e: # Convert to LettaInvalidArgumentError for backwards compatibility from letta.orm.errors import NoResultFound if isinstance(e, NoResultFound): raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle") raise # Override chunk size if provided embedding_config.embedding_chunk_size = embedding_chunk_size return embedding_config async def get_provider_from_name_async(self, provider_name: str, actor: User) -> Provider: all_providers = await self.get_enabled_providers_async(actor) providers = [provider for provider in all_providers if provider.name == provider_name] if not providers: raise LettaInvalidArgumentError( f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in all_providers])})", argument_name="provider_name", ) elif len(providers) > 1: logger.warning(f"Multiple providers with name {provider_name} supported") provider = providers[0] else: provider = providers[0] return provider def add_llm_model(self, request: LLMConfig) -> LLMConfig: """Add a new LLM model""" def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig: """Add a new embedding model""" async def run_tool_from_source( self, actor: User, tool_args: Dict[str, str], tool_source: str, tool_env_vars: Optional[Dict[str, str]] = None, tool_source_type: Optional[str] = None, tool_name: Optional[str] = None, tool_args_json_schema: Optional[Dict[str, Any]] = None, tool_json_schema: Optional[Dict[str, Any]] = None, pip_requirements: Optional[List[PipRequirement]] = None, ) -> ToolReturnMessage: """Run a tool from source code""" from letta.services.tool_schema_generator import generate_schema_for_tool_creation if tool_source_type not in (None, ToolSourceType.python, ToolSourceType.typescript): raise LettaInvalidArgumentError( f"Tool source type is not supported at this time. Found {tool_source_type}", argument_name="tool_source_type" ) # If tools_json_schema is explicitly passed in, override it on the created Tool object if tool_json_schema: tool = Tool( name=tool_name, source_code=tool_source, json_schema=tool_json_schema, pip_requirements=pip_requirements, source_type=tool_source_type, ) else: # NOTE: we're creating a floating Tool object and NOT persisting to DB tool = Tool( name=tool_name, source_code=tool_source, args_json_schema=tool_args_json_schema, pip_requirements=pip_requirements, source_type=tool_source_type, ) # try to get the schema if not tool.name: if not tool.json_schema: tool.json_schema = generate_schema_for_tool_creation(tool) tool.name = tool.json_schema.get("name") assert tool.name is not None, "Failed to create tool object" # TODO eventually allow using agent state in tools agent_state = None # Next, attempt to run the tool with the sandbox try: tool_execution_manager = ToolExecutionManager( agent_state=agent_state, message_manager=self.message_manager, agent_manager=self.agent_manager, block_manager=self.block_manager, run_manager=self.run_manager, passage_manager=self.passage_manager, actor=actor, sandbox_env_vars=tool_env_vars, ) # TODO: Integrate sandbox result tool_execution_result = await tool_execution_manager.execute_tool_async( function_name=tool_name, function_args=tool_args, tool=tool, ) from letta.schemas.letta_message import ToolReturn as ToolReturnSchema tool_return_obj = ToolReturnSchema( tool_return=str(tool_execution_result.func_return), status=tool_execution_result.status, tool_call_id="null", stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr, ) return ToolReturnMessage( id="null", tool_call_id="null", date=get_utc_time(), name=tool_name, status=tool_execution_result.status, tool_return=str(tool_execution_result.func_return), stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr, tool_returns=[tool_return_obj], ) except Exception as e: func_return = get_friendly_error_msg(function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e)) from letta.schemas.letta_message import ToolReturn as ToolReturnSchema tool_return_obj = ToolReturnSchema( tool_return=func_return, status="error", tool_call_id="null", stdout=[], stderr=[traceback.format_exc()], ) return ToolReturnMessage( id="null", tool_call_id="null", date=get_utc_time(), name=tool.name, status="error", tool_return=func_return, stdout=[], stderr=[traceback.format_exc()], tool_returns=[tool_return_obj], ) # MCP wrappers # TODO support both command + SSE servers (via config) async def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]: """List the MCP servers in the config (doesn't test that they are actually working)""" # TODO implement non-flatfile mechanism if not tool_settings.mcp_read_from_config: return {} # raise RuntimeError("MCP config file disabled. Enable it in settings.") mcp_server_list = {} # Attempt to read from ~/.letta/mcp_config.json mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME) if os.path.exists(mcp_config_path): def _read_config(): with open(mcp_config_path, "r") as f: return json.load(f) try: mcp_config = await asyncio.to_thread(_read_config) except Exception as e: logger.error(f"Failed to parse MCP config file ({mcp_config_path}) as json: {e}") return mcp_server_list # Proper formatting is "mcpServers" key at the top level, # then a dict with the MCP server name as the key, # with the value being the schema from StdioServerParameters if MCP_CONFIG_TOPLEVEL_KEY in mcp_config: for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items(): # No support for duplicate server names if server_name in mcp_server_list: logger.error(f"Duplicate MCP server name found (skipping): {server_name}") continue if "url" in server_params_raw: # Attempt to parse the server params as an SSE server try: server_params = SSEServerConfig( server_name=server_name, server_url=server_params_raw["url"], ) mcp_server_list[server_name] = server_params except Exception as e: logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}") continue else: # Attempt to parse the server params as a StdioServerParameters try: server_params = StdioServerConfig( server_name=server_name, command=server_params_raw["command"], args=server_params_raw.get("args", []), env=server_params_raw.get("env", {}), ) mcp_server_list[server_name] = server_params except Exception as e: logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}") continue # If the file doesn't exist, return empty dictionary return mcp_server_list async def get_tools_from_mcp_server(self, mcp_server_name: str) -> List[MCPTool]: """List the tools in an MCP server. Requires a client to be created.""" if mcp_server_name not in self.mcp_clients: raise LettaInvalidArgumentError(f"No client was created for MCP server: {mcp_server_name}", argument_name="mcp_server_name") tools = await self.mcp_clients[mcp_server_name].list_tools() # Add health information to each tool for tool in tools: if tool.inputSchema: health_status, reasons = validate_complete_json_schema(tool.inputSchema) tool.health = MCPToolHealth(status=health_status.value, reasons=reasons) return tools async def add_mcp_server_to_config( self, server_config: Union[SSEServerConfig, StdioServerConfig], allow_upsert: bool = True ) -> List[Union[SSEServerConfig, StdioServerConfig]]: """Add a new server config to the MCP config file""" # TODO implement non-flatfile mechanism if not tool_settings.mcp_read_from_config: raise RuntimeError("MCP config file disabled. Enable it in settings.") # If the config file doesn't exist, throw an error. mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME) if not os.path.exists(mcp_config_path): # Create the file if it doesn't exist logger.debug(f"MCP config file not found, creating new file at: {mcp_config_path}") # If the file does exist, attempt to parse it get calling get_mcp_servers try: current_mcp_servers = self.get_mcp_servers() except Exception as e: # Raise an error telling the user to fix the config file logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}") raise LettaInvalidArgumentError(f"Failed to parse MCP config file {mcp_config_path}") # Check if the server name is already in the config if server_config.server_name in current_mcp_servers and not allow_upsert: raise LettaInvalidArgumentError( f"Server name {server_config.server_name} is already in the config file", argument_name="server_name" ) # Attempt to initialize the connection to the server if server_config.type == MCPServerType.SSE: new_mcp_client = AsyncFastMCPSSEClient(server_config) elif server_config.type == MCPServerType.STDIO: new_mcp_client = AsyncStdioMCPClient(server_config) else: raise LettaInvalidArgumentError(f"Invalid MCP server config: {server_config}", argument_name="server_config") try: await new_mcp_client.connect_to_server() except LettaMCPConnectionError: raise except Exception: logger.exception(f"Failed to connect to MCP server: {server_config.server_name}") raise LettaMCPConnectionError( message=f"Failed to connect to MCP server: {server_config.server_name}", server_name=server_config.server_name, ) # Print out the tools that are connected logger.info(f"Attempting to fetch tools from MCP server: {server_config.server_name}") new_mcp_tools = await new_mcp_client.list_tools() logger.info(f"MCP tools connected: {', '.join([t.name for t in new_mcp_tools])}") logger.debug(f"MCP tools: {', '.join([str(t) for t in new_mcp_tools])}") # Now that we've confirmed the config is working, let's add it to the client list self.mcp_clients[server_config.server_name] = new_mcp_client # Add to the server file current_mcp_servers[server_config.server_name] = server_config # Write out the file, and make sure to in include the top-level mcpConfig (wrapped to avoid blocking event loop) try: new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}} def _write_config(): with open(mcp_config_path, "w") as f: json.dump(new_mcp_file, f, indent=4) await asyncio.to_thread(_write_config) except Exception as e: logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}") raise LettaInvalidArgumentError(f"Failed to write MCP config file {mcp_config_path}") return list(current_mcp_servers.values()) async def delete_mcp_server_from_config(self, server_name: str) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]: """Delete a server config from the MCP config file""" # TODO implement non-flatfile mechanism if not tool_settings.mcp_read_from_config: raise RuntimeError("MCP config file disabled. Enable it in settings.") # If the config file doesn't exist, throw an error. mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME) if not os.path.exists(mcp_config_path): # If the file doesn't exist, raise an error raise FileNotFoundError(f"MCP config file not found: {mcp_config_path}") # If the file does exist, attempt to parse it get calling get_mcp_servers try: current_mcp_servers = await self.get_mcp_servers() except Exception as e: # Raise an error telling the user to fix the config file logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}") raise LettaInvalidArgumentError(f"Failed to parse MCP config file {mcp_config_path}") # Check if the server name is already in the config # If it's not, throw an error if server_name not in current_mcp_servers: raise LettaInvalidArgumentError(f"Server name {server_name} not found in MCP config file", argument_name="server_name") # Remove from the server file del current_mcp_servers[server_name] # Write out the file, and make sure to in include the top-level mcpConfig (wrapped to avoid blocking event loop) try: new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}} def _write_config(): with open(mcp_config_path, "w") as f: json.dump(new_mcp_file, f, indent=4) await asyncio.to_thread(_write_config) except Exception as e: logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}") raise LettaInvalidArgumentError(f"Failed to write MCP config file {mcp_config_path}") return list(current_mcp_servers.values())