feat: Add metadata_ field to Tool (#1321)
This commit is contained in:
31
alembic/versions/1e553a664210_add_metadata_to_tools.py
Normal file
31
alembic/versions/1e553a664210_add_metadata_to_tools.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Add metadata to Tools
|
||||
|
||||
Revision ID: 1e553a664210
|
||||
Revises: 2cceb07c2384
|
||||
Create Date: 2025-03-17 15:50:05.562302
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "1e553a664210"
|
||||
down_revision: Union[str, None] = "2cceb07c2384"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("tools", sa.Column("metadata_", sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("tools", "metadata_")
|
||||
# ### end Alembic commands ###
|
||||
@@ -508,10 +508,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
|
||||
|
||||
def to_pydantic(self) -> "BaseModel":
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
model = self.__pydantic_model__.model_validate(self)
|
||||
if hasattr(self, "metadata_"):
|
||||
model.metadata = self.metadata_
|
||||
"""Converts the SQLAlchemy model to its corresponding Pydantic model."""
|
||||
model = self.__pydantic_model__.model_validate(self, from_attributes=True)
|
||||
|
||||
# Explicitly map metadata_ to metadata in Pydantic model
|
||||
if hasattr(self, "metadata_") and hasattr(model, "metadata_"):
|
||||
setattr(model, "metadata_", self.metadata_) # Ensures correct assignment
|
||||
|
||||
return model
|
||||
|
||||
def pretty_print_columns(self) -> str:
|
||||
|
||||
@@ -44,5 +44,6 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
|
||||
json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
|
||||
args_json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The JSON schema of the function arguments.")
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.")
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
||||
|
||||
@@ -66,6 +66,7 @@ class Tool(BaseTool):
|
||||
# metadata fields
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def refresh_source_code_and_json_schema(self):
|
||||
@@ -137,10 +138,6 @@ class ToolCreate(LettaBase):
|
||||
|
||||
@classmethod
|
||||
def from_mcp(cls, mcp_server_name: str, mcp_tool: MCPTool) -> "ToolCreate":
|
||||
|
||||
# Get the MCP tool from the MCP server
|
||||
# NVM
|
||||
|
||||
# Pass the MCP tool to the schema generator
|
||||
json_schema = generate_tool_schema_for_mcp(mcp_tool=mcp_tool)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class CoreMemoryBlockSchema(BaseModel):
|
||||
is_template: bool
|
||||
label: str
|
||||
limit: int
|
||||
metadata_: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata_: Optional[Dict] = None
|
||||
template_name: Optional[str]
|
||||
updated_at: str
|
||||
value: str
|
||||
@@ -85,6 +85,7 @@ class ToolSchema(BaseModel):
|
||||
tags: List[str]
|
||||
tool_type: str
|
||||
updated_at: str
|
||||
metadata_: Optional[Dict] = None
|
||||
|
||||
|
||||
class AgentSchema(BaseModel):
|
||||
@@ -99,7 +100,7 @@ class AgentSchema(BaseModel):
|
||||
llm_config: LLMConfig
|
||||
message_buffer_autoclear: bool
|
||||
messages: List[MessageSchema]
|
||||
metadata_: Dict
|
||||
metadata_: Optional[Dict] = None
|
||||
multi_agent_group: Optional[Any]
|
||||
name: str
|
||||
system: str
|
||||
|
||||
@@ -145,8 +145,9 @@ def print_tool(server: SyncServer, default_user, default_organization):
|
||||
source_type = "python"
|
||||
description = "test_description"
|
||||
tags = ["test"]
|
||||
metadata = {"a": "b"}
|
||||
|
||||
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
|
||||
derived_name = derived_json_schema["name"]
|
||||
@@ -1834,6 +1835,7 @@ def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
assert fetched_tool.name == print_tool.name
|
||||
assert fetched_tool.description == print_tool.description
|
||||
assert fetched_tool.tags == print_tool.tags
|
||||
assert fetched_tool.metadata_ == print_tool.metadata_
|
||||
assert fetched_tool.source_code == print_tool.source_code
|
||||
assert fetched_tool.source_type == print_tool.source_type
|
||||
assert fetched_tool.tool_type == ToolType.CUSTOM
|
||||
|
||||
Reference in New Issue
Block a user