feat: remove organization from tool pydantic schema (#3430)

This commit is contained in:
cthomas
2025-07-20 00:00:31 -07:00
committed by GitHub
parent 555db59f2c
commit 2b0dc4a1f9
4 changed files with 29 additions and 10 deletions

View File

@@ -50,7 +50,6 @@ class Tool(BaseTool):
tool_type: ToolType = Field(ToolType.CUSTOM, description="The type of the tool.")
description: Optional[str] = Field(None, description="The description of the tool.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the tool.")
name: Optional[str] = Field(None, description="The name of the function.")
tags: List[str] = Field([], description="Metadata tags.")

View File

@@ -1,3 +1,7 @@
from typing import Dict
from marshmallow import post_dump, pre_load
from letta.orm import Tool
from letta.schemas.tool import Tool as PydanticTool
from letta.serialize_schemas.marshmallow_base import BaseSchema
@@ -10,6 +14,24 @@ class SerializedToolSchema(BaseSchema):
__pydantic_model__ = PydanticTool
@post_dump
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
# delete id
del data["id"]
del data["_created_by_id"]
del data["_last_updated_by_id"]
return data
@pre_load
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
if self.Meta.model:
data["id"] = self.generate_id()
data["_created_by_id"] = self.actor.id
data["_last_updated_by_id"] = self.actor.id
return data
class Meta(BaseSchema.Meta):
model = Tool
exclude = BaseSchema.Meta.exclude + ("is_deleted",)
exclude = BaseSchema.Meta.exclude + ("is_deleted", "organization")

View File

@@ -76,6 +76,7 @@ class ToolManager:
if tool_id:
# Put to dict and remove fields that should not be reset
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
update_data["organization_id"] = actor.organization_id
# If there's anything to update
if update_data:
@@ -148,12 +149,12 @@ class ToolManager:
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
with db_registry.session() as session:
# Set the organization id at the ORM layer
pydantic_tool.organization_id = actor.organization_id
# Auto-generate description if not provided
if pydantic_tool.description is None:
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
tool_data = pydantic_tool.model_dump(to_orm=True)
# Set the organization id at the ORM layer
tool_data["organization_id"] = actor.organization_id
tool = ToolModel(**tool_data)
tool.create(session, actor=actor) # Re-raise other database-related errors
@@ -164,12 +165,12 @@ class ToolManager:
async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
async with db_registry.async_session() as session:
# Set the organization id at the ORM layer
pydantic_tool.organization_id = actor.organization_id
# Auto-generate description if not provided
if pydantic_tool.description is None:
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
tool_data = pydantic_tool.model_dump(to_orm=True)
# Set the organization id at the ORM layer
tool_data["organization_id"] = actor.organization_id
tool = ToolModel(**tool_data)
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
@@ -516,7 +517,6 @@ class ToolManager:
source_type="python",
tool_type=tool_type,
return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT,
organization_id=actor.organization_id,
)
# auto-generate description if not provided
@@ -551,6 +551,7 @@ class ToolManager:
if actor:
tool_dict["_created_by_id"] = actor.id
tool_dict["_last_updated_by_id"] = actor.id
tool_dict["organization_id"] = actor.organization_id
# filter to only include columns that exist in the table
filtered_dict = {k: v for k, v in tool_dict.items() if k in valid_columns}

View File

@@ -3025,21 +3025,18 @@ async def test_user_caching(server: SyncServer, event_loop, default_user, perfor
def test_create_tool(server: SyncServer, print_tool, default_user, default_organization):
# Assertions to ensure the created tool matches the expected values
assert print_tool.created_by_id == default_user.id
assert print_tool.organization_id == default_organization.id
assert print_tool.tool_type == ToolType.CUSTOM
def test_create_composio_tool(server: SyncServer, composio_github_star_tool, default_user, default_organization):
# Assertions to ensure the created tool matches the expected values
assert composio_github_star_tool.created_by_id == default_user.id
assert composio_github_star_tool.organization_id == default_organization.id
assert composio_github_star_tool.tool_type == ToolType.EXTERNAL_COMPOSIO
def test_create_mcp_tool(server: SyncServer, mcp_tool, default_user, default_organization):
# Assertions to ensure the created tool matches the expected values
assert mcp_tool.created_by_id == default_user.id
assert mcp_tool.organization_id == default_organization.id
assert mcp_tool.tool_type == ToolType.EXTERNAL_MCP
assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == "test"