feat: Add query parameter for project ID on /upload route and adapt tests (#1210)
This commit is contained in:
@@ -70,4 +70,11 @@ class SerializedAgentSchema(BaseSchema):
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
exclude = BaseSchema.Meta.exclude + ("sources", "source_passages", "agent_passages")
|
||||
exclude = BaseSchema.Meta.exclude + (
|
||||
"project_id",
|
||||
"template_id",
|
||||
"base_template_id",
|
||||
"sources",
|
||||
"source_passages",
|
||||
"agent_passages",
|
||||
)
|
||||
|
||||
@@ -119,6 +119,7 @@ async def upload_agent_serialized(
|
||||
True,
|
||||
description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.",
|
||||
),
|
||||
project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."),
|
||||
):
|
||||
"""
|
||||
Upload a serialized agent JSON file and recreate the agent in the system.
|
||||
@@ -129,7 +130,11 @@ async def upload_agent_serialized(
|
||||
serialized_data = await file.read()
|
||||
agent_json = json.loads(serialized_data)
|
||||
new_agent = server.agent_manager.deserialize(
|
||||
serialized_agent=agent_json, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools
|
||||
serialized_agent=agent_json,
|
||||
actor=actor,
|
||||
append_copy_suffix=append_copy_suffix,
|
||||
override_existing_tools=override_existing_tools,
|
||||
project_id=project_id,
|
||||
)
|
||||
return new_agent
|
||||
|
||||
|
||||
@@ -444,7 +444,12 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
def deserialize(
|
||||
self, serialized_agent: dict, actor: PydanticUser, append_copy_suffix: bool = True, override_existing_tools: bool = True
|
||||
self,
|
||||
serialized_agent: dict,
|
||||
actor: PydanticUser,
|
||||
append_copy_suffix: bool = True,
|
||||
override_existing_tools: bool = True,
|
||||
project_id: Optional[str] = None,
|
||||
) -> PydanticAgentState:
|
||||
tool_data_list = serialized_agent.pop("tools", [])
|
||||
|
||||
@@ -453,7 +458,9 @@ class AgentManager:
|
||||
agent = schema.load(serialized_agent, session=session)
|
||||
if append_copy_suffix:
|
||||
agent.name += "_copy"
|
||||
agent.create(session, actor=actor)
|
||||
if project_id:
|
||||
agent.project_id = project_id
|
||||
agent = agent.create(session, actor=actor)
|
||||
pydantic_agent = agent.to_pydantic()
|
||||
|
||||
# Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle
|
||||
|
||||
@@ -229,7 +229,7 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log:
|
||||
- Datetime fields are ignored.
|
||||
- Order-independent comparison for lists of dicts.
|
||||
"""
|
||||
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id"}
|
||||
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id", "project_id"}
|
||||
|
||||
# Remove datetime fields upfront
|
||||
d1 = strip_datetime_fields(d1)
|
||||
@@ -476,8 +476,9 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
|
||||
# FastAPI endpoint tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize("append_copy_suffix", [True])
|
||||
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix):
|
||||
@pytest.mark.parametrize("append_copy_suffix", [True, False])
|
||||
@pytest.mark.parametrize("project_id", ["project-12345", None])
|
||||
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id):
|
||||
"""
|
||||
Test the full E2E serialization and deserialization flow using FastAPI endpoints.
|
||||
"""
|
||||
@@ -495,7 +496,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
upload_response = fastapi_client.post(
|
||||
"/v1/agents/upload",
|
||||
headers={"user_id": other_user.id},
|
||||
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False},
|
||||
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
|
||||
files=files,
|
||||
)
|
||||
assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
|
||||
@@ -504,7 +505,8 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
copied_agent = upload_response.json()
|
||||
copied_agent_id = copied_agent["id"]
|
||||
assert copied_agent_id != agent_id, "Copied agent should have a different ID"
|
||||
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
|
||||
if append_copy_suffix:
|
||||
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
|
||||
|
||||
# Step 3: Retrieve the copied agent
|
||||
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
Reference in New Issue
Block a user