feat: Adjust import/export agent endpoints to accept new agent file schema (#3506)

Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Shubham Naik <shub@letta.com>
This commit is contained in:
Matthew Zhou
2025-08-12 11:18:56 -07:00
committed by GitHub
parent df51974f31
commit 99902ff05e
7 changed files with 201 additions and 48 deletions

View File

@@ -236,5 +236,33 @@ class AgentFileExportError(Exception):
"""Exception raised during agent file export operations"""
class AgentNotFoundForExportError(AgentFileExportError):
"""Exception raised when requested agents are not found during export"""
def __init__(self, missing_ids: List[str]):
self.missing_ids = missing_ids
super().__init__(f"The following agent IDs were not found: {missing_ids}")
class AgentExportIdMappingError(AgentFileExportError):
"""Exception raised when ID mapping fails during export conversion"""
def __init__(self, db_id: str, entity_type: str):
self.db_id = db_id
self.entity_type = entity_type
super().__init__(
f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. "
f"All IDs should have been mapped during agent processing."
)
class AgentExportProcessingError(AgentFileExportError):
"""Exception raised when general export processing fails"""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(f"Export failed: {message}")
class AgentFileImportError(Exception):
"""Exception raised during agent file import operations"""

View File

@@ -24,12 +24,14 @@ class ImportResult:
success: bool,
message: str = "",
imported_count: int = 0,
imported_agent_ids: Optional[List[str]] = None,
errors: Optional[List[str]] = None,
id_mappings: Optional[Dict[str, str]] = None,
):
self.success = success
self.message = message
self.imported_count = imported_count
self.imported_agent_ids = imported_agent_ids or []
self.errors = errors or []
self.id_mappings = id_mappings or {}

View File

@@ -8,13 +8,14 @@ from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException,
from fastapi.responses import JSONResponse
from marshmallow import ValidationError
from orjson import orjson
from pydantic import Field
from pydantic import BaseModel, Field
from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
from letta.data_sources.redis_client import get_redis_client
from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.log import get_logger
@@ -22,6 +23,7 @@ from letta.orm.errors import NoResultFound
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.agent_file import AgentFileSchema
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.group import Group
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
@@ -144,29 +146,143 @@ class IndentedORJSONResponse(Response):
@router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized")
def export_agent_serialized(
async def export_agent_serialized(
agent_id: str,
max_steps: int = 100,
server: "SyncServer" = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"),
use_legacy_format: bool = Query(
True,
description="If true, exports using the legacy single-agent format. If false, exports using the new multi-entity format.",
),
# do not remove, used to autogeneration of spec
# TODO: Think of a better way to export AgentSchema
spec: AgentSchema | None = None,
# TODO: Think of a better way to export AgentFileSchema
spec: AgentFileSchema | None = None,
legacy_spec: AgentSchema | None = None,
) -> JSONResponse:
"""
Export the serialized JSON representation of an agent, formatted with indentation.
Supports two export formats:
- Legacy format (use_legacy_format=true): Single agent with inline tools/blocks
- New format (default): Multi-entity format with separate agents, tools, blocks, files, etc.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
if use_legacy_format:
# Use the legacy serialization method
try:
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps)
return agent.model_dump()
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
else:
# Use the new multi-entity export format
try:
agent_file_schema = await server.agent_serialization_manager.export(agent_ids=[agent_id], actor=actor)
return agent_file_schema.model_dump()
except AgentNotFoundForExportError:
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
except AgentExportIdMappingError as e:
raise HTTPException(
status_code=500, detail=f"Internal error during export: ID mapping failed for {e.entity_type} ID '{e.db_id}'"
)
except AgentExportProcessingError as e:
raise HTTPException(status_code=500, detail=f"Export processing failed: {str(e.original_error)}")
class ImportedAgentsResponse(BaseModel):
"""Response model for imported agents"""
agent_ids: List[str] = Field(..., description="List of IDs of the imported agents")
def import_agent_legacy(
agent_json: dict,
server: "SyncServer",
actor: User,
append_copy_suffix: bool = True,
override_existing_tools: bool = True,
project_id: str | None = None,
strip_messages: bool = False,
env_vars: Optional[dict[str, Any]] = None,
) -> List[str]:
"""
Import an agent using the legacy AgentSchema format.
"""
try:
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps)
return agent.model_dump()
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
# Validate the JSON against AgentSchema before passing it to deserialize
agent_schema = AgentSchema.model_validate(agent_json)
new_agent = server.agent_manager.deserialize(
serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema
actor=actor,
append_copy_suffix=append_copy_suffix,
override_existing_tools=override_existing_tools,
project_id=project_id,
strip_messages=strip_messages,
env_vars=env_vars,
)
return [new_agent.id]
except ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}")
except IntegrityError as e:
raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}")
except OperationalError as e:
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}")
@router.post("/import", response_model=AgentState, operation_id="import_agent_serialized")
def import_agent_serialized(
async def import_agent(
agent_file_json: dict,
server: "SyncServer",
actor: User,
# TODO: Support these fields for new agent file
append_copy_suffix: bool = True,
override_existing_tools: bool = True,
project_id: str | None = None,
strip_messages: bool = False,
) -> List[str]:
"""
Import an agent using the new AgentFileSchema format.
"""
try:
agent_schema = AgentFileSchema.model_validate(agent_file_json)
except ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}")
try:
import_result = await server.agent_serialization_manager.import_file(schema=agent_schema, actor=actor)
if not import_result.success:
raise HTTPException(
status_code=500, detail=f"Import failed: {import_result.message}. Errors: {', '.join(import_result.errors)}"
)
return import_result.imported_agent_ids
except AgentFileImportError as e:
raise HTTPException(status_code=400, detail=f"Agent file import error: {str(e)}")
except IntegrityError as e:
raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}")
except OperationalError as e:
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while importing agents: {e!s}")
@router.post("/import", response_model=ImportedAgentsResponse, operation_id="import_agent_serialized")
async def import_agent_serialized(
file: UploadFile = File(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"),
@@ -183,19 +299,35 @@ def import_agent_serialized(
env_vars: Optional[Dict[str, Any]] = Form(None, description="Environment variables to pass to the agent for tool execution."),
):
"""
Import a serialized agent file and recreate the agent in the system.
Import a serialized agent file and recreate the agent(s) in the system.
Returns the IDs of all imported agents.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
serialized_data = file.file.read()
agent_json = json.loads(serialized_data)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
# Validate the JSON against AgentSchema before passing it to deserialize
agent_schema = AgentSchema.model_validate(agent_json)
new_agent = server.agent_manager.deserialize(
serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema
# Check if the JSON is AgentFileSchema or AgentSchema
# TODO: This is kind of hacky, but should work as long as dont' change the schema
if "agents" in agent_json and isinstance(agent_json.get("agents"), list):
# This is an AgentFileSchema
agent_ids = await import_agent(
agent_file_json=agent_json,
server=server,
actor=actor,
append_copy_suffix=append_copy_suffix,
override_existing_tools=override_existing_tools,
project_id=project_id,
strip_messages=strip_messages,
)
else:
# This is a legacy AgentSchema
agent_ids = import_agent_legacy(
agent_json=agent_json,
server=server,
actor=actor,
append_copy_suffix=append_copy_suffix,
override_existing_tools=override_existing_tools,
@@ -203,23 +335,8 @@ def import_agent_serialized(
strip_messages=strip_messages,
env_vars=env_vars,
)
return new_agent
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
except ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}")
except IntegrityError as e:
raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}")
except OperationalError as e:
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}")
return ImportedAgentsResponse(agent_ids=agent_ids)
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
from letta.errors import AgentFileExportError, AgentFileImportError
from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.log import get_logger
from letta.schemas.agent import AgentState, CreateAgent
@@ -118,10 +118,7 @@ class AgentSerializationManager:
return self._db_to_file_ids[db_id]
if not allow_new:
raise AgentFileExportError(
f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. "
f"All IDs should have been mapped during agent processing."
)
raise AgentExportIdMappingError(db_id, entity_type)
file_id = self._generate_file_id(entity_type)
self._db_to_file_ids[db_id] = file_id
@@ -352,7 +349,7 @@ class AgentSerializationManager:
if len(agent_states) != len(agent_ids):
found_ids = {agent.id for agent in agent_states}
missing_ids = [agent_id for agent_id in agent_ids if agent_id not in found_ids]
raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}")
raise AgentNotFoundForExportError(missing_ids)
groups = []
group_agent_ids = []
@@ -417,7 +414,7 @@ class AgentSerializationManager:
except Exception as e:
logger.error(f"Failed to export agent file: {e}")
raise AgentFileExportError(f"Export failed: {e}") from e
raise AgentExportProcessingError(str(e), e) from e
async def import_file(
self,
@@ -657,6 +654,12 @@ class AgentSerializationManager:
)
imported_count += len(files_for_agent)
# Extract the imported agent IDs (database IDs)
imported_agent_ids = []
for agent_schema in schema.agents:
if agent_schema.id in file_to_db_ids:
imported_agent_ids.append(file_to_db_ids[agent_schema.id])
for group in schema.groups:
group_data = group.model_dump(exclude={"id"})
group_data["agent_ids"] = [file_to_db_ids[agent_id] for agent_id in group_data["agent_ids"]]
@@ -670,6 +673,7 @@ class AgentSerializationManager:
success=True,
message=f"Import completed successfully. Imported {imported_count} entities.",
imported_count=imported_count,
imported_agent_ids=imported_agent_ids,
id_mappings=file_to_db_ids,
)

