From 7aff9aa659d076b3f1ff3fc73e8e7ceb0e3375c0 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Sun, 5 Oct 2025 15:09:31 -0700 Subject: [PATCH] feat: error handling for rest api for agents + blocks [LET-4625] (#5142) --- letta/errors.py | 28 +++ letta/server/rest_api/app.py | 53 ++++- letta/server/rest_api/routers/v1/agents.py | 260 +++++++-------------- letta/server/rest_api/routers/v1/blocks.py | 34 ++- letta/server/server.py | 75 +++--- letta/services/agent_manager.py | 5 +- letta/services/block_manager.py | 84 ++++++- tests/managers/test_block_manager.py | 28 +-- 8 files changed, 327 insertions(+), 240 deletions(-) diff --git a/letta/errors.py b/letta/errors.py index 1d154d31..1d753d35 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -97,6 +97,34 @@ class LettaUserNotFoundError(LettaError): """Error raised when a user is not found.""" +class LettaInvalidArgumentError(LettaError): + """Error raised when an invalid argument is provided.""" + + def __init__(self, message: str, argument_name: Optional[str] = None): + details = {"argument_name": argument_name} if argument_name else {} + super().__init__(message=message, code=ErrorCode.INVALID_ARGUMENT, details=details) + + +class LettaMCPError(LettaError): + """Base error for MCP-related issues.""" + + +class LettaMCPConnectionError(LettaMCPError): + """Error raised when unable to connect to MCP server.""" + + def __init__(self, message: str, server_name: Optional[str] = None): + details = {"server_name": server_name} if server_name else {} + super().__init__(message=message, code=ErrorCode.INTERNAL_SERVER_ERROR, details=details) + + +class LettaMCPTimeoutError(LettaMCPError): + """Error raised when MCP server operation times out.""" + + def __init__(self, message: str, server_name: Optional[str] = None): + details = {"server_name": server_name} if server_name else {} + super().__init__(message=message, code=ErrorCode.TIMEOUT, details=details) + + class LettaUnexpectedStreamCancellationError(LettaError): """Error raised when a streaming request is terminated unexpectedly.""" diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 6123c193..3cda0a64 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -12,14 +12,23 @@ from typing import Optional import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse +from marshmallow import ValidationError +from sqlalchemy.exc import IntegrityError, OperationalError from starlette.middleware.cors import CORSMiddleware from letta.__init__ import __version__ as letta_version from letta.agents.exceptions import IncompatibleAgentType from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.errors import ( + AgentExportIdMappingError, + AgentExportProcessingError, + AgentFileImportError, + AgentNotFoundForExportError, BedrockPermissionError, LettaAgentNotFoundError, + LettaInvalidArgumentError, + LettaMCPConnectionError, + LettaMCPTimeoutError, LettaToolCreateError, LettaToolNameConflictError, LettaUserNotFoundError, @@ -234,16 +243,41 @@ def create_application() -> "FastAPI": _error_handler_404 = partial(error_handler_with_code, code=404) _error_handler_404_agent = partial(_error_handler_404, detail="Agent not found") _error_handler_404_user = partial(_error_handler_404, detail="User not found") + _error_handler_408 = partial(error_handler_with_code, code=408) _error_handler_409 = partial(error_handler_with_code, code=409) + _error_handler_422 = partial(error_handler_with_code, code=422) + _error_handler_500 = partial(error_handler_with_code, code=500) + _error_handler_503 = partial(error_handler_with_code, code=503) - app.add_exception_handler(ValueError, _error_handler_400) + # 400 Bad Request errors + app.add_exception_handler(LettaInvalidArgumentError, _error_handler_400) + app.add_exception_handler(LettaToolCreateError, _error_handler_400) + app.add_exception_handler(LettaToolNameConflictError, _error_handler_400) + app.add_exception_handler(AgentFileImportError, _error_handler_400) + + # 404 Not Found errors app.add_exception_handler(NoResultFound, _error_handler_404) app.add_exception_handler(LettaAgentNotFoundError, _error_handler_404_agent) app.add_exception_handler(LettaUserNotFoundError, _error_handler_404_user) + app.add_exception_handler(AgentNotFoundForExportError, _error_handler_404) + + # 408 Timeout errors + app.add_exception_handler(LettaMCPTimeoutError, _error_handler_408) + + # 409 Conflict errors app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409) app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409) - app.add_exception_handler(LettaToolCreateError, _error_handler_400) - app.add_exception_handler(LettaToolNameConflictError, _error_handler_400) + app.add_exception_handler(IntegrityError, _error_handler_409) + + # 422 Validation errors + app.add_exception_handler(ValidationError, _error_handler_422) + + # 500 Internal Server errors + app.add_exception_handler(AgentExportIdMappingError, _error_handler_500) + app.add_exception_handler(AgentExportProcessingError, _error_handler_500) + + # 503 Service Unavailable errors + app.add_exception_handler(OperationalError, _error_handler_503) @app.exception_handler(IncompatibleAgentType) async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType): @@ -327,6 +361,19 @@ def create_application() -> "FastAPI": }, ) + @app.exception_handler(LettaMCPConnectionError) + async def mcp_connection_error_handler(request: Request, exc: LettaMCPConnectionError): + return JSONResponse( + status_code=502, + content={ + "error": { + "type": "mcp_connection_error", + "message": "Failed to connect to MCP server.", + "detail": str(exc), + } + }, + ) + @app.exception_handler(LLMError) async def llm_error_handler(request: Request, exc: LLMError): return JSONResponse( diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index f7860ca4..175da3b0 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -195,24 +195,12 @@ async def export_agent( if use_legacy_format: # Use the legacy serialization method - try: - agent = await server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps) - return agent.model_dump() - except NoResultFound: - raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") + agent = await server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps) + return agent.model_dump() else: # Use the new multi-entity export format - try: - agent_file_schema = await server.agent_serialization_manager.export(agent_ids=[agent_id], actor=actor) - return agent_file_schema.model_dump() - except AgentNotFoundForExportError: - raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") - except AgentExportIdMappingError as e: - raise HTTPException( - status_code=500, detail=f"Internal error during export: ID mapping failed for {e.entity_type} ID '{e.db_id}'" - ) - except AgentExportProcessingError as e: - raise HTTPException(status_code=500, detail=f"Export processing failed: {str(e.original_error)}") + agent_file_schema = await server.agent_serialization_manager.export(agent_ids=[agent_id], actor=actor) + return agent_file_schema.model_dump() class ImportedAgentsResponse(BaseModel): @@ -234,33 +222,19 @@ def import_agent_legacy( """ Import an agent using the legacy AgentSchema format. """ - try: - # Validate the JSON against AgentSchema before passing it to deserialize - agent_schema = AgentSchema.model_validate(agent_json) + # Validate the JSON against AgentSchema before passing it to deserialize + agent_schema = AgentSchema.model_validate(agent_json) - new_agent = server.agent_manager.deserialize( - serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema - actor=actor, - append_copy_suffix=append_copy_suffix, - override_existing_tools=override_existing_tools, - project_id=project_id, - strip_messages=strip_messages, - env_vars=env_vars, - ) - return [new_agent.id] - - except ValidationError as e: - raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}") - - except IntegrityError as e: - raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}") - - except OperationalError as e: - raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}") - - except Exception as e: - traceback.print_exc() - raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}") + new_agent = server.agent_manager.deserialize( + serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema + actor=actor, + append_copy_suffix=append_copy_suffix, + override_existing_tools=override_existing_tools, + project_id=project_id, + strip_messages=strip_messages, + env_vars=env_vars, + ) + return [new_agent.id] async def _import_agent( @@ -278,46 +252,29 @@ async def _import_agent( """ Import an agent using the new AgentFileSchema format. """ - try: - agent_schema = AgentFileSchema.model_validate(agent_file_json) - except ValidationError as e: - raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}") + agent_schema = AgentFileSchema.model_validate(agent_file_json) - try: - if override_embedding_handle: - embedding_config_override = await server.get_cached_embedding_config_async(actor=actor, handle=override_embedding_handle) - else: - embedding_config_override = None + if override_embedding_handle: + embedding_config_override = await server.get_cached_embedding_config_async(actor=actor, handle=override_embedding_handle) + else: + embedding_config_override = None - import_result = await server.agent_serialization_manager.import_file( - schema=agent_schema, - actor=actor, - append_copy_suffix=append_copy_suffix, - override_existing_tools=override_existing_tools, - env_vars=env_vars, - override_embedding_config=embedding_config_override, - project_id=project_id, - ) + import_result = await server.agent_serialization_manager.import_file( + schema=agent_schema, + actor=actor, + append_copy_suffix=append_copy_suffix, + override_existing_tools=override_existing_tools, + env_vars=env_vars, + override_embedding_config=embedding_config_override, + project_id=project_id, + ) - if not import_result.success: - raise HTTPException( - status_code=500, detail=f"Import failed: {import_result.message}. Errors: {', '.join(import_result.errors)}" - ) + if not import_result.success: + from letta.errors import AgentFileImportError - return import_result.imported_agent_ids + raise AgentFileImportError(f"Import failed: {import_result.message}. Errors: {', '.join(import_result.errors)}") - except AgentFileImportError as e: - raise HTTPException(status_code=400, detail=f"Agent file import error: {str(e)}") - - except IntegrityError as e: - raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}") - - except OperationalError as e: - raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}") - - except Exception as e: - traceback.print_exc() - raise HTTPException(status_code=500, detail=f"An unexpected error occurred while importing agents: {e!s}") + return import_result.imported_agent_ids @router.post("/import", response_model=ImportedAgentsResponse, operation_id="import_agent") @@ -405,11 +362,7 @@ async def retrieve_agent_context_window( Retrieve the context window of a specific agent. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - try: - return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor) - except Exception as e: - traceback.print_exc() - raise e + return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor) class CreateAgentRequest(CreateAgent): @@ -433,14 +386,10 @@ async def create_agent( """ Create an agent. """ - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - if headers.experimental_params.letta_v1_agent and agent.agent_type == AgentType.memgpt_v2_agent: - agent.agent_type = AgentType.letta_v1_agent - return await server.create_agent_async(agent, actor=actor) - except Exception as e: - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(e)) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + if headers.experimental_params.letta_v1_agent and agent.agent_type == AgentType.memgpt_v2_agent: + agent.agent_type = AgentType.letta_v1_agent + return await server.create_agent_async(agent, actor=actor) @router.patch("/{agent_id}", response_model=AgentState, operation_id="modify_agent") @@ -683,12 +632,9 @@ async def open_file( actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) # Get the agent to access files configuration - try: - per_file_view_window_char_limit, max_files_open = await server.agent_manager.get_agent_files_config_async( - agent_id=agent_id, actor=actor - ) - except ValueError: - raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found") + per_file_view_window_char_limit, max_files_open = await server.agent_manager.get_agent_files_config_async( + agent_id=agent_id, actor=actor + ) # Get file metadata file_metadata = await server.file_manager.get_file_by_id(file_id=file_id, actor=actor, include_content=True) @@ -734,16 +680,13 @@ async def close_file( actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) # Use update_file_agent_by_id to close the file - try: - await server.file_agent_manager.update_file_agent_by_id( - agent_id=agent_id, - file_id=file_id, - actor=actor, - is_open=False, - ) - return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"File id={file_id} successfully closed"}) - except NoResultFound: - raise HTTPException(status_code=404, detail=f"File association for file_id={file_id} and agent_id={agent_id} not found") + await server.file_agent_manager.update_file_agent_by_id( + agent_id=agent_id, + file_id=file_id, + actor=actor, + is_open=False, + ) + return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"File id={file_id} successfully closed"}) @router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent") @@ -769,10 +712,7 @@ async def retrieve_agent( actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - try: - return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, include_relationships=include_relationships, actor=actor) - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) + return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, include_relationships=include_relationships, actor=actor) @router.delete("/{agent_id}", response_model=None, operation_id="delete_agent") @@ -785,11 +725,8 @@ async def delete_agent( Delete an agent. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - try: - await server.agent_manager.delete_agent_async(agent_id=agent_id, actor=actor) - return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"}) - except NoResultFound: - raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.") + await server.agent_manager.delete_agent_async(agent_id=agent_id, actor=actor) + return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"}) @router.get("/{agent_id}/sources", response_model=list[Source], operation_id="list_agent_sources") @@ -889,10 +826,7 @@ async def retrieve_block( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - try: - return await server.agent_manager.get_block_with_label_async(agent_id=agent_id, block_label=block_label, actor=actor) - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) + return await server.agent_manager.get_block_with_label_async(agent_id=agent_id, block_label=block_label, actor=actor) @router.get("/{agent_id}/core-memory/blocks", response_model=list[Block], operation_id="list_core_memory_blocks") @@ -917,17 +851,14 @@ async def list_blocks( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - try: - return await server.agent_manager.list_agent_blocks_async( - agent_id=agent_id, - actor=actor, - before=before, - after=after, - limit=limit, - ascending=(order == "asc"), - ) - except NoResultFound: - raise HTTPException(status_code=404, detail="Agent not found") + return await server.agent_manager.list_agent_blocks_async( + agent_id=agent_id, + actor=actor, + before=before, + after=after, + limit=limit, + ascending=(order == "asc"), + ) @router.patch("/{agent_id}/core-memory/blocks/{block_label}", response_model=Block, operation_id="modify_core_memory_block") @@ -1050,34 +981,26 @@ async def search_archival_memory( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - try: - # convert datetime to string in ISO 8601 format - start_datetime = start_datetime.isoformat() if start_datetime else None - end_datetime = end_datetime.isoformat() if end_datetime else None + # convert datetime to string in ISO 8601 format + start_datetime = start_datetime.isoformat() if start_datetime else None + end_datetime = end_datetime.isoformat() if end_datetime else None - # Use the shared agent manager method - formatted_results = await server.agent_manager.search_agent_archival_memory_async( - agent_id=agent_id, - actor=actor, - query=query, - tags=tags, - tag_match_mode=tag_match_mode, - top_k=top_k, - start_datetime=start_datetime, - end_datetime=end_datetime, - ) + # Use the shared agent manager method + formatted_results = await server.agent_manager.search_agent_archival_memory_async( + agent_id=agent_id, + actor=actor, + query=query, + tags=tags, + tag_match_mode=tag_match_mode, + top_k=top_k, + start_datetime=start_datetime, + end_datetime=end_datetime, + ) - # Convert to proper response schema - search_results = [ArchivalMemorySearchResult(**result) for result in formatted_results] + # Convert to proper response schema + search_results = [ArchivalMemorySearchResult(**result) for result in formatted_results] - return ArchivalMemorySearchResponse(results=search_results, count=len(formatted_results)) - - except NoResultFound as e: - raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - raise HTTPException(status_code=500, detail=f"Internal server error during archival memory search: {str(e)}") + return ArchivalMemorySearchResponse(results=search_results, count=len(formatted_results)) # TODO(ethan): query or path parameter for memory_id? @@ -1572,21 +1495,18 @@ async def search_messages( if agent_count == 0: raise HTTPException(status_code=400, detail="No agents found in organization to derive embedding configuration from") - try: - results = await server.message_manager.search_messages_org_async( - actor=actor, - query_text=request.query, - search_mode=request.search_mode, - roles=request.roles, - project_id=request.project_id, - template_id=request.template_id, - limit=request.limit, - start_date=request.start_date, - end_date=request.end_date, - ) - return results - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + results = await server.message_manager.search_messages_org_async( + actor=actor, + query_text=request.query, + search_mode=request.search_mode, + roles=request.roles, + project_id=request.project_id, + template_id=request.template_id, + limit=request.limit, + start_date=request.start_date, + end_date=request.end_date, + ) + return results async def _process_message_background( diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index af9a8f86..6c8ac280 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -154,13 +154,10 @@ async def retrieve_block( headers: HeaderParams = Depends(get_headers), ): actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - try: - block = await server.block_manager.get_block_by_id_async(block_id=block_id, actor=actor) - if block is None: - raise HTTPException(status_code=404, detail="Block not found") - return block - except NoResultFound: - raise HTTPException(status_code=404, detail="Block not found") + block = await server.block_manager.get_block_by_id_async(block_id=block_id, actor=actor) + if block is None: + raise NoResultFound(f"Block with id '{block_id}' not found") + return block @router.get("/{block_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_block") @@ -195,16 +192,13 @@ async def list_agents_for_block( Raises a 404 if the block does not exist. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - try: - agents = await server.block_manager.get_agents_for_block_async( - block_id=block_id, - before=before, - after=after, - limit=limit, - ascending=(order == "asc"), - include_relationships=include_relationships, - actor=actor, - ) - return agents - except NoResultFound: - raise HTTPException(status_code=404, detail=f"Block with id={block_id} not found") + agents = await server.block_manager.get_agents_for_block_async( + block_id=block_id, + before=before, + after=after, + limit=limit, + ascending=(order == "asc"), + include_relationships=include_relationships, + actor=actor, + ) + return agents diff --git a/letta/server/server.py b/letta/server/server.py index b1be2dff..69c298a0 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -19,7 +19,7 @@ import letta.system as system from letta.config import LettaConfig from letta.constants import LETTA_TOOL_EXECUTION_DIR from letta.data_sources.connectors import DataConnector, load_data -from letta.errors import HandleNotFoundError +from letta.errors import HandleNotFoundError, LettaInvalidArgumentError, LettaMCPConnectionError, LettaMCPTimeoutError from letta.functions.mcp_client.types import MCPServerType, MCPTool, MCPToolHealth, SSEServerConfig, StdioServerConfig from letta.functions.schema_validator import validate_complete_json_schema from letta.groups.helpers import load_multi_agent @@ -363,7 +363,7 @@ class SyncServer(object): elif server_config.type == MCPServerType.STDIO: self.mcp_clients[server_name] = AsyncStdioMCPClient(server_config) else: - raise ValueError(f"Invalid MCP server config: {server_config}") + raise LettaInvalidArgumentError(f"Invalid MCP server config: {server_config}", argument_name="server_config") try: await self.mcp_clients[server_name].connect_to_server() @@ -416,7 +416,7 @@ class SyncServer(object): if request.llm_config is None: if request.model is None: if settings.default_llm_handle is None: - raise ValueError("Must specify either model or llm_config in request") + raise LettaInvalidArgumentError("Must specify either model or llm_config in request", argument_name="model") else: request.model = settings.default_llm_handle config_params = { @@ -436,7 +436,9 @@ class SyncServer(object): if request.embedding_config is None: if request.embedding is None: if settings.default_embedding_handle is None: - raise ValueError("Must specify either embedding or embedding_config in request") + raise LettaInvalidArgumentError( + "Must specify either embedding or embedding_config in request", argument_name="embedding" + ) else: request.embedding = settings.default_embedding_handle embedding_config_params = { @@ -760,7 +762,7 @@ class SyncServer(object): # TODO: move this into a thread source = await self.source_manager.get_source_by_id(source_id=source_id) if source is None: - raise ValueError(f"Source {source_id} does not exist") + raise NoResultFound(f"Source {source_id} does not exist") connector = DirectoryConnector(input_files=[file_path]) num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector) @@ -894,7 +896,7 @@ class SyncServer(object): actor = await self.user_manager.get_actor_by_id_async(actor_id=user_id) source = await self.source_manager.get_source_by_name(source_name=source_name, actor=actor) if source is None: - raise ValueError(f"Data source {source_name} does not exist for user {user_id}") + raise NoResultFound(f"Data source {source_name} does not exist for user {user_id}") # load data into the document store passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor) @@ -1042,13 +1044,18 @@ class SyncServer(object): if len(llm_configs) == 1: llm_config = llm_configs[0] elif len(llm_configs) > 1: - raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}") + raise LettaInvalidArgumentError( + f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name" + ) else: llm_config = llm_configs[0] if context_window_limit is not None: if context_window_limit > llm_config.context_window: - raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})") + raise LettaInvalidArgumentError( + f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})", + argument_name="context_window_limit", + ) llm_config.context_window = context_window_limit else: llm_config.context_window = min(llm_config.context_window, model_settings.global_max_context_window_limit) @@ -1057,7 +1064,10 @@ class SyncServer(object): llm_config.max_tokens = max_tokens if max_reasoning_tokens is not None: if not max_tokens or max_reasoning_tokens > max_tokens: - raise ValueError(f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})") + raise LettaInvalidArgumentError( + f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})", + argument_name="max_reasoning_tokens", + ) llm_config.max_reasoning_tokens = max_reasoning_tokens if enable_reasoner is not None: llm_config.enable_reasoner = enable_reasoner @@ -1077,8 +1087,10 @@ class SyncServer(object): all_embedding_configs = await provider.list_embedding_models_async() embedding_configs = [config for config in all_embedding_configs if config.handle == handle] if not embedding_configs: - raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}") - except ValueError as e: + raise LettaInvalidArgumentError( + f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name" + ) + except LettaInvalidArgumentError as e: # search local configs embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle] if not embedding_configs: @@ -1087,7 +1099,9 @@ class SyncServer(object): if len(embedding_configs) == 1: embedding_config = embedding_configs[0] elif len(embedding_configs) > 1: - raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}") + raise LettaInvalidArgumentError( + f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name" + ) else: embedding_config = embedding_configs[0] @@ -1100,11 +1114,12 @@ class SyncServer(object): all_providers = await self.get_enabled_providers_async(actor) providers = [provider for provider in all_providers if provider.name == provider_name] if not providers: - raise ValueError( - f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in self._enabled_providers])})" + raise LettaInvalidArgumentError( + f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in self._enabled_providers])})", + argument_name="provider_name", ) elif len(providers) > 1: - raise ValueError(f"Multiple providers with name {provider_name} supported") + raise LettaInvalidArgumentError(f"Multiple providers with name {provider_name} supported", argument_name="provider_name") else: provider = providers[0] @@ -1173,7 +1188,9 @@ class SyncServer(object): from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update if tool_source_type not in (None, ToolSourceType.python, ToolSourceType.typescript): - raise ValueError("Tool source type is not supported at this time. Found {tool_source_type}") + raise LettaInvalidArgumentError( + f"Tool source type is not supported at this time. Found {tool_source_type}", argument_name="tool_source_type" + ) # If tools_json_schema is explicitly passed in, override it on the created Tool object if tool_json_schema: @@ -1307,7 +1324,7 @@ class SyncServer(object): async def get_tools_from_mcp_server(self, mcp_server_name: str) -> List[MCPTool]: """List the tools in an MCP server. Requires a client to be created.""" if mcp_server_name not in self.mcp_clients: - raise ValueError(f"No client was created for MCP server: {mcp_server_name}") + raise LettaInvalidArgumentError(f"No client was created for MCP server: {mcp_server_name}", argument_name="mcp_server_name") tools = await self.mcp_clients[mcp_server_name].list_tools() # Add health information to each tool @@ -1339,11 +1356,13 @@ class SyncServer(object): except Exception as e: # Raise an error telling the user to fix the config file logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}") - raise ValueError(f"Failed to parse MCP config file {mcp_config_path}") + raise LettaInvalidArgumentError(f"Failed to parse MCP config file {mcp_config_path}") # Check if the server name is already in the config if server_config.server_name in current_mcp_servers and not allow_upsert: - raise ValueError(f"Server name {server_config.server_name} is already in the config file") + raise LettaInvalidArgumentError( + f"Server name {server_config.server_name} is already in the config file", argument_name="server_name" + ) # Attempt to initialize the connection to the server if server_config.type == MCPServerType.SSE: @@ -1351,7 +1370,7 @@ class SyncServer(object): elif server_config.type == MCPServerType.STDIO: new_mcp_client = AsyncStdioMCPClient(server_config) else: - raise ValueError(f"Invalid MCP server config: {server_config}") + raise LettaInvalidArgumentError(f"Invalid MCP server config: {server_config}", argument_name="server_config") try: await new_mcp_client.connect_to_server() except: @@ -1376,7 +1395,7 @@ class SyncServer(object): json.dump(new_mcp_file, f, indent=4) except Exception as e: logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}") - raise ValueError(f"Failed to write MCP config file {mcp_config_path}") + raise LettaInvalidArgumentError(f"Failed to write MCP config file {mcp_config_path}") return list(current_mcp_servers.values()) @@ -1399,12 +1418,12 @@ class SyncServer(object): except Exception as e: # Raise an error telling the user to fix the config file logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}") - raise ValueError(f"Failed to parse MCP config file {mcp_config_path}") + raise LettaInvalidArgumentError(f"Failed to parse MCP config file {mcp_config_path}") # Check if the server name is already in the config # If it's not, throw an error if server_name not in current_mcp_servers: - raise ValueError(f"Server name {server_name} not found in MCP config file") + raise LettaInvalidArgumentError(f"Server name {server_name} not found in MCP config file", argument_name="server_name") # Remove from the server file del current_mcp_servers[server_name] @@ -1416,7 +1435,7 @@ class SyncServer(object): json.dump(new_mcp_file, f, indent=4) except Exception as e: logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}") - raise ValueError(f"Failed to write MCP config file {mcp_config_path}") + raise LettaInvalidArgumentError(f"Failed to write MCP config file {mcp_config_path}") return list(current_mcp_servers.values()) @@ -1478,7 +1497,9 @@ class SyncServer(object): ) streaming_interface = letta_agent.interface if not isinstance(streaming_interface, StreamingServerInterface): - raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") + raise LettaInvalidArgumentError( + f"Agent has wrong type of interface: {type(streaming_interface)}", argument_name="interface" + ) # Enable token-streaming within the request if desired streaming_interface.streaming_mode = stream_tokens @@ -1583,7 +1604,7 @@ class SyncServer(object): ) -> Union[StreamingResponse, LettaResponse]: include_final_message = True if not stream_steps and stream_tokens: - raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'") + raise LettaInvalidArgumentError("stream_steps must be 'true' if stream_tokens is 'true'", argument_name="stream_steps") group = await self.group_manager.retrieve_group_async(group_id=group_id, actor=actor) agent_state_id = group.manager_agent_id or (group.agent_ids[0] if len(group.agent_ids) > 0 else None) @@ -1609,7 +1630,7 @@ class SyncServer(object): ) streaming_interface = letta_multi_agent.interface if not isinstance(streaming_interface, StreamingServerInterface): - raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") + raise LettaInvalidArgumentError(f"Agent has wrong type of interface: {type(streaming_interface)}", argument_name="interface") streaming_interface.streaming_mode = stream_tokens streaming_interface.streaming_chat_completion_mode = chat_completion_mode if metadata and hasattr(streaming_interface, "metadata"): diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 2e98dbec..5091b12d 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -74,7 +74,7 @@ from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.db import db_registry from letta.services.archive_manager import ArchiveManager -from letta.services.block_manager import BlockManager +from letta.services.block_manager import BlockManager, validate_block_limit_constraint from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter from letta.services.file_processor.chunker.line_chunker import LineChunker @@ -1661,6 +1661,9 @@ class AgentManager: update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + # Validate limit constraints before updating + validate_block_limit_constraint(update_data, block) + for key, value in update_data.items(): setattr(block, key, value) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index bf757bb4..8231f854 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -5,6 +5,7 @@ from typing import Dict, List, Optional from sqlalchemy import and_, delete, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession +from letta.errors import LettaInvalidArgumentError from letta.log import get_logger from letta.orm.agent import Agent as AgentModel from letta.orm.block import Block as BlockModel @@ -23,6 +24,62 @@ from letta.utils import enforce_types logger = get_logger(__name__) +def validate_block_limit_constraint(update_data: dict, existing_block: BlockModel) -> None: + """ + Validates that block limit constraints are satisfied when updating a block. + + Rules: + - If limit is being updated, it must be >= the length of the value (existing or new) + - If value is being updated, its length must not exceed the limit (existing or new) + + Args: + update_data: Dictionary of fields to update + existing_block: The current block being updated + + Raises: + LettaInvalidArgumentError: If validation fails + """ + # If limit is being updated, ensure it's >= current value length + if "limit" in update_data: + # Get the value that will be used (either from update_data or existing) + value_to_check = update_data.get("value", existing_block.value) + limit_to_check = update_data["limit"] + if value_to_check and limit_to_check < len(value_to_check): + raise LettaInvalidArgumentError( + f"Limit ({limit_to_check}) cannot be less than current value length ({len(value_to_check)} characters)", + argument_name="limit", + ) + # If value is being updated and there's an existing limit, ensure value doesn't exceed limit + elif "value" in update_data and existing_block.limit: + if len(update_data["value"]) > existing_block.limit: + raise LettaInvalidArgumentError( + f"Value length ({len(update_data['value'])} characters) exceeds block limit ({existing_block.limit} characters)", + argument_name="value", + ) + + +def validate_block_creation(block_data: dict) -> None: + """ + Validates that block limit constraints are satisfied when creating a block. + + Rules: + - If both value and limit are provided, limit must be >= value length + + Args: + block_data: Dictionary of block fields for creation + + Raises: + LettaInvalidArgumentError: If validation fails + """ + value = block_data.get("value") + limit = block_data.get("limit") + + if value and limit and len(value) > limit: + raise LettaInvalidArgumentError( + f"Block limit ({limit}) must be greater than or equal to value length ({len(value)} characters)", argument_name="limit" + ) + + class BlockManager: """Manager class to handle business logic related to Blocks.""" @@ -37,6 +94,8 @@ class BlockManager: else: async with db_registry.async_session() as session: data = block.model_dump(to_orm=True, exclude_none=True) + # Validate block creation constraints + validate_block_creation(data) block = BlockModel(**data, organization_id=actor.organization_id) await block.create_async(session, actor=actor, no_commit=True, no_refresh=True) pydantic_block = block.to_pydantic() @@ -58,6 +117,11 @@ class BlockManager: return [] async with db_registry.async_session() as session: + # Validate all blocks before creating any + for block in blocks: + block_data = block.model_dump(to_orm=True, exclude_none=True) + validate_block_creation(block_data) + block_models = [ BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks ] @@ -78,6 +142,9 @@ class BlockManager: block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + # Validate limit constraints before updating + validate_block_limit_constraint(update_data, block) + for key, value in update_data.items(): setattr(block, key, value) @@ -653,7 +720,7 @@ class BlockManager: ) if not block.current_history_entry_id: - raise ValueError(f"Block {block_id} has no history entry - cannot undo.") + raise LettaInvalidArgumentError(f"Block {block_id} has no history entry - cannot undo.", argument_name="block_id") current_entry = await session.get(BlockHistory, block.current_history_entry_id) if not current_entry: @@ -672,7 +739,10 @@ class BlockManager: previous_entry = result.scalar_one_or_none() if not previous_entry: # No earlier checkpoint available - raise ValueError(f"Block {block_id} is already at the earliest checkpoint (seq={current_seq}). Cannot undo further.") + raise LettaInvalidArgumentError( + f"Block {block_id} is already at the earliest checkpoint (seq={current_seq}). Cannot undo further.", + argument_name="block_id", + ) # 3) Move to that sequence block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor) @@ -699,11 +769,13 @@ class BlockManager: ) if not block.current_history_entry_id: - raise ValueError(f"Block {block_id} has no history entry - cannot redo.") + raise LettaInvalidArgumentError(f"Block {block_id} has no history entry - cannot redo.", argument_name="block_id") current_entry = await session.get(BlockHistory, block.current_history_entry_id) if not current_entry: - raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}") + raise LettaInvalidArgumentError( + f"BlockHistory row not found for id={block.current_history_entry_id}", argument_name="block_id" + ) current_seq = current_entry.sequence_number @@ -717,7 +789,9 @@ class BlockManager: result = await session.execute(stmt) next_entry = result.scalar_one_or_none() if not next_entry: - raise ValueError(f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.") + raise LettaInvalidArgumentError( + f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.", argument_name="block_id" + ) block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor) diff --git a/tests/managers/test_block_manager.py b/tests/managers/test_block_manager.py index 3d5aa801..de8f01e7 100644 --- a/tests/managers/test_block_manager.py +++ b/tests/managers/test_block_manager.py @@ -44,7 +44,7 @@ from letta.constants import ( MULTI_AGENT_TOOLS, ) from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client -from letta.errors import LettaAgentNotFoundError +from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.mcp_client.types import MCPTool from letta.helpers import ToolRulesSolver @@ -451,7 +451,7 @@ async def test_update_block_limit(server: SyncServer, default_user): update_data = BlockUpdate(value="Updated Content" * 2000, description="Updated description") # Check that exceeding the block limit raises an exception - with pytest.raises(ValueError): + with pytest.raises(LettaInvalidArgumentError): await block_manager.update_block_async(block_id=block.id, block_update=update_data, actor=default_user) # Ensure the update works when within limits @@ -908,8 +908,8 @@ async def test_undo_checkpoint_block(server: SyncServer, default_user): assert undone_block.label == "undo_test", "Label should also revert if changed (or remain the same if unchanged)" -#@pytest.mark.asyncio -#async def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default_user): +# @pytest.mark.asyncio +# async def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default_user): # """ # Verifies that once we've undone to an earlier checkpoint, creating a new # checkpoint removes any leftover 'future' states that existed beyond that sequence. @@ -980,7 +980,7 @@ async def test_undo_checkpoint_block(server: SyncServer, default_user): async def test_undo_no_history(server: SyncServer, default_user): """ If a block has never been checkpointed (no current_history_entry_id), - undo_checkpoint_block should raise a ValueError. + undo_checkpoint_block should raise a LettaInvalidArgumentError. """ block_manager = BlockManager() @@ -988,7 +988,7 @@ async def test_undo_no_history(server: SyncServer, default_user): block = await block_manager.create_or_update_block_async(PydanticBlock(label="no_history_test", value="initial"), actor=default_user) # Attempt to undo - with pytest.raises(ValueError, match="has no history entry - cannot undo"): + with pytest.raises(LettaInvalidArgumentError): await block_manager.undo_checkpoint_block(block_id=block.id, actor=default_user) @@ -1007,8 +1007,8 @@ async def test_undo_first_checkpoint(server: SyncServer, default_user): # 2) First checkpoint => seq=1 await block_manager.checkpoint_block_async(block_id=block.id, actor=default_user) - # Attempt undo -> expect ValueError - with pytest.raises(ValueError, match="Cannot undo further"): + # Attempt undo -> expect LettaInvalidArgumentError + with pytest.raises(LettaInvalidArgumentError): await block_manager.undo_checkpoint_block(block_id=block.id, actor=default_user) @@ -1048,7 +1048,7 @@ async def test_undo_multiple_checkpoints(server: SyncServer, default_user): assert undone_block.value == "v1" # Try once more -> fails because seq=1 is the earliest - with pytest.raises(ValueError, match="Cannot undo further"): + with pytest.raises(LettaInvalidArgumentError): await block_manager.undo_checkpoint_block(block_v1.id, actor=default_user) @@ -1145,15 +1145,15 @@ async def test_redo_checkpoint_block(server: SyncServer, default_user): async def test_redo_no_history(server: SyncServer, default_user): """ If a block has no current_history_entry_id (never checkpointed), - then redo_checkpoint_block should raise ValueError. + then redo_checkpoint_block should raise LettaInvalidArgumentError. """ block_manager = BlockManager() # Create block with no checkpoint block = await block_manager.create_or_update_block_async(PydanticBlock(label="redo_no_history", value="v0"), actor=default_user) - # Attempt to redo => expect ValueError - with pytest.raises(ValueError, match="no history entry - cannot redo"): + # Attempt to redo => expect LettaInvalidArgumentError + with pytest.raises(LettaInvalidArgumentError): await block_manager.redo_checkpoint_block(block.id, actor=default_user) @@ -1161,7 +1161,7 @@ async def test_redo_no_history(server: SyncServer, default_user): async def test_redo_at_highest_checkpoint(server: SyncServer, default_user): """ If the block is at the maximum sequence number, there's no higher checkpoint to move to. - redo_checkpoint_block should raise ValueError. + redo_checkpoint_block should raise LettaInvalidArgumentError. """ block_manager = BlockManager() @@ -1177,7 +1177,7 @@ async def test_redo_at_highest_checkpoint(server: SyncServer, default_user): # We are at seq=2, which is the highest checkpoint. # Attempt redo => there's no seq=3 - with pytest.raises(ValueError, match="Cannot redo further"): + with pytest.raises(LettaInvalidArgumentError): await block_manager.redo_checkpoint_block(b_init.id, actor=default_user)