feat: Adjust import/export agent endpoints to accept new agent file schema (#3506)
Co-authored-by: Shubham Naik <shub@memgpt.ai> Co-authored-by: Shubham Naik <shub@letta.com>
This commit is contained in:
@@ -236,5 +236,33 @@ class AgentFileExportError(Exception):
|
||||
"""Exception raised during agent file export operations"""
|
||||
|
||||
|
||||
class AgentNotFoundForExportError(AgentFileExportError):
|
||||
"""Exception raised when requested agents are not found during export"""
|
||||
|
||||
def __init__(self, missing_ids: List[str]):
|
||||
self.missing_ids = missing_ids
|
||||
super().__init__(f"The following agent IDs were not found: {missing_ids}")
|
||||
|
||||
|
||||
class AgentExportIdMappingError(AgentFileExportError):
|
||||
"""Exception raised when ID mapping fails during export conversion"""
|
||||
|
||||
def __init__(self, db_id: str, entity_type: str):
|
||||
self.db_id = db_id
|
||||
self.entity_type = entity_type
|
||||
super().__init__(
|
||||
f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. "
|
||||
f"All IDs should have been mapped during agent processing."
|
||||
)
|
||||
|
||||
|
||||
class AgentExportProcessingError(AgentFileExportError):
|
||||
"""Exception raised when general export processing fails"""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(f"Export failed: {message}")
|
||||
|
||||
|
||||
class AgentFileImportError(Exception):
|
||||
"""Exception raised during agent file import operations"""
|
||||
|
||||
@@ -24,12 +24,14 @@ class ImportResult:
|
||||
success: bool,
|
||||
message: str = "",
|
||||
imported_count: int = 0,
|
||||
imported_agent_ids: Optional[List[str]] = None,
|
||||
errors: Optional[List[str]] = None,
|
||||
id_mappings: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.success = success
|
||||
self.message = message
|
||||
self.imported_count = imported_count
|
||||
self.imported_agent_ids = imported_agent_ids or []
|
||||
self.errors = errors or []
|
||||
self.id_mappings = id_mappings or {}
|
||||
|
||||
|
||||
@@ -8,13 +8,14 @@ from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException,
|
||||
from fastapi.responses import JSONResponse
|
||||
from marshmallow import ValidationError
|
||||
from orjson import orjson
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.log import get_logger
|
||||
@@ -22,6 +23,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.otel.context import get_ctx_attributes
|
||||
from letta.otel.metric_registry import MetricRegistry
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent_file import AgentFileSchema
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
@@ -144,29 +146,143 @@ class IndentedORJSONResponse(Response):
|
||||
|
||||
|
||||
@router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized")
|
||||
def export_agent_serialized(
|
||||
async def export_agent_serialized(
|
||||
agent_id: str,
|
||||
max_steps: int = 100,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
use_legacy_format: bool = Query(
|
||||
True,
|
||||
description="If true, exports using the legacy single-agent format. If false, exports using the new multi-entity format.",
|
||||
),
|
||||
# do not remove, used to autogeneration of spec
|
||||
# TODO: Think of a better way to export AgentSchema
|
||||
spec: AgentSchema | None = None,
|
||||
# TODO: Think of a better way to export AgentFileSchema
|
||||
spec: AgentFileSchema | None = None,
|
||||
legacy_spec: AgentSchema | None = None,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Export the serialized JSON representation of an agent, formatted with indentation.
|
||||
|
||||
Supports two export formats:
|
||||
- Legacy format (use_legacy_format=true): Single agent with inline tools/blocks
|
||||
- New format (default): Multi-entity format with separate agents, tools, blocks, files, etc.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
if use_legacy_format:
|
||||
# Use the legacy serialization method
|
||||
try:
|
||||
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps)
|
||||
return agent.model_dump()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||
else:
|
||||
# Use the new multi-entity export format
|
||||
try:
|
||||
agent_file_schema = await server.agent_serialization_manager.export(agent_ids=[agent_id], actor=actor)
|
||||
return agent_file_schema.model_dump()
|
||||
except AgentNotFoundForExportError:
|
||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||
except AgentExportIdMappingError as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Internal error during export: ID mapping failed for {e.entity_type} ID '{e.db_id}'"
|
||||
)
|
||||
except AgentExportProcessingError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Export processing failed: {str(e.original_error)}")
|
||||
|
||||
|
||||
class ImportedAgentsResponse(BaseModel):
|
||||
"""Response model for imported agents"""
|
||||
|
||||
agent_ids: List[str] = Field(..., description="List of IDs of the imported agents")
|
||||
|
||||
|
||||
def import_agent_legacy(
|
||||
agent_json: dict,
|
||||
server: "SyncServer",
|
||||
actor: User,
|
||||
append_copy_suffix: bool = True,
|
||||
override_existing_tools: bool = True,
|
||||
project_id: str | None = None,
|
||||
strip_messages: bool = False,
|
||||
env_vars: Optional[dict[str, Any]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Import an agent using the legacy AgentSchema format.
|
||||
"""
|
||||
try:
|
||||
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor, max_steps=max_steps)
|
||||
return agent.model_dump()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||
# Validate the JSON against AgentSchema before passing it to deserialize
|
||||
agent_schema = AgentSchema.model_validate(agent_json)
|
||||
|
||||
new_agent = server.agent_manager.deserialize(
|
||||
serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema
|
||||
actor=actor,
|
||||
append_copy_suffix=append_copy_suffix,
|
||||
override_existing_tools=override_existing_tools,
|
||||
project_id=project_id,
|
||||
strip_messages=strip_messages,
|
||||
env_vars=env_vars,
|
||||
)
|
||||
return [new_agent.id]
|
||||
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}")
|
||||
|
||||
except IntegrityError as e:
|
||||
raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}")
|
||||
|
||||
except OperationalError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}")
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}")
|
||||
|
||||
|
||||
@router.post("/import", response_model=AgentState, operation_id="import_agent_serialized")
|
||||
def import_agent_serialized(
|
||||
async def import_agent(
|
||||
agent_file_json: dict,
|
||||
server: "SyncServer",
|
||||
actor: User,
|
||||
# TODO: Support these fields for new agent file
|
||||
append_copy_suffix: bool = True,
|
||||
override_existing_tools: bool = True,
|
||||
project_id: str | None = None,
|
||||
strip_messages: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Import an agent using the new AgentFileSchema format.
|
||||
"""
|
||||
try:
|
||||
agent_schema = AgentFileSchema.model_validate(agent_file_json)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}")
|
||||
|
||||
try:
|
||||
import_result = await server.agent_serialization_manager.import_file(schema=agent_schema, actor=actor)
|
||||
|
||||
if not import_result.success:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Import failed: {import_result.message}. Errors: {', '.join(import_result.errors)}"
|
||||
)
|
||||
|
||||
return import_result.imported_agent_ids
|
||||
|
||||
except AgentFileImportError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Agent file import error: {str(e)}")
|
||||
|
||||
except IntegrityError as e:
|
||||
raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}")
|
||||
|
||||
except OperationalError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}")
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while importing agents: {e!s}")
|
||||
|
||||
|
||||
@router.post("/import", response_model=ImportedAgentsResponse, operation_id="import_agent_serialized")
|
||||
async def import_agent_serialized(
|
||||
file: UploadFile = File(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
@@ -183,19 +299,35 @@ def import_agent_serialized(
|
||||
env_vars: Optional[Dict[str, Any]] = Form(None, description="Environment variables to pass to the agent for tool execution."),
|
||||
):
|
||||
"""
|
||||
Import a serialized agent file and recreate the agent in the system.
|
||||
Import a serialized agent file and recreate the agent(s) in the system.
|
||||
Returns the IDs of all imported agents.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
serialized_data = file.file.read()
|
||||
agent_json = json.loads(serialized_data)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
|
||||
|
||||
# Validate the JSON against AgentSchema before passing it to deserialize
|
||||
agent_schema = AgentSchema.model_validate(agent_json)
|
||||
|
||||
new_agent = server.agent_manager.deserialize(
|
||||
serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema
|
||||
# Check if the JSON is AgentFileSchema or AgentSchema
|
||||
# TODO: This is kind of hacky, but should work as long as dont' change the schema
|
||||
if "agents" in agent_json and isinstance(agent_json.get("agents"), list):
|
||||
# This is an AgentFileSchema
|
||||
agent_ids = await import_agent(
|
||||
agent_file_json=agent_json,
|
||||
server=server,
|
||||
actor=actor,
|
||||
append_copy_suffix=append_copy_suffix,
|
||||
override_existing_tools=override_existing_tools,
|
||||
project_id=project_id,
|
||||
strip_messages=strip_messages,
|
||||
)
|
||||
else:
|
||||
# This is a legacy AgentSchema
|
||||
agent_ids = import_agent_legacy(
|
||||
agent_json=agent_json,
|
||||
server=server,
|
||||
actor=actor,
|
||||
append_copy_suffix=append_copy_suffix,
|
||||
override_existing_tools=override_existing_tools,
|
||||
@@ -203,23 +335,8 @@ def import_agent_serialized(
|
||||
strip_messages=strip_messages,
|
||||
env_vars=env_vars,
|
||||
)
|
||||
return new_agent
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
|
||||
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}")
|
||||
|
||||
except IntegrityError as e:
|
||||
raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}")
|
||||
|
||||
except OperationalError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}")
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}")
|
||||
return ImportedAgentsResponse(agent_ids=agent_ids)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
|
||||
from letta.errors import AgentFileExportError, AgentFileImportError
|
||||
from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
@@ -118,10 +118,7 @@ class AgentSerializationManager:
|
||||
return self._db_to_file_ids[db_id]
|
||||
|
||||
if not allow_new:
|
||||
raise AgentFileExportError(
|
||||
f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. "
|
||||
f"All IDs should have been mapped during agent processing."
|
||||
)
|
||||
raise AgentExportIdMappingError(db_id, entity_type)
|
||||
|
||||
file_id = self._generate_file_id(entity_type)
|
||||
self._db_to_file_ids[db_id] = file_id
|
||||
@@ -352,7 +349,7 @@ class AgentSerializationManager:
|
||||
if len(agent_states) != len(agent_ids):
|
||||
found_ids = {agent.id for agent in agent_states}
|
||||
missing_ids = [agent_id for agent_id in agent_ids if agent_id not in found_ids]
|
||||
raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}")
|
||||
raise AgentNotFoundForExportError(missing_ids)
|
||||
|
||||
groups = []
|
||||
group_agent_ids = []
|
||||
@@ -417,7 +414,7 @@ class AgentSerializationManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export agent file: {e}")
|
||||
raise AgentFileExportError(f"Export failed: {e}") from e
|
||||
raise AgentExportProcessingError(str(e), e) from e
|
||||
|
||||
async def import_file(
|
||||
self,
|
||||
@@ -657,6 +654,12 @@ class AgentSerializationManager:
|
||||
)
|
||||
imported_count += len(files_for_agent)
|
||||
|
||||
# Extract the imported agent IDs (database IDs)
|
||||
imported_agent_ids = []
|
||||
for agent_schema in schema.agents:
|
||||
if agent_schema.id in file_to_db_ids:
|
||||
imported_agent_ids.append(file_to_db_ids[agent_schema.id])
|
||||
|
||||
for group in schema.groups:
|
||||
group_data = group.model_dump(exclude={"id"})
|
||||
group_data["agent_ids"] = [file_to_db_ids[agent_id] for agent_id in group_data["agent_ids"]]
|
||||
@@ -670,6 +673,7 @@ class AgentSerializationManager:
|
||||
success=True,
|
||||
message=f"Import completed successfully. Imported {imported_count} entities.",
|
||||
imported_count=imported_count,
|
||||
imported_agent_ids=imported_agent_ids,
|
||||
id_mappings=file_to_db_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,10 +65,12 @@ async def test_pinecone_tool(client: AsyncLetta) -> None:
|
||||
Test the Pinecone tool integration with the Letta client.
|
||||
"""
|
||||
with open("../../scripts/test-afs/knowledge-base.af", "rb") as f:
|
||||
agent = await client.agents.import_file(file=f)
|
||||
response = await client.agents.import_file(file=f)
|
||||
|
||||
agent_id = response.agent_ids[0]
|
||||
|
||||
agent = await client.agents.modify(
|
||||
agent_id=agent.id,
|
||||
agent_id=agent_id,
|
||||
tool_exec_environment_variables={
|
||||
"PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"),
|
||||
"PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"),
|
||||
|
||||
@@ -1 +1 @@
|
||||
{}
|
||||
{}
|
||||
|
||||
@@ -655,14 +655,15 @@ def test_agent_download_upload_flow(server, server_url, serialize_test_agent, de
|
||||
|
||||
# Sanity checks
|
||||
copied_agent = upload_response.json()
|
||||
copied_agent_id = copied_agent["id"]
|
||||
copied_agent_id = copied_agent["agent_ids"][0]
|
||||
assert copied_agent_id != agent_id, "Copied agent should have a different ID"
|
||||
|
||||
agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user)
|
||||
if append_copy_suffix:
|
||||
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
|
||||
assert agent_copy.name == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
|
||||
|
||||
# Step 3: Retrieve the copied agent
|
||||
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user)
|
||||
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
|
||||
@@ -702,9 +703,8 @@ def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, oth
|
||||
|
||||
assert response.status_code == 200, f"Failed to upload {filename}: {response.text}"
|
||||
json_response = response.json()
|
||||
assert "id" in json_response and json_response["id"].startswith("agent-"), "Uploaded agent response is malformed"
|
||||
|
||||
copied_agent_id = json_response["id"]
|
||||
copied_agent_id = json_response["agent_ids"][0]
|
||||
|
||||
server.send_messages(
|
||||
actor=other_user,
|
||||
@@ -735,7 +735,7 @@ def test_serialize_with_max_steps(server, server_url, default_user, other_user):
|
||||
|
||||
assert response.status_code == 200, f"Failed to upload agent: {response.text}"
|
||||
agent_data = response.json()
|
||||
agent_id = agent_data["id"]
|
||||
agent_id = agent_data["agent_ids"][0]
|
||||
|
||||
# test with default max_steps (should use None, returning all messages)
|
||||
full_result = server.agent_manager.serialize(agent_id=agent_id, actor=default_user)
|
||||
|
||||
Reference in New Issue
Block a user