diff --git a/letta/errors.py b/letta/errors.py index d2e643e3..b057bc40 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -236,5 +236,33 @@ class AgentFileExportError(Exception): """Exception raised during agent file export operations""" +class AgentNotFoundForExportError(AgentFileExportError): + """Exception raised when requested agents are not found during export""" + + def __init__(self, missing_ids: List[str]): + self.missing_ids = missing_ids + super().__init__(f"The following agent IDs were not found: {missing_ids}") + + +class AgentExportIdMappingError(AgentFileExportError): + """Exception raised when ID mapping fails during export conversion""" + + def __init__(self, db_id: str, entity_type: str): + self.db_id = db_id + self.entity_type = entity_type + super().__init__( + f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. " + f"All IDs should have been mapped during agent processing." + ) + + +class AgentExportProcessingError(AgentFileExportError): + """Exception raised when general export processing fails""" + + def __init__(self, message: str, original_error: Optional[Exception] = None): + self.original_error = original_error + super().__init__(f"Export failed: {message}") + + class AgentFileImportError(Exception): """Exception raised during agent file import operations""" diff --git a/letta/schemas/agent_file.py b/letta/schemas/agent_file.py index d00615dd..1d87fa20 100644 --- a/letta/schemas/agent_file.py +++ b/letta/schemas/agent_file.py @@ -24,12 +24,14 @@ class ImportResult: success: bool, message: str = "", imported_count: int = 0, + imported_agent_ids: Optional[List[str]] = None, errors: Optional[List[str]] = None, id_mappings: Optional[Dict[str, str]] = None, ): self.success = success self.message = message self.imported_count = imported_count + self.imported_agent_ids = imported_agent_ids or [] self.errors = errors or [] self.id_mappings = id_mappings or {} diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 19eb17c2..048f083b 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -8,13 +8,14 @@ from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, from fastapi.responses import JSONResponse from marshmallow import ValidationError from orjson import orjson -from pydantic import Field +from pydantic import BaseModel, Field from sqlalchemy.exc import IntegrityError, OperationalError from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX from letta.data_sources.redis_client import get_redis_client +from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.log import get_logger @@ -22,6 +23,7 @@ from letta.orm.errors import NoResultFound from letta.otel.context import get_ctx_attributes from letta.otel.metric_registry import MetricRegistry from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent +from letta.schemas.agent_file import AgentFileSchema from letta.schemas.block import Block, BlockUpdate from letta.schemas.group import Group from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig @@ -144,29 +146,143 @@ class IndentedORJSONResponse(Response): @router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized") -def export_agent_serialized( +async def export_agent_serialized( agent_id: str, max_steps: int = 100, server: "SyncServer" = Depends(get_letta_server), actor_id: str | None = Header(None, alias="user_id"), + use_legacy_format: bool = Query( + True, + description="If true, exports using the legacy single-agent format. If false, exports using the new multi-entity format.", + ), # do not remove, used to autogeneration of spec - # TODO: Think of a better way to export AgentSchema - spec: AgentSchema | None = None, + # TODO: Think of a better way to export AgentFileSchema + spec: AgentFileSchema | None = None, + legacy_spec: AgentSchema | None = None, ) -> JSONResponse: """ Export the serialized JSON representation of an agent, formatted with indentation. + + Supports two export formats: + - Legacy format (use_legacy_format=true): Single agent with inline tools/blocks + - New format (default): Multi-entity format with separate agents, tools, blocks, files, etc. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) + if use_legacy_format: + # Use the legacy serialization method + try: + agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps) + return agent.model_dump() + except NoResultFound: + raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") + else: + # Use the new multi-entity export format + try: + agent_file_schema = await server.agent_serialization_manager.export(agent_ids=[agent_id], actor=actor) + return agent_file_schema.model_dump() + except AgentNotFoundForExportError: + raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") + except AgentExportIdMappingError as e: + raise HTTPException( + status_code=500, detail=f"Internal error during export: ID mapping failed for {e.entity_type} ID '{e.db_id}'" + ) + except AgentExportProcessingError as e: + raise HTTPException(status_code=500, detail=f"Export processing failed: {str(e.original_error)}") + + +class ImportedAgentsResponse(BaseModel): + """Response model for imported agents""" + + agent_ids: List[str] = Field(..., description="List of IDs of the imported agents") + + +def import_agent_legacy( + agent_json: dict, + server: "SyncServer", + actor: User, + append_copy_suffix: bool = True, + override_existing_tools: bool = True, + project_id: str | None = None, + strip_messages: bool = False, + env_vars: Optional[dict[str, Any]] = None, +) -> List[str]: + """ + Import an agent using the legacy AgentSchema format. + """ try: - agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps) - return agent.model_dump() - except NoResultFound: - raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") + # Validate the JSON against AgentSchema before passing it to deserialize + agent_schema = AgentSchema.model_validate(agent_json) + + new_agent = server.agent_manager.deserialize( + serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema + actor=actor, + append_copy_suffix=append_copy_suffix, + override_existing_tools=override_existing_tools, + project_id=project_id, + strip_messages=strip_messages, + env_vars=env_vars, + ) + return [new_agent.id] + + except ValidationError as e: + raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}") + + except IntegrityError as e: + raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}") + + except OperationalError as e: + raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}") + + except Exception as e: + traceback.print_exc() + raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}") -@router.post("/import", response_model=AgentState, operation_id="import_agent_serialized") -def import_agent_serialized( +async def import_agent( + agent_file_json: dict, + server: "SyncServer", + actor: User, + # TODO: Support these fields for new agent file + append_copy_suffix: bool = True, + override_existing_tools: bool = True, + project_id: str | None = None, + strip_messages: bool = False, +) -> List[str]: + """ + Import an agent using the new AgentFileSchema format. + """ + try: + agent_schema = AgentFileSchema.model_validate(agent_file_json) + except ValidationError as e: + raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}") + + try: + import_result = await server.agent_serialization_manager.import_file(schema=agent_schema, actor=actor) + + if not import_result.success: + raise HTTPException( + status_code=500, detail=f"Import failed: {import_result.message}. Errors: {', '.join(import_result.errors)}" + ) + + return import_result.imported_agent_ids + + except AgentFileImportError as e: + raise HTTPException(status_code=400, detail=f"Agent file import error: {str(e)}") + + except IntegrityError as e: + raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}") + + except OperationalError as e: + raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}") + + except Exception as e: + traceback.print_exc() + raise HTTPException(status_code=500, detail=f"An unexpected error occurred while importing agents: {e!s}") + + +@router.post("/import", response_model=ImportedAgentsResponse, operation_id="import_agent_serialized") +async def import_agent_serialized( file: UploadFile = File(...), server: "SyncServer" = Depends(get_letta_server), actor_id: str | None = Header(None, alias="user_id"), @@ -183,19 +299,35 @@ def import_agent_serialized( env_vars: Optional[Dict[str, Any]] = Form(None, description="Environment variables to pass to the agent for tool execution."), ): """ - Import a serialized agent file and recreate the agent in the system. + Import a serialized agent file and recreate the agent(s) in the system. + Returns the IDs of all imported agents. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) try: serialized_data = file.file.read() agent_json = json.loads(serialized_data) + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="Corrupted agent file format.") - # Validate the JSON against AgentSchema before passing it to deserialize - agent_schema = AgentSchema.model_validate(agent_json) - - new_agent = server.agent_manager.deserialize( - serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema + # Check if the JSON is AgentFileSchema or AgentSchema + # TODO: This is kind of hacky, but should work as long as dont' change the schema + if "agents" in agent_json and isinstance(agent_json.get("agents"), list): + # This is an AgentFileSchema + agent_ids = await import_agent( + agent_file_json=agent_json, + server=server, + actor=actor, + append_copy_suffix=append_copy_suffix, + override_existing_tools=override_existing_tools, + project_id=project_id, + strip_messages=strip_messages, + ) + else: + # This is a legacy AgentSchema + agent_ids = import_agent_legacy( + agent_json=agent_json, + server=server, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools, @@ -203,23 +335,8 @@ def import_agent_serialized( strip_messages=strip_messages, env_vars=env_vars, ) - return new_agent - except json.JSONDecodeError: - raise HTTPException(status_code=400, detail="Corrupted agent file format.") - - except ValidationError as e: - raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}") - - except IntegrityError as e: - raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}") - - except OperationalError as e: - raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}") - - except Exception as e: - traceback.print_exc() - raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}") + return ImportedAgentsResponse(agent_ids=agent_ids) @router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window") diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index a045d1a4..edccd7b6 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone from typing import Any, Dict, List, Optional from letta.constants import MCP_TOOL_TAG_NAME_PREFIX -from letta.errors import AgentFileExportError, AgentFileImportError +from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError from letta.helpers.pinecone_utils import should_use_pinecone from letta.log import get_logger from letta.schemas.agent import AgentState, CreateAgent @@ -118,10 +118,7 @@ class AgentSerializationManager: return self._db_to_file_ids[db_id] if not allow_new: - raise AgentFileExportError( - f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. " - f"All IDs should have been mapped during agent processing." - ) + raise AgentExportIdMappingError(db_id, entity_type) file_id = self._generate_file_id(entity_type) self._db_to_file_ids[db_id] = file_id @@ -352,7 +349,7 @@ class AgentSerializationManager: if len(agent_states) != len(agent_ids): found_ids = {agent.id for agent in agent_states} missing_ids = [agent_id for agent_id in agent_ids if agent_id not in found_ids] - raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}") + raise AgentNotFoundForExportError(missing_ids) groups = [] group_agent_ids = [] @@ -417,7 +414,7 @@ class AgentSerializationManager: except Exception as e: logger.error(f"Failed to export agent file: {e}") - raise AgentFileExportError(f"Export failed: {e}") from e + raise AgentExportProcessingError(str(e), e) from e async def import_file( self, @@ -657,6 +654,12 @@ class AgentSerializationManager: ) imported_count += len(files_for_agent) + # Extract the imported agent IDs (database IDs) + imported_agent_ids = [] + for agent_schema in schema.agents: + if agent_schema.id in file_to_db_ids: + imported_agent_ids.append(file_to_db_ids[agent_schema.id]) + for group in schema.groups: group_data = group.model_dump(exclude={"id"}) group_data["agent_ids"] = [file_to_db_ids[agent_id] for agent_id in group_data["agent_ids"]] @@ -670,6 +673,7 @@ class AgentSerializationManager: success=True, message=f"Import completed successfully. Imported {imported_count} entities.", imported_count=imported_count, + imported_agent_ids=imported_agent_ids, id_mappings=file_to_db_ids, ) diff --git a/tests/integration_test_pinecone_tool.py b/tests/integration_test_pinecone_tool.py index c9aecfb7..3caf4d5f 100644 --- a/tests/integration_test_pinecone_tool.py +++ b/tests/integration_test_pinecone_tool.py @@ -65,10 +65,12 @@ async def test_pinecone_tool(client: AsyncLetta) -> None: Test the Pinecone tool integration with the Letta client. """ with open("../../scripts/test-afs/knowledge-base.af", "rb") as f: - agent = await client.agents.import_file(file=f) + response = await client.agents.import_file(file=f) + + agent_id = response.agent_ids[0] agent = await client.agents.modify( - agent_id=agent.id, + agent_id=agent_id, tool_exec_environment_variables={ "PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"), "PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"), diff --git a/tests/mcp/mcp_config.json b/tests/mcp/mcp_config.json index 9e26dfee..0967ef42 100644 --- a/tests/mcp/mcp_config.json +++ b/tests/mcp/mcp_config.json @@ -1 +1 @@ -{} \ No newline at end of file +{} diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index 5be01735..7ea29ea5 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -655,14 +655,15 @@ def test_agent_download_upload_flow(server, server_url, serialize_test_agent, de # Sanity checks copied_agent = upload_response.json() - copied_agent_id = copied_agent["id"] + copied_agent_id = copied_agent["agent_ids"][0] assert copied_agent_id != agent_id, "Copied agent should have a different ID" + + agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user) if append_copy_suffix: - assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix" + assert agent_copy.name == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix" # Step 3: Retrieve the copied agent serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user) - agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user) print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json())) assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user) @@ -702,9 +703,8 @@ def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, oth assert response.status_code == 200, f"Failed to upload {filename}: {response.text}" json_response = response.json() - assert "id" in json_response and json_response["id"].startswith("agent-"), "Uploaded agent response is malformed" - copied_agent_id = json_response["id"] + copied_agent_id = json_response["agent_ids"][0] server.send_messages( actor=other_user, @@ -735,7 +735,7 @@ def test_serialize_with_max_steps(server, server_url, default_user, other_user): assert response.status_code == 200, f"Failed to upload agent: {response.text}" agent_data = response.json() - agent_id = agent_data["id"] + agent_id = agent_data["agent_ids"][0] # test with default max_steps (should use None, returning all messages) full_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user)