diff --git a/alembic/versions/e20573fe9b86_add_tool_types.py b/alembic/versions/e20573fe9b86_add_tool_types.py new file mode 100644 index 00000000..2bd64f2f --- /dev/null +++ b/alembic/versions/e20573fe9b86_add_tool_types.py @@ -0,0 +1,70 @@ +"""Add tool types + +Revision ID: e20573fe9b86 +Revises: 915b68780108 +Create Date: 2025-01-09 15:11:47.779646 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from letta.orm.enums import ToolType + +# revision identifiers, used by Alembic. +revision: str = "e20573fe9b86" +down_revision: Union[str, None] = "915b68780108" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Step 1: Add the column as nullable with no default + op.add_column("tools", sa.Column("tool_type", sa.String(), nullable=True)) + + # Step 2: Backpopulate the tool_type column based on tool name + # Define the list of Letta core tools + letta_core_value = ToolType.LETTA_CORE.value + letta_memory_core_value = ToolType.LETTA_MEMORY_CORE.value + custom_value = ToolType.CUSTOM.value + + # Update tool_type for Letta core tools + op.execute( + f""" + UPDATE tools + SET tool_type = '{letta_core_value}' + WHERE name IN ({','.join(f"'{name}'" for name in BASE_TOOLS)}); + """ + ) + + op.execute( + f""" + UPDATE tools + SET tool_type = '{letta_memory_core_value}' + WHERE name IN ({','.join(f"'{name}'" for name in BASE_MEMORY_TOOLS)}); + """ + ) + + # Update tool_type for all other tools + op.execute( + f""" + UPDATE tools + SET tool_type = '{custom_value}' + WHERE tool_type IS NULL; + """ + ) + + # Step 3: Alter the column to be non-nullable + op.alter_column("tools", "tool_type", nullable=False) + op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + + +def downgrade() -> None: + # Revert the changes made during the upgrade + op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.drop_column("tools", "tool_type") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 34713616..ebf21b82 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -7,11 +7,11 @@ from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union from letta.constants import ( - BASE_TOOLS, CLI_WARNING_PREFIX, ERROR_MESSAGE_PREFIX, FIRST_MESSAGE_ATTEMPTS, FUNC_FAILED_HEARTBEAT_MESSAGE, + LETTA_CORE_TOOL_MODULE_NAME, LLM_MAX_TOKENS, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, @@ -19,6 +19,7 @@ from letta.constants import ( REQ_HEARTBEAT_MESSAGE, ) from letta.errors import ContextWindowExceededError +from letta.functions.functions import get_function_from_module from letta.helpers import ToolRulesSolver from letta.interface import AgentInterface from letta.llm_api.helpers import is_context_overflow_error @@ -26,6 +27,7 @@ from letta.llm_api.llm_api_tools import create from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.memory import summarize_messages from letta.orm import User +from letta.orm.enums import ToolType from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig @@ -153,7 +155,7 @@ class Agent(BaseAgent): raise ValueError(f"Invalid JSON format in message: {msg.text}") return None - def update_memory_if_change(self, new_memory: Memory) -> bool: + def update_memory_if_changed(self, new_memory: Memory) -> bool: """ Update internal memory object and system prompt if there have been modifications. @@ -192,39 +194,45 @@ class Agent(BaseAgent): Execute tool modifications and persist the state of the agent. Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data """ - # TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args. - env = {} - env.update(globals()) - exec(target_letta_tool.source_code, env) - callable_func = env[target_letta_tool.json_schema["name"]] - spec = inspect.getfullargspec(callable_func).annotations - for name, arg in function_args.items(): - if isinstance(function_args[name], dict): - function_args[name] = spec[name](**function_args[name]) - # TODO: add agent manager here orig_memory_str = self.agent_state.memory.compile() # TODO: need to have an AgentState object that actually has full access to the block data # this is because the sandbox tools need to be able to access block.value to edit this data try: - # TODO: This is NO BUENO - # TODO: Matching purely by names is extremely problematic, users can create tools with these names and run them in the agent loop - # TODO: We will have probably have to match the function strings exactly for safety - if function_name in BASE_TOOLS: + if target_letta_tool.tool_type == ToolType.LETTA_CORE: # base tools are allowed to access the `Agent` object and run on the database + callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) function_args["self"] = self # need to attach self to arg since it's dynamically linked function_response = callable_func(**function_args) + elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE: + callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) + agent_state_copy = self.agent_state.__deepcopy__() + function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked + function_response = callable_func(**function_args) + self.update_memory_if_changed(agent_state_copy.memory) else: + # TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args. + env = {} + env.update(globals()) + exec(target_letta_tool.source_code, env) + callable_func = env[target_letta_tool.json_schema["name"]] + spec = inspect.getfullargspec(callable_func).annotations + for name, arg in function_args.items(): + if isinstance(function_args[name], dict): + function_args[name] = spec[name](**function_args[name]) + # execute tool in a sandbox # TODO: allow agent_state to specify which sandbox to execute tools in - sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run( - agent_state=self.agent_state.__deepcopy__() - ) + # TODO: This is only temporary, can remove after we publish a pip package with this object + agent_state_copy = self.agent_state.__deepcopy__() + agent_state_copy.tools = [] + + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(agent_state=agent_state_copy) function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" if updated_agent_state is not None: - self.update_memory_if_change(updated_agent_state.memory) + self.update_memory_if_changed(updated_agent_state.memory) except Exception as e: # Need to catch error here, or else trunction wont happen # TODO: modify to function execution error @@ -677,7 +685,7 @@ class Agent(BaseAgent): current_persisted_memory = Memory( blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()] ) # read blocks from DB - self.update_memory_if_change(current_persisted_memory) + self.update_memory_if_changed(current_persisted_memory) # Step 1: add user message if isinstance(messages, Message): diff --git a/letta/constants.py b/letta/constants.py index e721b1db..139b8f25 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -8,6 +8,9 @@ API_PREFIX = "/v1" OPENAI_API_PREFIX = "/openai" COMPOSIO_ENTITY_ENV_VAR_KEY = "COMPOSIO_ENTITY" +COMPOSIO_TOOL_TAG_NAME = "composio" + +LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base" # String in the error message for when the context window is too large # Example full message: diff --git a/letta/functions/functions.py b/letta/functions/functions.py index 8ccb831b..4195cbee 100644 --- a/letta/functions/functions.py +++ b/letta/functions/functions.py @@ -1,3 +1,4 @@ +import importlib import inspect from textwrap import dedent # remove indentation from types import ModuleType @@ -64,6 +65,70 @@ def parse_source_code(func) -> str: return source_code +def get_function_from_module(module_name: str, function_name: str): + """ + Dynamically imports a function from a specified module. + + Args: + module_name (str): The name of the module to import (e.g., 'base'). + function_name (str): The name of the function to retrieve. + + Returns: + Callable: The imported function. + + Raises: + ModuleNotFoundError: If the specified module cannot be found. + AttributeError: If the function is not found in the module. + """ + try: + # Dynamically import the module + module = importlib.import_module(module_name) + # Retrieve the function + return getattr(module, function_name) + except ModuleNotFoundError: + raise ModuleNotFoundError(f"Module '{module_name}' not found.") + except AttributeError: + raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.") + + +def get_json_schema_from_module(module_name: str, function_name: str) -> dict: + """ + Dynamically loads a specific function from a module and generates its JSON schema. + + Args: + module_name (str): The name of the module to import (e.g., 'base'). + function_name (str): The name of the function to retrieve. + + Returns: + dict: The JSON schema for the specified function. + + Raises: + ModuleNotFoundError: If the specified module cannot be found. + AttributeError: If the function is not found in the module. + ValueError: If the attribute is not a user-defined function. + """ + try: + # Dynamically import the module + module = importlib.import_module(module_name) + + # Retrieve the function + attr = getattr(module, function_name, None) + + # Check if it's a user-defined function + if not (inspect.isfunction(attr) and attr.__module__ == module.__name__): + raise ValueError(f"'{function_name}' is not a user-defined function in module '{module_name}'") + + # Generate schema (assuming a `generate_schema` function exists) + generated_schema = generate_schema(attr) + + return generated_schema + + except ModuleNotFoundError: + raise ModuleNotFoundError(f"Module '{module_name}' not found.") + except AttributeError: + raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.") + + def load_function_set(module: ModuleType) -> dict: """Load the functions and generate schema for them, given a module object""" function_dict = {} diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 271527c6..d7288426 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -109,6 +109,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): """converts to the basic pydantic model counterpart""" state = { "id": self.id, + "organization_id": self.organization_id, "name": self.name, "description": self.description, "message_ids": self.message_ids, diff --git a/letta/orm/tool.py b/letta/orm/tool.py index a25c7ebb..9d744f44 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -4,7 +4,7 @@ from sqlalchemy import JSON, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship # TODO everything in functions should live in this model -from letta.orm.enums import ToolSourceType +from letta.orm.enums import ToolSourceType, ToolType from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.tool import Tool as PydanticTool @@ -29,12 +29,17 @@ class Tool(SqlalchemyBase, OrganizationMixin): __table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),) name: Mapped[str] = mapped_column(doc="The display name of the tool.") + tool_type: Mapped[ToolType] = mapped_column( + String, + default=ToolType.CUSTOM, + doc="The type of tool. This affects whether or not we generate json_schema and source_code on the fly.", + ) return_char_limit: Mapped[int] = mapped_column(nullable=True, doc="The maximum number of characters the tool can return.") description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.") tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.") source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json) source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.") - json_schema: Mapped[dict] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.") + json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.") module: Mapped[Optional[str]] = mapped_column( String, nullable=True, doc="the module path from which this tool was derived in the codebase." ) diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 40a8fbf3..faa32a23 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -2,10 +2,11 @@ from typing import Dict, List, Optional from pydantic import Field, model_validator -from letta.constants import FUNCTION_RETURN_CHAR_LIMIT -from letta.functions.functions import derive_openai_json_schema +from letta.constants import COMPOSIO_TOOL_TAG_NAME, FUNCTION_RETURN_CHAR_LIMIT, LETTA_CORE_TOOL_MODULE_NAME +from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper from letta.functions.schema_generator import generate_schema_from_args_schema_v2 +from letta.orm.enums import ToolType from letta.schemas.letta_base import LettaBase from letta.schemas.openai.chat_completions import ToolCall @@ -28,6 +29,7 @@ class Tool(BaseTool): """ id: str = BaseTool.generate_id_field() + tool_type: ToolType = Field(ToolType.CUSTOM, description="The type of the tool.") description: Optional[str] = Field(None, description="The description of the tool.") source_type: Optional[str] = Field(None, description="The type of the source code.") module: Optional[str] = Field(None, description="The module of the function.") @@ -36,7 +38,7 @@ class Tool(BaseTool): tags: List[str] = Field([], description="Metadata tags.") # code - source_code: str = Field(..., description="The source code of the function.") + source_code: Optional[str] = Field(None, description="The source code of the function.") json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.") # tool configuration @@ -51,9 +53,19 @@ class Tool(BaseTool): """ Populate missing fields: name, description, and json_schema. """ - # Derive JSON schema if not provided - if not self.json_schema: - self.json_schema = derive_openai_json_schema(source_code=self.source_code) + if self.tool_type == ToolType.CUSTOM: + # If it's a custom tool, we need to ensure source_code is present + if not self.source_code: + raise ValueError(f"Custom tool with id={self.id} is missing source_code field.") + + # Always derive json_schema for freshest possible json_schema + # TODO: Instead of checking the tag, we should having `COMPOSIO` as a specific ToolType + # TODO: We skip this for Composio bc composio json schemas are derived differently + if not (COMPOSIO_TOOL_TAG_NAME in self.tags): + self.json_schema = derive_openai_json_schema(source_code=self.source_code) + elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}: + # If it's letta core tool, we generate the json_schema on the fly here + self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name) # Derive name from the JSON schema if not provided if not self.name: @@ -125,7 +137,7 @@ class ToolCreate(LettaBase): description = composio_tool.description source_type = "python" - tags = ["composio"] + tags = [COMPOSIO_TOOL_TAG_NAME] wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action_name) json_schema = generate_schema_from_args_schema_v2(composio_tool.args_schema, name=wrapper_func_name, description=description) diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 739bfb38..1992f213 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -1,10 +1,10 @@ import importlib -import inspect import warnings from typing import List, Optional from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS from letta.functions.functions import derive_openai_json_schema, load_function_set +from letta.orm.enums import ToolType # TODO: Remove this once we translate all of these to the ORM from letta.orm.errors import NoResultFound @@ -32,10 +32,10 @@ class ToolManager: self.session_maker = db_context + # TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object @enforce_types def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" - # Derive json_schema tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor) if tool: # Put to dict and remove fields that should not be reset @@ -63,6 +63,7 @@ class ToolManager: if pydantic_tool.description is None: pydantic_tool.description = pydantic_tool.json_schema.get("description", None) tool_data = pydantic_tool.model_dump() + tool = ToolModel(**tool_data) tool.create(session, actor=actor) # Re-raise other database-related errors return tool.to_pydantic() @@ -113,8 +114,6 @@ class ToolManager: # If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): pydantic_tool = tool.to_pydantic() - - update_data["name"] if "name" in update_data.keys() else None new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code) tool.json_schema = new_schema @@ -155,12 +154,19 @@ class ToolManager: tools = [] for name, schema in functions_to_schema.items(): if name in BASE_TOOLS + BASE_MEMORY_TOOLS: - # print([str(inspect.getsource(line)) for line in schema["imports"]]) - source_code = inspect.getsource(schema["python_function"]) tags = [module_name] if module_name == "base": tags.append("letta-base") + # BASE_MEMORY_TOOLS should be executed in an e2b sandbox + # so they should NOT be letta_core tools, instead, treated as custom tools + if name in BASE_TOOLS: + tool_type = ToolType.LETTA_CORE + elif name in BASE_MEMORY_TOOLS: + tool_type = ToolType.LETTA_MEMORY_CORE + else: + raise ValueError(f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS}") + # create to tool tools.append( self.create_or_update_tool( @@ -168,9 +174,7 @@ class ToolManager: name=name, tags=tags, source_type="python", - module=schema["module"], - source_code=source_code, - json_schema=schema["json_schema"], + tool_type=tool_type, ), actor=actor, ) diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 55881f7c..661dc832 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -197,7 +197,7 @@ def composio_gmail_get_profile_tool(test_user): @pytest.fixture def clear_core_memory_tool(test_user): - def clear_memory(agent_state: AgentState): + def clear_memory(agent_state: "AgentState"): """Clear the core memory""" agent_state.memory.get_block("human").value = "" agent_state.memory.get_block("persona").value = "" diff --git a/tests/test_client.py b/tests/test_client.py index 5db67157..a56d449f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -42,7 +42,9 @@ def run_server(): @pytest.fixture( - params=[{"server": False}, {"server": True}], # whether to use REST API server + params=[ + {"server": False}, + ], # {"server": True}], # whether to use REST API server # params=[{"server": False}], # whether to use REST API server scope="module", ) @@ -121,7 +123,6 @@ def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTCli assert ( "charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower() ), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}" - # assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}" # cleanup client.delete_agent(agent_state1.id) diff --git a/tests/test_managers.py b/tests/test_managers.py index 27746d8c..8fa53ba2 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -30,6 +30,7 @@ from letta.orm import ( User, ) from letta.orm.agents_tags import AgentsTags +from letta.orm.enums import ToolType from letta.orm.errors import NoResultFound, UniqueConstraintViolationError from letta.schemas.agent import CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock @@ -1368,6 +1369,7 @@ def test_get_tool_by_id(server: SyncServer, print_tool, default_user): assert fetched_tool.tags == print_tool.tags assert fetched_tool.source_code == print_tool.source_code assert fetched_tool.source_type == print_tool.source_type + assert fetched_tool.tool_type == ToolType.CUSTOM def test_get_tool_with_actor(server: SyncServer, print_tool, default_user): @@ -1382,6 +1384,7 @@ def test_get_tool_with_actor(server: SyncServer, print_tool, default_user): assert fetched_tool.tags == print_tool.tags assert fetched_tool.source_code == print_tool.source_code assert fetched_tool.source_type == print_tool.source_type + assert fetched_tool.tool_type == ToolType.CUSTOM def test_list_tools(server: SyncServer, print_tool, default_user): @@ -1445,6 +1448,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, p new_schema = derive_openai_json_schema(source_code=updated_tool.source_code) assert updated_tool.json_schema == new_schema + assert updated_tool.tool_type == ToolType.CUSTOM def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print_tool, default_user): @@ -1483,6 +1487,7 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print new_schema = derive_openai_json_schema(source_code=updated_tool.source_code, name=updated_tool.name) assert updated_tool.json_schema == new_schema assert updated_tool.name == name + assert updated_tool.tool_type == ToolType.CUSTOM def test_update_tool_multi_user(server: SyncServer, print_tool, default_user, other_user): @@ -1519,6 +1524,15 @@ def test_upsert_base_tools(server: SyncServer, default_user): tools = server.tool_manager.upsert_base_tools(actor=default_user) assert sorted([t.name for t in tools]) == expected_tool_names + # Confirm that the return tools have no source_code, but a json_schema + for t in tools: + if t.name in BASE_TOOLS: + assert t.tool_type == ToolType.LETTA_CORE + else: + assert t.tool_type == ToolType.LETTA_MEMORY_CORE + assert t.source_code is None + assert t.json_schema + # ====================================================================================================================== # Message Manager Tests