feat: Support override tool functionality to agent file v2 (#4092)
This commit is contained in:
@@ -260,7 +260,9 @@ async def import_agent(
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}")
|
||||
|
||||
try:
|
||||
import_result = await server.agent_serialization_manager.import_file(schema=agent_schema, actor=actor)
|
||||
import_result = await server.agent_serialization_manager.import_file(
|
||||
schema=agent_schema, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools
|
||||
)
|
||||
|
||||
if not import_result.success:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -428,6 +428,8 @@ class AgentSerializationManager:
|
||||
self,
|
||||
schema: AgentFileSchema,
|
||||
actor: User,
|
||||
append_copy_suffix: bool = False,
|
||||
override_existing_tools: bool = True,
|
||||
dry_run: bool = False,
|
||||
env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ImportResult:
|
||||
@@ -489,7 +491,9 @@ class AgentSerializationManager:
|
||||
pydantic_tools.append(Tool(**tool_schema.model_dump(exclude={"id"})))
|
||||
|
||||
# bulk upsert all tools at once
|
||||
created_tools = await self.tool_manager.bulk_upsert_tools_async(pydantic_tools, actor)
|
||||
created_tools = await self.tool_manager.bulk_upsert_tools_async(
|
||||
pydantic_tools, actor, override_existing_tools=override_existing_tools
|
||||
)
|
||||
|
||||
# map file ids to database ids
|
||||
# note: tools are matched by name during upsert, so we need to match by name here too
|
||||
@@ -611,6 +615,8 @@ class AgentSerializationManager:
|
||||
for agent_schema in schema.agents:
|
||||
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
|
||||
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})
|
||||
if append_copy_suffix:
|
||||
agent_data["name"] = agent_data.get("name") + "_copy"
|
||||
|
||||
# Remap tool_ids from file IDs to database IDs
|
||||
if agent_data.get("tool_ids"):
|
||||
|
||||
@@ -184,7 +184,9 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def bulk_upsert_tools_async(self, pydantic_tools: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
|
||||
async def bulk_upsert_tools_async(
|
||||
self, pydantic_tools: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
|
||||
) -> List[PydanticTool]:
|
||||
"""
|
||||
Bulk create or update multiple tools in a single database transaction.
|
||||
|
||||
@@ -227,10 +229,10 @@ class ToolManager:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# use optimized postgresql bulk upsert
|
||||
async with db_registry.async_session() as session:
|
||||
return await self._bulk_upsert_postgresql(session, pydantic_tools, actor)
|
||||
return await self._bulk_upsert_postgresql(session, pydantic_tools, actor, override_existing_tools)
|
||||
else:
|
||||
# fallback to individual upserts for sqlite
|
||||
return await self._upsert_tools_individually(pydantic_tools, actor)
|
||||
return await self._upsert_tools_individually(pydantic_tools, actor, override_existing_tools)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -784,8 +786,10 @@ class ToolManager:
|
||||
return await self._upsert_tools_individually(tool_data_list, actor)
|
||||
|
||||
@trace_method
|
||||
async def _bulk_upsert_postgresql(self, session, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""hyper-optimized postgresql bulk upsert using on_conflict_do_update."""
|
||||
async def _bulk_upsert_postgresql(
|
||||
self, session, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
|
||||
) -> List[PydanticTool]:
|
||||
"""hyper-optimized postgresql bulk upsert using on_conflict_do_update or on_conflict_do_nothing."""
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
@@ -809,32 +813,51 @@ class ToolManager:
|
||||
# use postgresql's native bulk upsert
|
||||
stmt = insert(table).values(insert_data)
|
||||
|
||||
# on conflict, update all columns except id, created_at, and _created_by_id
|
||||
excluded = stmt.excluded
|
||||
update_dict = {}
|
||||
for col in table.columns:
|
||||
if col.name not in ("id", "created_at", "_created_by_id"):
|
||||
if col.name == "updated_at":
|
||||
update_dict[col.name] = func.now()
|
||||
else:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
if override_existing_tools:
|
||||
# on conflict, update all columns except id, created_at, and _created_by_id
|
||||
excluded = stmt.excluded
|
||||
update_dict = {}
|
||||
for col in table.columns:
|
||||
if col.name not in ("id", "created_at", "_created_by_id"):
|
||||
if col.name == "updated_at":
|
||||
update_dict[col.name] = func.now()
|
||||
else:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
else:
|
||||
# on conflict, do nothing (skip existing tools)
|
||||
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"])
|
||||
|
||||
await session.execute(upsert_stmt)
|
||||
await session.commit()
|
||||
|
||||
# fetch results
|
||||
# fetch results (includes both inserted and skipped tools)
|
||||
tool_names = [tool.name for tool in tool_data_list]
|
||||
result_query = select(ToolModel).where(ToolModel.name.in_(tool_names), ToolModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(result_query)
|
||||
return [tool.to_pydantic() for tool in result.scalars()]
|
||||
|
||||
@trace_method
|
||||
async def _upsert_tools_individually(self, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
|
||||
async def _upsert_tools_individually(
|
||||
self, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
|
||||
) -> List[PydanticTool]:
|
||||
"""fallback to individual upserts for sqlite (original approach)."""
|
||||
tools = []
|
||||
for tool in tool_data_list:
|
||||
upserted_tool = await self.create_or_update_tool_async(tool, actor)
|
||||
tools.append(upserted_tool)
|
||||
if override_existing_tools:
|
||||
# update existing tools if they exist
|
||||
upserted_tool = await self.create_or_update_tool_async(tool, actor)
|
||||
tools.append(upserted_tool)
|
||||
else:
|
||||
# skip existing tools, only create new ones
|
||||
existing_tool_id = await self.get_tool_id_by_name_async(tool_name=tool.name, actor=actor)
|
||||
if existing_tool_id:
|
||||
# tool exists, fetch and return it without updating
|
||||
existing_tool = await self.get_tool_by_id_async(existing_tool_id, actor=actor)
|
||||
tools.append(existing_tool)
|
||||
else:
|
||||
# tool doesn't exist, create it
|
||||
created_tool = await self.create_tool_async(tool, actor=actor)
|
||||
tools.append(created_tool)
|
||||
return tools
|
||||
|
||||
@@ -3969,7 +3969,7 @@ async def test_bulk_upsert_tools_name_conflict(server: SyncServer, default_user)
|
||||
name="unique_name_tool",
|
||||
description="Original description",
|
||||
tags=["original"],
|
||||
source_code="def unique_name_tool():\n '''Original function'''\n return 'original'",
|
||||
source_code="def unique_name_tool():\n '''Original function'''\n return 'original'`",
|
||||
source_type="python",
|
||||
)
|
||||
|
||||
@@ -4087,6 +4087,146 @@ async def test_bulk_upsert_tools_mixed_create_update(server: SyncServer, default
|
||||
assert tool_2.tags == ["existing"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_upsert_tools_override_existing_true(server: SyncServer, default_user):
|
||||
"""Test bulk_upsert_tools_async with override_existing_tools=True (default behavior)"""
|
||||
|
||||
# create some existing tools
|
||||
existing_tool = PydanticTool(
|
||||
name="test_override_tool",
|
||||
description="Original description",
|
||||
tags=["original"],
|
||||
source_code="def test_override_tool():\n '''Original'''\n return 'original'",
|
||||
source_type="python",
|
||||
)
|
||||
created = await server.tool_manager.create_tool_async(existing_tool, default_user)
|
||||
original_id = created.id
|
||||
|
||||
# prepare updated version of the tool
|
||||
updated_tool = PydanticTool(
|
||||
name="test_override_tool",
|
||||
description="Updated description",
|
||||
tags=["updated"],
|
||||
source_code="def test_override_tool():\n '''Updated'''\n return 'updated'",
|
||||
source_type="python",
|
||||
)
|
||||
|
||||
# bulk upsert with override_existing_tools=True (default)
|
||||
result = await server.tool_manager.bulk_upsert_tools_async([updated_tool], default_user, override_existing_tools=True)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == original_id # id should remain the same
|
||||
assert result[0].description == "Updated description" # description should be updated
|
||||
assert result[0].tags == ["updated"] # tags should be updated
|
||||
|
||||
# verify the tool was actually updated in the database
|
||||
fetched = await server.tool_manager.get_tool_by_id_async(original_id, default_user)
|
||||
assert fetched.description == "Updated description"
|
||||
assert fetched.tags == ["updated"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_upsert_tools_override_existing_false(server: SyncServer, default_user):
|
||||
"""Test bulk_upsert_tools_async with override_existing_tools=False (skip existing)"""
|
||||
|
||||
# create some existing tools
|
||||
existing_tool = PydanticTool(
|
||||
name="test_no_override_tool",
|
||||
description="Original description",
|
||||
tags=["original"],
|
||||
source_code="def test_no_override_tool():\n '''Original'''\n return 'original'",
|
||||
source_type="python",
|
||||
)
|
||||
created = await server.tool_manager.create_tool_async(existing_tool, default_user)
|
||||
original_id = created.id
|
||||
|
||||
# prepare updated version of the tool
|
||||
updated_tool = PydanticTool(
|
||||
name="test_no_override_tool",
|
||||
description="Should not be updated",
|
||||
tags=["should_not_update"],
|
||||
source_code="def test_no_override_tool():\n '''Should not update'''\n return 'should_not_update'",
|
||||
source_type="python",
|
||||
)
|
||||
|
||||
# bulk upsert with override_existing_tools=False
|
||||
result = await server.tool_manager.bulk_upsert_tools_async([updated_tool], default_user, override_existing_tools=False)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == original_id # id should remain the same
|
||||
assert result[0].description == "Original description" # description should NOT be updated
|
||||
assert result[0].tags == ["original"] # tags should NOT be updated
|
||||
|
||||
# verify the tool was NOT updated in the database
|
||||
fetched = await server.tool_manager.get_tool_by_id_async(original_id, default_user)
|
||||
assert fetched.description == "Original description"
|
||||
assert fetched.tags == ["original"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_upsert_tools_override_mixed_scenario(server: SyncServer, default_user):
|
||||
"""Test bulk_upsert_tools_async with override_existing_tools=False in mixed create/update scenario"""
|
||||
|
||||
# create some existing tools
|
||||
existing_tools = []
|
||||
for i in range(2):
|
||||
tool = PydanticTool(
|
||||
name=f"mixed_existing_{i}",
|
||||
description=f"Original {i}",
|
||||
tags=["original"],
|
||||
source_code=f"def mixed_existing_{i}():\n '''Original {i}'''\n return 'original_{i}'",
|
||||
source_type="python",
|
||||
)
|
||||
created = await server.tool_manager.create_tool_async(tool, default_user)
|
||||
existing_tools.append(created)
|
||||
|
||||
# prepare bulk tools: 2 updates (that should be skipped) + 3 new creations
|
||||
bulk_tools = []
|
||||
|
||||
# these should be skipped when override_existing_tools=False
|
||||
for i in range(2):
|
||||
bulk_tools.append(
|
||||
PydanticTool(
|
||||
name=f"mixed_existing_{i}",
|
||||
description=f"Should not update {i}",
|
||||
tags=["should_not_update"],
|
||||
source_code=f"def mixed_existing_{i}():\n '''Should not update {i}'''\n return 'should_not_update_{i}'",
|
||||
source_type="python",
|
||||
)
|
||||
)
|
||||
|
||||
# these should be created
|
||||
for i in range(3):
|
||||
bulk_tools.append(
|
||||
PydanticTool(
|
||||
name=f"mixed_new_{i}",
|
||||
description=f"New tool {i}",
|
||||
tags=["new"],
|
||||
source_code=f"def mixed_new_{i}():\n '''New {i}'''\n return 'new_{i}'",
|
||||
source_type="python",
|
||||
)
|
||||
)
|
||||
|
||||
# bulk upsert with override_existing_tools=False
|
||||
result = await server.tool_manager.bulk_upsert_tools_async(bulk_tools, default_user, override_existing_tools=False)
|
||||
|
||||
assert len(result) == 5 # 2 existing (not updated) + 3 new
|
||||
|
||||
# verify existing tools were NOT updated
|
||||
for i in range(2):
|
||||
tool = await server.tool_manager.get_tool_by_name_async(f"mixed_existing_{i}", default_user)
|
||||
assert tool.description == f"Original {i}" # should remain original
|
||||
assert tool.tags == ["original"] # should remain original
|
||||
assert tool.id == existing_tools[i].id # id should remain same
|
||||
|
||||
# verify new tools were created
|
||||
for i in range(3):
|
||||
new_tool = await server.tool_manager.get_tool_by_name_async(f"mixed_new_{i}", default_user)
|
||||
assert new_tool is not None
|
||||
assert new_tool.description == f"New tool {i}"
|
||||
assert new_tool.tags == ["new"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tool_with_pip_requirements(server: SyncServer, default_user, default_organization):
|
||||
def test_tool_with_deps():
|
||||
|
||||
Reference in New Issue
Block a user