feat: remove organization from tool pydantic schema (#3430)
This commit is contained in:
@@ -50,7 +50,6 @@ class Tool(BaseTool):
|
||||
tool_type: ToolType = Field(ToolType.CUSTOM, description="The type of the tool.")
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the tool.")
|
||||
name: Optional[str] = Field(None, description="The name of the function.")
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import post_dump, pre_load
|
||||
|
||||
from letta.orm import Tool
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
@@ -10,6 +14,24 @@ class SerializedToolSchema(BaseSchema):
|
||||
|
||||
__pydantic_model__ = PydanticTool
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
# delete id
|
||||
del data["id"]
|
||||
del data["_created_by_id"]
|
||||
del data["_last_updated_by_id"]
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
if self.Meta.model:
|
||||
data["id"] = self.generate_id()
|
||||
data["_created_by_id"] = self.actor.id
|
||||
data["_last_updated_by_id"] = self.actor.id
|
||||
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Tool
|
||||
exclude = BaseSchema.Meta.exclude + ("is_deleted",)
|
||||
exclude = BaseSchema.Meta.exclude + ("is_deleted", "organization")
|
||||
|
||||
@@ -76,6 +76,7 @@ class ToolManager:
|
||||
if tool_id:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data["organization_id"] = actor.organization_id
|
||||
|
||||
# If there's anything to update
|
||||
if update_data:
|
||||
@@ -148,12 +149,12 @@ class ToolManager:
|
||||
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
with db_registry.session() as session:
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_tool.organization_id = actor.organization_id
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
tool_data = pydantic_tool.model_dump(to_orm=True)
|
||||
# Set the organization id at the ORM layer
|
||||
tool_data["organization_id"] = actor.organization_id
|
||||
|
||||
tool = ToolModel(**tool_data)
|
||||
tool.create(session, actor=actor) # Re-raise other database-related errors
|
||||
@@ -164,12 +165,12 @@ class ToolManager:
|
||||
async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_tool.organization_id = actor.organization_id
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
tool_data = pydantic_tool.model_dump(to_orm=True)
|
||||
# Set the organization id at the ORM layer
|
||||
tool_data["organization_id"] = actor.organization_id
|
||||
|
||||
tool = ToolModel(**tool_data)
|
||||
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
|
||||
@@ -516,7 +517,6 @@ class ToolManager:
|
||||
source_type="python",
|
||||
tool_type=tool_type,
|
||||
return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
|
||||
# auto-generate description if not provided
|
||||
@@ -551,6 +551,7 @@ class ToolManager:
|
||||
if actor:
|
||||
tool_dict["_created_by_id"] = actor.id
|
||||
tool_dict["_last_updated_by_id"] = actor.id
|
||||
tool_dict["organization_id"] = actor.organization_id
|
||||
|
||||
# filter to only include columns that exist in the table
|
||||
filtered_dict = {k: v for k, v in tool_dict.items() if k in valid_columns}
|
||||
|
||||
@@ -3025,21 +3025,18 @@ async def test_user_caching(server: SyncServer, event_loop, default_user, perfor
|
||||
def test_create_tool(server: SyncServer, print_tool, default_user, default_organization):
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert print_tool.created_by_id == default_user.id
|
||||
assert print_tool.organization_id == default_organization.id
|
||||
assert print_tool.tool_type == ToolType.CUSTOM
|
||||
|
||||
|
||||
def test_create_composio_tool(server: SyncServer, composio_github_star_tool, default_user, default_organization):
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert composio_github_star_tool.created_by_id == default_user.id
|
||||
assert composio_github_star_tool.organization_id == default_organization.id
|
||||
assert composio_github_star_tool.tool_type == ToolType.EXTERNAL_COMPOSIO
|
||||
|
||||
|
||||
def test_create_mcp_tool(server: SyncServer, mcp_tool, default_user, default_organization):
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert mcp_tool.created_by_id == default_user.id
|
||||
assert mcp_tool.organization_id == default_organization.id
|
||||
assert mcp_tool.tool_type == ToolType.EXTERNAL_MCP
|
||||
assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == "test"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user