feat: Add query parameter for project ID on /upload route and adapt tests (#1210)

This commit is contained in:
Matthew Zhou
2025-03-06 15:44:12 -08:00
committed by GitHub
parent 5503135240
commit 021e28a98e
4 changed files with 30 additions and 9 deletions

View File

@@ -70,4 +70,11 @@ class SerializedAgentSchema(BaseSchema):
class Meta(BaseSchema.Meta):
model = Agent
# TODO: Serialize these as well...
exclude = BaseSchema.Meta.exclude + ("sources", "source_passages", "agent_passages")
exclude = BaseSchema.Meta.exclude + (
"project_id",
"template_id",
"base_template_id",
"sources",
"source_passages",
"agent_passages",
)

View File

@@ -119,6 +119,7 @@ async def upload_agent_serialized(
True,
description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.",
),
project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."),
):
"""
Upload a serialized agent JSON file and recreate the agent in the system.
@@ -129,7 +130,11 @@ async def upload_agent_serialized(
serialized_data = await file.read()
agent_json = json.loads(serialized_data)
new_agent = server.agent_manager.deserialize(
serialized_agent=agent_json, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools
serialized_agent=agent_json,
actor=actor,
append_copy_suffix=append_copy_suffix,
override_existing_tools=override_existing_tools,
project_id=project_id,
)
return new_agent

View File

@@ -444,7 +444,12 @@ class AgentManager:
@enforce_types
def deserialize(
self, serialized_agent: dict, actor: PydanticUser, append_copy_suffix: bool = True, override_existing_tools: bool = True
self,
serialized_agent: dict,
actor: PydanticUser,
append_copy_suffix: bool = True,
override_existing_tools: bool = True,
project_id: Optional[str] = None,
) -> PydanticAgentState:
tool_data_list = serialized_agent.pop("tools", [])
@@ -453,7 +458,9 @@ class AgentManager:
agent = schema.load(serialized_agent, session=session)
if append_copy_suffix:
agent.name += "_copy"
agent.create(session, actor=actor)
if project_id:
agent.project_id = project_id
agent = agent.create(session, actor=actor)
pydantic_agent = agent.to_pydantic()
# Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle

View File

@@ -229,7 +229,7 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log:
- Datetime fields are ignored.
- Order-independent comparison for lists of dicts.
"""
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id"}
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id", "project_id"}
# Remove datetime fields upfront
d1 = strip_datetime_fields(d1)
@@ -476,8 +476,9 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
# FastAPI endpoint tests
@pytest.mark.parametrize("append_copy_suffix", [True])
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix):
@pytest.mark.parametrize("append_copy_suffix", [True, False])
@pytest.mark.parametrize("project_id", ["project-12345", None])
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id):
"""
Test the full E2E serialization and deserialization flow using FastAPI endpoints.
"""
@@ -495,7 +496,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
upload_response = fastapi_client.post(
"/v1/agents/upload",
headers={"user_id": other_user.id},
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False},
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
files=files,
)
assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
@@ -504,7 +505,8 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
copied_agent = upload_response.json()
copied_agent_id = copied_agent["id"]
assert copied_agent_id != agent_id, "Copied agent should have a different ID"
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
if append_copy_suffix:
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
# Step 3: Retrieve the copied agent
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)