View File

@@ -65,10 +65,12 @@ async def test_pinecone_tool(client: AsyncLetta) -> None:
Test the Pinecone tool integration with the Letta client.
"""
with open("../../scripts/test-afs/knowledge-base.af", "rb") as f:
agent = await client.agents.import_file(file=f)
response = await client.agents.import_file(file=f)
agent_id = response.agent_ids[0]
agent = await client.agents.modify(
agent_id=agent.id,
agent_id=agent_id,
tool_exec_environment_variables={
"PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"),
"PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"),

View File

@@ -1 +1 @@
{}
{}

View File

@@ -655,14 +655,15 @@ def test_agent_download_upload_flow(server, server_url, serialize_test_agent, de
# Sanity checks
copied_agent = upload_response.json()
copied_agent_id = copied_agent["id"]
copied_agent_id = copied_agent["agent_ids"][0]
assert copied_agent_id != agent_id, "Copied agent should have a different ID"
agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user)
if append_copy_suffix:
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
assert agent_copy.name == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
# Step 3: Retrieve the copied agent
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)
agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user)
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
@@ -702,9 +703,8 @@ def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, oth
assert response.status_code == 200, f"Failed to upload {filename}: {response.text}"
json_response = response.json()
assert "id" in json_response and json_response["id"].startswith("agent-"), "Uploaded agent response is malformed"
copied_agent_id = json_response["id"]
copied_agent_id = json_response["agent_ids"][0]
server.send_messages(
actor=other_user,
@@ -735,7 +735,7 @@ def test_serialize_with_max_steps(server, server_url, default_user, other_user):
assert response.status_code == 200, f"Failed to upload agent: {response.text}"
agent_data = response.json()
agent_id = agent_data["id"]
agent_id = agent_data["agent_ids"][0]
# test with default max_steps (should use None, returning all messages)
full_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user)