diff --git a/alembic/versions/1e553a664210_add_metadata_to_tools.py b/alembic/versions/1e553a664210_add_metadata_to_tools.py new file mode 100644 index 00000000..51b6da20 --- /dev/null +++ b/alembic/versions/1e553a664210_add_metadata_to_tools.py @@ -0,0 +1,31 @@ +"""Add metadata to Tools + +Revision ID: 1e553a664210 +Revises: 2cceb07c2384 +Create Date: 2025-03-17 15:50:05.562302 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1e553a664210" +down_revision: Union[str, None] = "2cceb07c2384" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("tools", sa.Column("metadata_", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("tools", "metadata_") + # ### end Alembic commands ### diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index fd211b86..3b45c6ee 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -508,10 +508,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.") def to_pydantic(self) -> "BaseModel": - """converts to the basic pydantic model counterpart""" - model = self.__pydantic_model__.model_validate(self) - if hasattr(self, "metadata_"): - model.metadata = self.metadata_ + """Converts the SQLAlchemy model to its corresponding Pydantic model.""" + model = self.__pydantic_model__.model_validate(self, from_attributes=True) + + # Explicitly map metadata_ to metadata in Pydantic model + if hasattr(self, "metadata_") and hasattr(model, "metadata_"): + setattr(model, "metadata_", self.metadata_) # Ensures correct assignment + return model def pretty_print_columns(self) -> str: diff --git a/letta/orm/tool.py b/letta/orm/tool.py index c43a0a88..7a7c3199 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -44,5 +44,6 @@ class Tool(SqlalchemyBase, OrganizationMixin): source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.") json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.") args_json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The JSON schema of the function arguments.") + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.") # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin") diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 55fac00c..b23d9ac2 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -66,6 +66,7 @@ class Tool(BaseTool): # metadata fields created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") + metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.") @model_validator(mode="after") def refresh_source_code_and_json_schema(self): @@ -137,10 +138,6 @@ class ToolCreate(LettaBase): @classmethod def from_mcp(cls, mcp_server_name: str, mcp_tool: MCPTool) -> "ToolCreate": - - # Get the MCP tool from the MCP server - # NVM - # Pass the MCP tool to the schema generator json_schema = generate_tool_schema_for_mcp(mcp_tool=mcp_tool) diff --git a/letta/serialize_schemas/pydantic_agent_schema.py b/letta/serialize_schemas/pydantic_agent_schema.py index ce1d65ce..593c0377 100644 --- a/letta/serialize_schemas/pydantic_agent_schema.py +++ b/letta/serialize_schemas/pydantic_agent_schema.py @@ -15,7 +15,7 @@ class CoreMemoryBlockSchema(BaseModel): is_template: bool label: str limit: int - metadata_: Dict[str, Any] = Field(default_factory=dict) + metadata_: Optional[Dict] = None template_name: Optional[str] updated_at: str value: str @@ -85,6 +85,7 @@ class ToolSchema(BaseModel): tags: List[str] tool_type: str updated_at: str + metadata_: Optional[Dict] = None class AgentSchema(BaseModel): @@ -99,7 +100,7 @@ class AgentSchema(BaseModel): llm_config: LLMConfig message_buffer_autoclear: bool messages: List[MessageSchema] - metadata_: Dict + metadata_: Optional[Dict] = None multi_agent_group: Optional[Any] name: str system: str diff --git a/tests/test_managers.py b/tests/test_managers.py index 1d1dc240..e8da8406 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -145,8 +145,9 @@ def print_tool(server: SyncServer, default_user, default_organization): source_type = "python" description = "test_description" tags = ["test"] + metadata = {"a": "b"} - tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type) + tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata) derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) derived_name = derived_json_schema["name"] @@ -1834,6 +1835,7 @@ def test_get_tool_by_id(server: SyncServer, print_tool, default_user): assert fetched_tool.name == print_tool.name assert fetched_tool.description == print_tool.description assert fetched_tool.tags == print_tool.tags + assert fetched_tool.metadata_ == print_tool.metadata_ assert fetched_tool.source_code == print_tool.source_code assert fetched_tool.source_type == print_tool.source_type assert fetched_tool.tool_type == ToolType.CUSTOM