feat: Support override tool functionality to agent file v2 (#4092)

This commit is contained in:
Matthew Zhou
2025-08-21 16:58:19 -07:00
committed by GitHub
parent 223c883205
commit b3704e47be
4 changed files with 193 additions and 22 deletions

View File

@@ -260,7 +260,9 @@ async def import_agent(
raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}")
try:
import_result = await server.agent_serialization_manager.import_file(schema=agent_schema, actor=actor)
import_result = await server.agent_serialization_manager.import_file(
schema=agent_schema, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools
)
if not import_result.success:
raise HTTPException(

View File

@@ -428,6 +428,8 @@ class AgentSerializationManager:
self,
schema: AgentFileSchema,
actor: User,
append_copy_suffix: bool = False,
override_existing_tools: bool = True,
dry_run: bool = False,
env_vars: Optional[Dict[str, Any]] = None,
) -> ImportResult:
@@ -489,7 +491,9 @@ class AgentSerializationManager:
pydantic_tools.append(Tool(**tool_schema.model_dump(exclude={"id"})))
# bulk upsert all tools at once
created_tools = await self.tool_manager.bulk_upsert_tools_async(pydantic_tools, actor)
created_tools = await self.tool_manager.bulk_upsert_tools_async(
pydantic_tools, actor, override_existing_tools=override_existing_tools
)
# map file ids to database ids
# note: tools are matched by name during upsert, so we need to match by name here too
@@ -611,6 +615,8 @@ class AgentSerializationManager:
for agent_schema in schema.agents:
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})
if append_copy_suffix:
agent_data["name"] = agent_data.get("name") + "_copy"
# Remap tool_ids from file IDs to database IDs
if agent_data.get("tool_ids"):

View File

