feat: Add metadata_ field to Tool (#1321)

This commit is contained in:
Matthew Zhou
2025-03-17 17:14:08 -07:00
committed by GitHub
parent 6962a6bcb3
commit 54bb536beb
6 changed files with 46 additions and 11 deletions

View File

@@ -0,0 +1,31 @@
"""Add metadata to Tools
Revision ID: 1e553a664210
Revises: 2cceb07c2384
Create Date: 2025-03-17 15:50:05.562302
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "1e553a664210"
down_revision: Union[str, None] = "2cceb07c2384"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tools", sa.Column("metadata_", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("tools", "metadata_")
# ### end Alembic commands ###

View File

@@ -508,10 +508,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
def to_pydantic(self) -> "BaseModel":
"""converts to the basic pydantic model counterpart"""
model = self.__pydantic_model__.model_validate(self)
if hasattr(self, "metadata_"):
model.metadata = self.metadata_
"""Converts the SQLAlchemy model to its corresponding Pydantic model."""
model = self.__pydantic_model__.model_validate(self, from_attributes=True)
# Explicitly map metadata_ to metadata in Pydantic model
if hasattr(self, "metadata_") and hasattr(model, "metadata_"):
setattr(model, "metadata_", self.metadata_) # Ensures correct assignment
return model
def pretty_print_columns(self) -> str:

View File

@@ -44,5 +44,6 @@ class Tool(SqlalchemyBase, OrganizationMixin):
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
args_json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The JSON schema of the function arguments.")
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")

View File

@@ -66,6 +66,7 @@ class Tool(BaseTool):
# metadata fields
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
@model_validator(mode="after")
def refresh_source_code_and_json_schema(self):
@@ -137,10 +138,6 @@ class ToolCreate(LettaBase):
@classmethod
def from_mcp(cls, mcp_server_name: str, mcp_tool: MCPTool) -> "ToolCreate":
# Get the MCP tool from the MCP server
# NVM
# Pass the MCP tool to the schema generator
json_schema = generate_tool_schema_for_mcp(mcp_tool=mcp_tool)

View File

@@ -15,7 +15,7 @@ class CoreMemoryBlockSchema(BaseModel):
is_template: bool
label: str
limit: int
metadata_: Dict[str, Any] = Field(default_factory=dict)
metadata_: Optional[Dict] = None
template_name: Optional[str]
updated_at: str
value: str
@@ -85,6 +85,7 @@ class ToolSchema(BaseModel):
tags: List[str]
tool_type: str
updated_at: str
metadata_: Optional[Dict] = None
class AgentSchema(BaseModel):
@@ -99,7 +100,7 @@ class AgentSchema(BaseModel):
llm_config: LLMConfig
message_buffer_autoclear: bool
messages: List[MessageSchema]
metadata_: Dict
metadata_: Optional[Dict] = None
multi_agent_group: Optional[Any]
name: str
system: str

View File

@@ -145,8 +145,9 @@ def print_tool(server: SyncServer, default_user, default_organization):
source_type = "python"
description = "test_description"
tags = ["test"]
metadata = {"a": "b"}
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
@@ -1834,6 +1835,7 @@ def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
assert fetched_tool.name == print_tool.name
assert fetched_tool.description == print_tool.description
assert fetched_tool.tags == print_tool.tags
assert fetched_tool.metadata_ == print_tool.metadata_
assert fetched_tool.source_code == print_tool.source_code
assert fetched_tool.source_type == print_tool.source_type
assert fetched_tool.tool_type == ToolType.CUSTOM