From 75430d82a81b324205ed82f4d3a8fad9447de767 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 6 Mar 2025 15:44:12 -0800 Subject: [PATCH] feat: Add query parameter for project ID on `/upload` route and adapt tests (#1210) --- letta/serialize_schemas/agent.py | 9 ++++++++- letta/server/rest_api/routers/v1/agents.py | 7 ++++++- letta/services/agent_manager.py | 11 +++++++++-- tests/test_agent_serialization.py | 12 +++++++----- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/letta/serialize_schemas/agent.py b/letta/serialize_schemas/agent.py index 0baea8c2..7cef7d15 100644 --- a/letta/serialize_schemas/agent.py +++ b/letta/serialize_schemas/agent.py @@ -70,4 +70,11 @@ class SerializedAgentSchema(BaseSchema): class Meta(BaseSchema.Meta): model = Agent # TODO: Serialize these as well... - exclude = BaseSchema.Meta.exclude + ("sources", "source_passages", "agent_passages") + exclude = BaseSchema.Meta.exclude + ( + "project_id", + "template_id", + "base_template_id", + "sources", + "source_passages", + "agent_passages", + ) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 3af6033d..7d549b3b 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -119,6 +119,7 @@ async def upload_agent_serialized( True, description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.", ), + project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."), ): """ Upload a serialized agent JSON file and recreate the agent in the system. @@ -129,7 +130,11 @@ async def upload_agent_serialized( serialized_data = await file.read() agent_json = json.loads(serialized_data) new_agent = server.agent_manager.deserialize( - serialized_agent=agent_json, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools + serialized_agent=agent_json, + actor=actor, + append_copy_suffix=append_copy_suffix, + override_existing_tools=override_existing_tools, + project_id=project_id, ) return new_agent diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 8f92eaf2..2a6d8aaa 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -444,7 +444,12 @@ class AgentManager: @enforce_types def deserialize( - self, serialized_agent: dict, actor: PydanticUser, append_copy_suffix: bool = True, override_existing_tools: bool = True + self, + serialized_agent: dict, + actor: PydanticUser, + append_copy_suffix: bool = True, + override_existing_tools: bool = True, + project_id: Optional[str] = None, ) -> PydanticAgentState: tool_data_list = serialized_agent.pop("tools", []) @@ -453,7 +458,9 @@ class AgentManager: agent = schema.load(serialized_agent, session=session) if append_copy_suffix: agent.name += "_copy" - agent.create(session, actor=actor) + if project_id: + agent.project_id = project_id + agent = agent.create(session, actor=actor) pydantic_agent = agent.to_pydantic() # Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index facbd552..b7f6cc7d 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -229,7 +229,7 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log: - Datetime fields are ignored. - Order-independent comparison for lists of dicts. """ - ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id"} + ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id", "project_id"} # Remove datetime fields upfront d1 = strip_datetime_fields(d1) @@ -476,8 +476,9 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server, # FastAPI endpoint tests -@pytest.mark.parametrize("append_copy_suffix", [True]) -def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix): +@pytest.mark.parametrize("append_copy_suffix", [True, False]) +@pytest.mark.parametrize("project_id", ["project-12345", None]) +def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id): """ Test the full E2E serialization and deserialization flow using FastAPI endpoints. """ @@ -495,7 +496,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent upload_response = fastapi_client.post( "/v1/agents/upload", headers={"user_id": other_user.id}, - params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False}, + params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id}, files=files, ) assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}" @@ -504,7 +505,8 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent copied_agent = upload_response.json() copied_agent_id = copied_agent["id"] assert copied_agent_id != agent_id, "Copied agent should have a different ID" - assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix" + if append_copy_suffix: + assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix" # Step 3: Retrieve the copied agent serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)