@@ -184,7 +184,9 @@ class ToolManager:
@enforce_types
@trace_method
async def bulk_upsert_tools_async(self, pydantic_tools: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
async def bulk_upsert_tools_async(
self, pydantic_tools: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
) -> List[PydanticTool]:
"""
Bulk create or update multiple tools in a single database transaction.
@@ -227,10 +229,10 @@ class ToolManager:
if settings.letta_pg_uri_no_default:
# use optimized postgresql bulk upsert
async with db_registry.async_session() as session:
return await self._bulk_upsert_postgresql(session, pydantic_tools, actor)
return await self._bulk_upsert_postgresql(session, pydantic_tools, actor, override_existing_tools)
else:
# fallback to individual upserts for sqlite
return await self._upsert_tools_individually(pydantic_tools, actor)
return await self._upsert_tools_individually(pydantic_tools, actor, override_existing_tools)
@enforce_types
@trace_method
@@ -784,8 +786,10 @@ class ToolManager:
return await self._upsert_tools_individually(tool_data_list, actor)
@trace_method
async def _bulk_upsert_postgresql(self, session, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
"""hyper-optimized postgresql bulk upsert using on_conflict_do_update."""
async def _bulk_upsert_postgresql(
self, session, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
) -> List[PydanticTool]:
"""hyper-optimized postgresql bulk upsert using on_conflict_do_update or on_conflict_do_nothing."""
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert
@@ -809,32 +813,51 @@ class ToolManager:
# use postgresql's native bulk upsert
stmt = insert(table).values(insert_data)
# on conflict, update all columns except id, created_at, and _created_by_id
excluded = stmt.excluded
update_dict = {}
for col in table.columns:
if col.name not in ("id", "created_at", "_created_by_id"):
if col.name == "updated_at":
update_dict[col.name] = func.now()
else:
update_dict[col.name] = excluded[col.name]
if override_existing_tools:
# on conflict, update all columns except id, created_at, and _created_by_id
excluded = stmt.excluded
update_dict = {}
for col in table.columns:
if col.name not in ("id", "created_at", "_created_by_id"):
if col.name == "updated_at":
update_dict[col.name] = func.now()
else:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
else:
# on conflict, do nothing (skip existing tools)
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"])
await session.execute(upsert_stmt)
await session.commit()
# fetch results
# fetch results (includes both inserted and skipped tools)
tool_names = [tool.name for tool in tool_data_list]
result_query = select(ToolModel).where(ToolModel.name.in_(tool_names), ToolModel.organization_id == actor.organization_id)
result = await session.execute(result_query)
return [tool.to_pydantic() for tool in result.scalars()]
@trace_method
async def _upsert_tools_individually(self, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
async def _upsert_tools_individually(
self, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
) -> List[PydanticTool]:
"""fallback to individual upserts for sqlite (original approach)."""
tools = []
for tool in tool_data_list:
upserted_tool = await self.create_or_update_tool_async(tool, actor)
tools.append(upserted_tool)
if override_existing_tools:
# update existing tools if they exist
upserted_tool = await self.create_or_update_tool_async(tool, actor)
tools.append(upserted_tool)
else:
# skip existing tools, only create new ones
existing_tool_id = await self.get_tool_id_by_name_async(tool_name=tool.name, actor=actor)
if existing_tool_id:
# tool exists, fetch and return it without updating
existing_tool = await self.get_tool_by_id_async(existing_tool_id, actor=actor)
tools.append(existing_tool)
else:
# tool doesn't exist, create it
created_tool = await self.create_tool_async(tool, actor=actor)
tools.append(created_tool)
return tools

View File

@@ -3969,7 +3969,7 @@ async def test_bulk_upsert_tools_name_conflict(server: SyncServer, default_user)
name="unique_name_tool",
description="Original description",
tags=["original"],
source_code="def unique_name_tool():\n '''Original function'''\n return 'original'",
source_code="def unique_name_tool():\n '''Original function'''\n return 'original'`",
source_type="python",
)
@@ -4087,6 +4087,146 @@ async def test_bulk_upsert_tools_mixed_create_update(server: SyncServer, default
assert tool_2.tags == ["existing"]
@pytest.mark.asyncio
async def test_bulk_upsert_tools_override_existing_true(server: SyncServer, default_user):
"""Test bulk_upsert_tools_async with override_existing_tools=True (default behavior)"""
# create some existing tools
existing_tool = PydanticTool(
name="test_override_tool",
description="Original description",
tags=["original"],
source_code="def test_override_tool():\n '''Original'''\n return 'original'",
source_type="python",
)
created = await server.tool_manager.create_tool_async(existing_tool, default_user)
original_id = created.id
# prepare updated version of the tool
updated_tool = PydanticTool(
name="test_override_tool",
description="Updated description",
tags=["updated"],
source_code="def test_override_tool():\n '''Updated'''\n return 'updated'",
source_type="python",
)
# bulk upsert with override_existing_tools=True (default)
result = await server.tool_manager.bulk_upsert_tools_async([updated_tool], default_user, override_existing_tools=True)
assert len(result) == 1
assert result[0].id == original_id # id should remain the same
assert result[0].description == "Updated description" # description should be updated
assert result[0].tags == ["updated"] # tags should be updated
# verify the tool was actually updated in the database
fetched = await server.tool_manager.get_tool_by_id_async(original_id, default_user)
assert fetched.description == "Updated description"
assert fetched.tags == ["updated"]
@pytest.mark.asyncio
async def test_bulk_upsert_tools_override_existing_false(server: SyncServer, default_user):
"""Test bulk_upsert_tools_async with override_existing_tools=False (skip existing)"""
# create some existing tools
existing_tool = PydanticTool(
name="test_no_override_tool",
description="Original description",
tags=["original"],
source_code="def test_no_override_tool():\n '''Original'''\n return 'original'",
source_type="python",
)
created = await server.tool_manager.create_tool_async(existing_tool, default_user)
original_id = created.id
# prepare updated version of the tool
updated_tool = PydanticTool(
name="test_no_override_tool",
description="Should not be updated",
tags=["should_not_update"],
source_code="def test_no_override_tool():\n '''Should not update'''\n return 'should_not_update'",
source_type="python",
)
# bulk upsert with override_existing_tools=False
result = await server.tool_manager.bulk_upsert_tools_async([updated_tool], default_user, override_existing_tools=False)
assert len(result) == 1
assert result[0].id == original_id # id should remain the same
assert result[0].description == "Original description" # description should NOT be updated
assert result[0].tags == ["original"] # tags should NOT be updated
# verify the tool was NOT updated in the database
fetched = await server.tool_manager.get_tool_by_id_async(original_id, default_user)
assert fetched.description == "Original description"
assert fetched.tags == ["original"]
@pytest.mark.asyncio
async def test_bulk_upsert_tools_override_mixed_scenario(server: SyncServer, default_user):
"""Test bulk_upsert_tools_async with override_existing_tools=False in mixed create/update scenario"""
# create some existing tools
existing_tools = []
for i in range(2):
tool = PydanticTool(
name=f"mixed_existing_{i}",
description=f"Original {i}",
tags=["original"],
source_code=f"def mixed_existing_{i}():\n '''Original {i}'''\n return 'original_{i}'",
source_type="python",
)
created = await server.tool_manager.create_tool_async(tool, default_user)
existing_tools.append(created)
# prepare bulk tools: 2 updates (that should be skipped) + 3 new creations
bulk_tools = []
# these should be skipped when override_existing_tools=False
for i in range(2):
bulk_tools.append(
PydanticTool(
name=f"mixed_existing_{i}",
description=f"Should not update {i}",
tags=["should_not_update"],
source_code=f"def mixed_existing_{i}():\n '''Should not update {i}'''\n return 'should_not_update_{i}'",
source_type="python",
)
)
# these should be created
for i in range(3):
bulk_tools.append(
PydanticTool(
name=f"mixed_new_{i}",
description=f"New tool {i}",
tags=["new"],
source_code=f"def mixed_new_{i}():\n '''New {i}'''\n return 'new_{i}'",
source_type="python",
)
)
# bulk upsert with override_existing_tools=False
result = await server.tool_manager.bulk_upsert_tools_async(bulk_tools, default_user, override_existing_tools=False)
assert len(result) == 5 # 2 existing (not updated) + 3 new
# verify existing tools were NOT updated
for i in range(2):
tool = await server.tool_manager.get_tool_by_name_async(f"mixed_existing_{i}", default_user)
assert tool.description == f"Original {i}" # should remain original
assert tool.tags == ["original"] # should remain original
assert tool.id == existing_tools[i].id # id should remain same
# verify new tools were created
for i in range(3):
new_tool = await server.tool_manager.get_tool_by_name_async(f"mixed_new_{i}", default_user)
assert new_tool is not None
assert new_tool.description == f"New tool {i}"
assert new_tool.tags == ["new"]
@pytest.mark.asyncio
async def test_create_tool_with_pip_requirements(server: SyncServer, default_user, default_organization):
def test_tool_with_deps():