From 2b0dc4a1f9c920beae4938bcdffc7a9cc9d0d391 Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 20 Jul 2025 00:00:31 -0700 Subject: [PATCH] feat: remove organization from tool pydantic schema (#3430) --- letta/schemas/tool.py | 1 - letta/serialize_schemas/marshmallow_tool.py | 24 ++++++++++++++++++++- letta/services/tool_manager.py | 11 +++++----- tests/test_managers.py | 3 --- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 3b8b2233..c2964e63 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -50,7 +50,6 @@ class Tool(BaseTool): tool_type: ToolType = Field(ToolType.CUSTOM, description="The type of the tool.") description: Optional[str] = Field(None, description="The description of the tool.") source_type: Optional[str] = Field(None, description="The type of the source code.") - organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the tool.") name: Optional[str] = Field(None, description="The name of the function.") tags: List[str] = Field([], description="Metadata tags.") diff --git a/letta/serialize_schemas/marshmallow_tool.py b/letta/serialize_schemas/marshmallow_tool.py index 44f64a97..a6d1c91e 100644 --- a/letta/serialize_schemas/marshmallow_tool.py +++ b/letta/serialize_schemas/marshmallow_tool.py @@ -1,3 +1,7 @@ +from typing import Dict + +from marshmallow import post_dump, pre_load + from letta.orm import Tool from letta.schemas.tool import Tool as PydanticTool from letta.serialize_schemas.marshmallow_base import BaseSchema @@ -10,6 +14,24 @@ class SerializedToolSchema(BaseSchema): __pydantic_model__ = PydanticTool + @post_dump + def sanitize_ids(self, data: Dict, **kwargs) -> Dict: + # delete id + del data["id"] + del data["_created_by_id"] + del data["_last_updated_by_id"] + + return data + + @pre_load + def regenerate_ids(self, data: Dict, **kwargs) -> Dict: + if self.Meta.model: + data["id"] = self.generate_id() + data["_created_by_id"] = self.actor.id + data["_last_updated_by_id"] = self.actor.id + + return data + class Meta(BaseSchema.Meta): model = Tool - exclude = BaseSchema.Meta.exclude + ("is_deleted",) + exclude = BaseSchema.Meta.exclude + ("is_deleted", "organization") diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index ed0803e5..13254049 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -76,6 +76,7 @@ class ToolManager: if tool_id: # Put to dict and remove fields that should not be reset update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True) + update_data["organization_id"] = actor.organization_id # If there's anything to update if update_data: @@ -148,12 +149,12 @@ class ToolManager: def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" with db_registry.session() as session: - # Set the organization id at the ORM layer - pydantic_tool.organization_id = actor.organization_id # Auto-generate description if not provided if pydantic_tool.description is None: pydantic_tool.description = pydantic_tool.json_schema.get("description", None) tool_data = pydantic_tool.model_dump(to_orm=True) + # Set the organization id at the ORM layer + tool_data["organization_id"] = actor.organization_id tool = ToolModel(**tool_data) tool.create(session, actor=actor) # Re-raise other database-related errors @@ -164,12 +165,12 @@ class ToolManager: async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" async with db_registry.async_session() as session: - # Set the organization id at the ORM layer - pydantic_tool.organization_id = actor.organization_id # Auto-generate description if not provided if pydantic_tool.description is None: pydantic_tool.description = pydantic_tool.json_schema.get("description", None) tool_data = pydantic_tool.model_dump(to_orm=True) + # Set the organization id at the ORM layer + tool_data["organization_id"] = actor.organization_id tool = ToolModel(**tool_data) await tool.create_async(session, actor=actor) # Re-raise other database-related errors @@ -516,7 +517,6 @@ class ToolManager: source_type="python", tool_type=tool_type, return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT, - organization_id=actor.organization_id, ) # auto-generate description if not provided @@ -551,6 +551,7 @@ class ToolManager: if actor: tool_dict["_created_by_id"] = actor.id tool_dict["_last_updated_by_id"] = actor.id + tool_dict["organization_id"] = actor.organization_id # filter to only include columns that exist in the table filtered_dict = {k: v for k, v in tool_dict.items() if k in valid_columns} diff --git a/tests/test_managers.py b/tests/test_managers.py index 8874403c..c6a7e1e5 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3025,21 +3025,18 @@ async def test_user_caching(server: SyncServer, event_loop, default_user, perfor def test_create_tool(server: SyncServer, print_tool, default_user, default_organization): # Assertions to ensure the created tool matches the expected values assert print_tool.created_by_id == default_user.id - assert print_tool.organization_id == default_organization.id assert print_tool.tool_type == ToolType.CUSTOM def test_create_composio_tool(server: SyncServer, composio_github_star_tool, default_user, default_organization): # Assertions to ensure the created tool matches the expected values assert composio_github_star_tool.created_by_id == default_user.id - assert composio_github_star_tool.organization_id == default_organization.id assert composio_github_star_tool.tool_type == ToolType.EXTERNAL_COMPOSIO def test_create_mcp_tool(server: SyncServer, mcp_tool, default_user, default_organization): # Assertions to ensure the created tool matches the expected values assert mcp_tool.created_by_id == default_user.id - assert mcp_tool.organization_id == default_organization.id assert mcp_tool.tool_type == ToolType.EXTERNAL_MCP assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == "test"