From d74406af4179e694694bd66bc4d992a1a1897a56 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 25 Oct 2024 14:25:40 -0700 Subject: [PATCH] feat: Add orm for Tools and clean up Tool logic (#1935) --- .github/workflows/code_style_checks.yml | 16 +- .github/workflows/test_cli.yml | 4 +- .github/workflows/tests.yml | 39 +- examples/composio_tool_usage.py | 6 +- examples/crewai_tool_usage.py | 11 +- examples/langchain_tool_usage.py | 24 +- letta/agent.py | 5 - letta/cli/cli.py | 7 +- letta/client/client.py | 145 +++---- letta/config.py | 4 +- letta/constants.py | 9 - letta/functions/functions.py | 24 ++ letta/functions/helpers.py | 7 +- letta/functions/schema_generator.py | 12 +- letta/metadata.py | 97 ----- letta/o1_agent.py | 4 +- letta/orm/__all__.py | 15 + letta/orm/mixins.py | 17 +- letta/orm/organization.py | 2 + letta/orm/sqlalchemy_base.py | 35 +- letta/orm/tool.py | 54 +++ letta/orm/user.py | 8 +- letta/schemas/tool.py | 123 +++--- letta/schemas/user.py | 4 +- letta/server/rest_api/admin/users.py | 2 +- letta/server/rest_api/routers/v1/tools.py | 42 +- letta/server/server.py | 361 ++---------------- letta/services/organization_manager.py | 31 +- letta/services/tool_manager.py | 193 ++++++++++ letta/services/user_manager.py | 27 +- tests/helpers/endpoints_helper.py | 7 +- tests/test_cli.py | 8 +- tests/test_client.py | 9 - ...est_new_client.py => test_local_client.py} | 60 ++- tests/test_managers.py | 169 +++++++- tests/test_server.py | 22 +- tests/test_tools.py | 19 +- 37 files changed, 833 insertions(+), 789 deletions(-) create mode 100644 letta/orm/tool.py create mode 100644 letta/services/tool_manager.py rename tests/{test_new_client.py => test_local_client.py} (88%) diff --git a/.github/workflows/code_style_checks.yml b/.github/workflows/code_style_checks.yml index 2fd9d9b9..8e7b7e94 100644 --- a/.github/workflows/code_style_checks.yml +++ b/.github/workflows/code_style_checks.yml @@ -1,21 +1,13 @@ name: Code Style Checks on: + push: + branches: [ main ] pull_request: - paths: - - '**.py' - pull_request_target: - types: - - opened - - edited - - synchronize - workflow_dispatch: - -permissions: - pull-requests: read + branches: [ main ] jobs: - validation-checks: + style-checks: runs-on: ubuntu-latest strategy: matrix: diff --git a/.github/workflows/test_cli.yml b/.github/workflows/test_cli.yml index 1257572c..6c3a658b 100644 --- a/.github/workflows/test_cli.yml +++ b/.github/workflows/test_cli.yml @@ -1,4 +1,4 @@ -name: Run CLI tests +name: Test CLI env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -10,7 +10,7 @@ on: branches: [ main ] jobs: - test: + test-cli: runs-on: ubuntu-latest timeout-minutes: 15 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cdf3edaa..c15d1a90 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: Run All pytest Tests +name: Unit Tests env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -13,7 +13,7 @@ on: branches: [ main ] jobs: - test: + unit-tests: runs-on: ubuntu-latest timeout-minutes: 15 @@ -37,6 +37,28 @@ jobs: poetry-version: "1.8.2" install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" + - name: Run LocalClient tests + env: + LETTA_PG_PORT: 8888 + LETTA_PG_USER: letta + LETTA_PG_PASSWORD: letta + LETTA_PG_DB: letta + LETTA_PG_HOST: localhost + LETTA_SERVER_PASS: test_server_token + run: | + poetry run pytest -s -vv tests/test_local_client.py + + - name: Run RESTClient tests + env: + LETTA_PG_PORT: 8888 + LETTA_PG_USER: letta + LETTA_PG_PASSWORD: letta + LETTA_PG_DB: letta + LETTA_PG_HOST: localhost + LETTA_SERVER_PASS: test_server_token + run: | + poetry run pytest -s -vv tests/test_client.py + - name: Run server tests env: LETTA_PG_PORT: 8888 @@ -70,6 +92,17 @@ jobs: run: | poetry run pytest -s -vv tests/test_tools.py + - name: Run o1 agent tests + env: + LETTA_PG_PORT: 8888 + LETTA_PG_USER: letta + LETTA_PG_PASSWORD: letta + LETTA_PG_DB: letta + LETTA_PG_HOST: localhost + LETTA_SERVER_PASS: test_server_token + run: | + poetry run pytest -s -vv tests/test_o1_agent.py + - name: Run tests with pytest env: LETTA_PG_PORT: 8888 @@ -80,4 +113,4 @@ jobs: LETTA_SERVER_PASS: test_server_token PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} run: | - poetry run pytest -s -vv -k "not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers" tests + poetry run pytest -s -vv -k "not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests diff --git a/examples/composio_tool_usage.py b/examples/composio_tool_usage.py index 7246a9f4..26508fcd 100644 --- a/examples/composio_tool_usage.py +++ b/examples/composio_tool_usage.py @@ -5,7 +5,6 @@ from letta import create_client from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory -from letta.schemas.tool import Tool """ Setup here. @@ -49,10 +48,7 @@ def main(): from composio_langchain import Action # Add the composio tool - tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - # create tool - client.add_tool(tool) + tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) persona = f""" My name is Letta. diff --git a/examples/crewai_tool_usage.py b/examples/crewai_tool_usage.py index c8d6f1cf..ddd2715e 100644 --- a/examples/crewai_tool_usage.py +++ b/examples/crewai_tool_usage.py @@ -2,8 +2,9 @@ import json import uuid from letta import create_client +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory -from letta.schemas.tool import Tool """ This example show how you can add CrewAI tools . @@ -21,14 +22,14 @@ def main(): crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com") - example_website_scrape_tool = Tool.from_crewai(crewai_tool) - tool_name = example_website_scrape_tool.name - # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) client = create_client() + client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) + client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) # create tool - client.add_tool(example_website_scrape_tool) + example_website_scrape_tool = client.load_crewai_tool(crewai_tool) + tool_name = example_website_scrape_tool.name # Confirm that the tool is in tools = client.list_tools() diff --git a/examples/langchain_tool_usage.py b/examples/langchain_tool_usage.py index e6cd94de..eb207694 100644 --- a/examples/langchain_tool_usage.py +++ b/examples/langchain_tool_usage.py @@ -5,7 +5,6 @@ from letta import create_client from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory -from letta.schemas.tool import Tool """ This example show how you can add LangChain tools . @@ -26,25 +25,22 @@ def main(): api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500) langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) - # Translate to memGPT Tool - # Note the additional_imports_module_attr_map - # We need to pass in a map of all the additional imports necessary to run this tool - # Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool, - # We need to also import WikipediaAPIWrapper - # The map is a mapping of the module name to the attribute name - # langchain_community.utilities.WikipediaAPIWrapper - wikipedia_query_tool = Tool.from_langchain( - langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"} - ) - tool_name = wikipedia_query_tool.name - # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) client = create_client() client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) # create tool - client.add_tool(wikipedia_query_tool) + # Note the additional_imports_module_attr_map + # We need to pass in a map of all the additional imports necessary to run this tool + # Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool, + # We need to also import WikipediaAPIWrapper + # The map is a mapping of the module name to the attribute name + # langchain_community.utilities.WikipediaAPIWrapper + wikipedia_query_tool = client.load_langchain_tool( + langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"} + ) + tool_name = wikipedia_query_tool.name # Confirm that the tool is in tools = client.list_tools() diff --git a/letta/agent.py b/letta/agent.py index 8ffb33ef..c865993d 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -240,7 +240,6 @@ class Agent(BaseAgent): assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}" # link tools - self.tools = tools self.link_tools(tools) # gpt-4, gpt-3.5-turbo, ... @@ -1521,10 +1520,6 @@ def save_agent(agent: Agent, ms: MetadataStore): else: ms.create_agent(agent_state) - for tool in agent.tools: - if ms.get_tool(tool_name=tool.name, user_id=tool.user_id) is None: - ms.create_tool(tool) - agent.agent_state = ms.get_agent(agent_id=agent_id) assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}" diff --git a/letta/cli/cli.py b/letta/cli/cli.py index dbecab15..c6e435a8 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -133,6 +133,7 @@ def run( # read user id from config ms = MetadataStore(config) client = create_client() + server = client.server # determine agent to use, if not provided if not yes and not agent: @@ -217,7 +218,9 @@ def run( ) # create agent - tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools] + tools = [ + server.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=client.user_id) for tool_name in agent_state.tools + ] letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools) else: # create new agent @@ -297,7 +300,7 @@ def run( ) assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}" typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE) - tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools] + tools = [server.tool_manager.get_tool_by_name_and_user_id(tool_name, user_id=client.user_id) for tool_name in agent_state.tools] letta_agent = Agent( interface=interface(), diff --git a/letta/client/client.py b/letta/client/client.py index 35aed129..e21a0cb0 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -182,6 +182,15 @@ class AbstractClient(object): def delete_human(self, id: str): raise NotImplementedError + def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: + raise NotImplementedError + + def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: + raise NotImplementedError + + def load_composio_tool(self, action: "ActionType") -> Tool: + raise NotImplementedError + def create_tool( self, func, @@ -1298,12 +1307,6 @@ class RESTClient(AbstractClient): source_code = parse_source_code(func) source_type = "python" - # TODO: Check if tool already exists - # if name: - # tool_id = self.get_tool_id(tool_name=name) - # if tool_id: - # raise ValueError(f"Tool with name {name} (id={tool_id}) already exists") - # call server function request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags) response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers) @@ -1337,7 +1340,7 @@ class RESTClient(AbstractClient): source_type = "python" - request = ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name) + request = ToolUpdate(source_type=source_type, source_code=source_code, tags=tags, name=name) response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update tool: {response.text}") @@ -1394,6 +1397,7 @@ class RESTClient(AbstractClient): params["cursor"] = str(cursor) if limit: params["limit"] = limit + response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", params=params, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to list tools: {response.text}") @@ -1503,6 +1507,7 @@ class LocalClient(AbstractClient): self, auto_save: bool = False, user_id: Optional[str] = None, + org_id: Optional[str] = None, debug: bool = False, default_llm_config: Optional[LLMConfig] = None, default_embedding_config: Optional[EmbeddingConfig] = None, @@ -1529,15 +1534,19 @@ class LocalClient(AbstractClient): self.interface = QueuingInterface(debug=debug) self.server = SyncServer(default_interface_factory=lambda: self.interface) + # save org_id that `LocalClient` is associated with + if org_id: + self.org_id = org_id + else: + self.org_id = self.server.organization_manager.DEFAULT_ORG_ID # save user_id that `LocalClient` is associated with if user_id: self.user_id = user_id else: # get default user - self.user_id = self.server.get_default_user().id + self.user_id = self.server.user_manager.DEFAULT_USER_ID # agents - def list_agents(self) -> List[AgentState]: self.interface.clear() @@ -2186,50 +2195,27 @@ class LocalClient(AbstractClient): self.server.delete_block(id) # tools + def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: + tool_create = ToolCreate.from_langchain( + langchain_tool=langchain_tool, + user_id=self.user_id, + organization_id=self.org_id, + additional_imports_module_attr_map=additional_imports_module_attr_map, + ) + return self.server.tool_manager.create_or_update_tool(tool_create) - # TODO: merge this into create_tool - def add_tool(self, tool: Tool, update: Optional[bool] = True) -> Tool: - """ - Adds a tool directly. + def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: + tool_create = ToolCreate.from_crewai( + crewai_tool=crewai_tool, + additional_imports_module_attr_map=additional_imports_module_attr_map, + user_id=self.user_id, + organization_id=self.org_id, + ) + return self.server.tool_manager.create_or_update_tool(tool_create) - Args: - tool (Tool): The tool to add. - update (bool, optional): Update the tool if it already exists. Defaults to True. - - Returns: - None - """ - if self.tool_with_name_and_user_id_exists(tool): - if update: - return self.server.update_tool( - ToolUpdate( - id=tool.id, - description=tool.description, - source_type=tool.source_type, - source_code=tool.source_code, - tags=tool.tags, - json_schema=tool.json_schema, - name=tool.name, - ), - self.user_id, - ) - else: - raise ValueError(f"Tool with id={tool.id} and name={tool.name}already exists") - else: - # call server function - return self.server.create_tool( - ToolCreate( - id=tool.id, - description=tool.description, - source_type=tool.source_type, - source_code=tool.source_code, - name=tool.name, - json_schema=tool.json_schema, - tags=tool.tags, - ), - user_id=self.user_id, - update=update, - ) + def load_composio_tool(self, action: "ActionType") -> Tool: + tool_create = ToolCreate.from_composio(action=action, user_id=self.user_id, organization_id=self.org_id) + return self.server.tool_manager.create_or_update_tool(tool_create) # TODO: Use the above function `add_tool` here as there is duplicate logic def create_tool( @@ -2262,11 +2248,16 @@ class LocalClient(AbstractClient): tags = [] # call server function - return self.server.create_tool( - # ToolCreate(source_type=source_type, source_code=source_code, name=tool_name, json_schema=json_schema, tags=tags), - ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags, terminal=terminal), - user_id=self.user_id, - update=update, + return self.server.tool_manager.create_or_update_tool( + ToolCreate( + user_id=self.user_id, + organization_id=self.org_id, + source_type=source_type, + source_code=source_code, + name=name, + tags=tags, + terminal=terminal, + ), ) def update_tool( @@ -2288,16 +2279,17 @@ class LocalClient(AbstractClient): Returns: tool (Tool): Updated tool """ - if func: - source_code = parse_source_code(func) - else: - source_code = None + update_data = { + "source_type": "python", # Always include source_type + "source_code": parse_source_code(func) if func else None, + "tags": tags, + "name": name, + } - source_type = "python" + # Filter out any None values from the dictionary + update_data = {key: value for key, value in update_data.items() if value is not None} - return self.server.update_tool( - ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id - ) + return self.server.tool_manager.update_tool_by_id(id, ToolUpdate(**update_data)) def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]: """ @@ -2306,11 +2298,11 @@ class LocalClient(AbstractClient): Returns: tools (List[Tool]): List of tools """ - return self.server.list_tools(cursor=cursor, limit=limit, user_id=self.user_id) + return self.server.tool_manager.list_tools_for_org(cursor=cursor, limit=limit, organization_id=self.org_id) def get_tool(self, id: str) -> Optional[Tool]: """ - Get a tool give its ID. + Get a tool given its ID. Args: id (str): ID of the tool @@ -2318,7 +2310,7 @@ class LocalClient(AbstractClient): Returns: tool (Tool): Tool """ - return self.server.get_tool(id) + return self.server.tool_manager.get_tool_by_id(id) def delete_tool(self, id: str): """ @@ -2327,11 +2319,11 @@ class LocalClient(AbstractClient): Args: id (str): ID of the tool """ - return self.server.delete_tool(id) + return self.server.tool_manager.delete_tool_by_id(id) def get_tool_id(self, name: str) -> Optional[str]: """ - Get the ID of a tool + Get the ID of a tool from its name. The client will use the org_id it is configured with. Args: name (str): Name of the tool @@ -2339,19 +2331,8 @@ class LocalClient(AbstractClient): Returns: id (str): ID of the tool (`None` if not found) """ - return self.server.get_tool_id(name, self.user_id) - - def tool_with_name_and_user_id_exists(self, tool: Tool) -> bool: - """ - Check if the tool with name and user_id exists - - Args: - tool (Tool): the tool - - Returns: - (bool): True if the id exists, False otherwise. - """ - return self.server.tool_with_name_and_user_id_exists(tool, self.user_id) + tool = self.server.tool_manager.get_tool_by_name_and_org_id(tool_name=name, organization_id=self.org_id) + return tool.id def load_data(self, connector: DataConnector, source_name: str): """ diff --git a/letta/config.py b/letta/config.py index b07c7f4c..cd1ec37c 100644 --- a/letta/config.py +++ b/letta/config.py @@ -13,13 +13,13 @@ from letta.constants import ( DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET, - DEFAULT_USER_ID, LETTA_DIR, ) from letta.log import get_logger from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.services.user_manager import UserManager logger = get_logger(__name__) @@ -45,7 +45,7 @@ def set_field(config, section, field, value): @dataclass class LettaConfig: config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(LETTA_DIR, "config") - anon_clientid: str = DEFAULT_USER_ID + anon_clientid: str = UserManager.DEFAULT_USER_ID # preset preset: str = DEFAULT_PRESET # TODO: rename to system prompt diff --git a/letta/constants.py b/letta/constants.py index a581f7b3..39319c5a 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -3,15 +3,6 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta") -# Defaults -DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000" -# This UUID follows the UUID4 rules: -# The 13th character (4) indicates it's version 4. -# The first character of the third segment (8) ensures the variant is correctly set. -DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000" -DEFAULT_USER_NAME = "default_user" -DEFAULT_ORG_NAME = "default_org" - # String in the error message for when the context window is too large # Example full message: diff --git a/letta/functions/functions.py b/letta/functions/functions.py index b06d1402..422cf72b 100644 --- a/letta/functions/functions.py +++ b/letta/functions/functions.py @@ -3,9 +3,33 @@ import inspect import os from textwrap import dedent # remove indentation from types import ModuleType +from typing import Optional from letta.constants import CLI_WARNING_PREFIX from letta.functions.schema_generator import generate_schema +from letta.schemas.tool import ToolCreate + + +def derive_openai_json_schema(tool_create: ToolCreate) -> dict: + # auto-generate openai schema + try: + # Define a custom environment with necessary imports + env = { + "Optional": Optional, # Add any other required imports here + } + + env.update(globals()) + exec(tool_create.source_code, env) + + # get available functions + functions = [f for f in env if callable(env[f])] + + # TODO: not sure if this always works + func = env[functions[-1]] + json_schema = generate_schema(func, terminal=tool_create.terminal, name=tool_create.name) + return json_schema + except Exception as e: + raise RuntimeError(f"Failed to execute source code: {e}") def parse_source_code(func) -> str: diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index addcbbc7..fd12463f 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -1,5 +1,6 @@ from typing import Any, Optional, Union +import humps from pydantic import BaseModel @@ -8,7 +9,7 @@ def generate_composio_tool_wrapper(action: "ActionType") -> tuple[str, str]: tool_instantiation_str = f"composio_toolset.get_tools(actions=[Action.{str(action)}])[0]" # Generate func name - func_name = f"run_{action.name.lower()}" + func_name = action.name.lower() wrapper_function_str = f""" def {func_name}(**kwargs): @@ -40,7 +41,7 @@ def generate_langchain_tool_wrapper( tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}" run_call = f"return tool._run(**kwargs)" - func_name = f"run_{tool_name.lower()}" + func_name = humps.decamelize(tool_name) # Combine all parts into the wrapper function wrapper_function_str = f""" @@ -70,7 +71,7 @@ def generate_crewai_tool_wrapper(tool: "CrewAIBaseTool", additional_imports_modu tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}" run_call = f"return tool._run(**kwargs)" - func_name = f"run_{tool_name.lower()}" + func_name = humps.decamelize(tool_name) # Combine all parts into the wrapper function wrapper_function_str = f""" diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index fbd65b97..2abb0f4b 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -74,7 +74,7 @@ def pydantic_model_to_open_ai(model): } -def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None): +def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None) -> dict: # Get the signature of the function sig = inspect.signature(function) @@ -139,7 +139,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No def generate_schema_from_args_schema( - args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None + args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True ) -> Dict[str, Any]: properties = {} required = [] @@ -163,4 +163,12 @@ def generate_schema_from_args_schema( "parameters": {"type": "object", "properties": properties, "required": required}, } + # append heartbeat (necessary for triggering another reasoning step after this tool call) + if append_heartbeat: + function_call_json["parameters"]["properties"]["request_heartbeat"] = { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", + } + function_call_json["parameters"]["required"].append("request_heartbeat") + return function_call_json diff --git a/letta/metadata.py b/letta/metadata.py index e36fbad6..b0150ac7 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -14,8 +14,6 @@ from sqlalchemy import ( Integer, String, TypeDecorator, - asc, - or_, ) from sqlalchemy.sql import func @@ -32,7 +30,6 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction from letta.schemas.source import Source -from letta.schemas.tool import Tool from letta.schemas.user import User from letta.settings import settings from letta.utils import enforce_types, get_utc_time, printd @@ -359,37 +356,6 @@ class BlockModel(Base): ) -class ToolModel(Base): - __tablename__ = "tools" - __table_args__ = {"extend_existing": True} - - id = Column(String, primary_key=True) - name = Column(String, nullable=False) - user_id = Column(String) - description = Column(String) - source_type = Column(String) - source_code = Column(String) - json_schema = Column(JSON) - module = Column(String) - tags = Column(JSON) - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> Tool: - return Tool( - id=self.id, - name=self.name, - user_id=self.user_id, - description=self.description, - source_type=self.source_type, - source_code=self.source_code, - json_schema=self.json_schema, - module=self.module, - tags=self.tags, - ) - - class JobModel(Base): __tablename__ = "jobs" __table_args__ = {"extend_existing": True} @@ -516,14 +482,6 @@ class MetadataStore: session.add(BlockModel(**vars(block))) session.commit() - @enforce_types - def create_tool(self, tool: Tool): - with self.session_maker() as session: - if self.get_tool(tool_id=tool.id, tool_name=tool.name, user_id=tool.user_id) is not None: - raise ValueError(f"Tool with name {tool.name} already exists") - session.add(ToolModel(**vars(tool))) - session.commit() - @enforce_types def update_agent(self, agent: AgentState): with self.session_maker() as session: @@ -556,18 +514,6 @@ class MetadataStore: session.add(BlockModel(**vars(block))) session.commit() - @enforce_types - def update_tool(self, tool_id: str, tool: Tool): - with self.session_maker() as session: - session.query(ToolModel).filter(ToolModel.id == tool_id).update(vars(tool)) - session.commit() - - @enforce_types - def delete_tool(self, tool_id: str): - with self.session_maker() as session: - session.query(ToolModel).filter(ToolModel.id == tool_id).delete() - session.commit() - @enforce_types def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]): with self.session_maker() as session: @@ -612,23 +558,6 @@ class MetadataStore: session.commit() - @enforce_types - def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]: - with self.session_maker() as session: - # Query for public tools or user-specific tools - query = session.query(ToolModel).filter(or_(ToolModel.user_id == None, ToolModel.user_id == user_id)) - - # Apply cursor if provided (assuming cursor is an ID) - if cursor: - query = query.filter(ToolModel.id > cursor) - - # Order by ID and apply limit - results = query.order_by(asc(ToolModel.id)).limit(limit).all() - - # Convert to records - res = [r.to_record() for r in results] - return res - @enforce_types def list_agents(self, user_id: str) -> List[AgentState]: with self.session_maker() as session: @@ -672,32 +601,6 @@ class MetadataStore: assert len(results) == 1, f"Expected 1 result, got {len(results)}" return results[0].to_record() - @enforce_types - def get_tool( - self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None - ) -> Optional[ToolModel]: - with self.session_maker() as session: - if tool_id: - results = session.query(ToolModel).filter(ToolModel.id == tool_id).all() - else: - assert tool_name is not None - results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all() - if user_id: - results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all() - if len(results) == 0: - return None - # assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - - @enforce_types - def get_tool_with_name_and_user_id(self, tool_name: Optional[str] = None, user_id: Optional[str] = None) -> Optional[ToolModel]: - with self.session_maker() as session: - results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - @enforce_types def get_block(self, block_id: str) -> Optional[Block]: with self.session_maker() as session: diff --git a/letta/o1_agent.py b/letta/o1_agent.py index ddba1e40..b1aadec4 100644 --- a/letta/o1_agent.py +++ b/letta/o1_agent.py @@ -10,7 +10,7 @@ from letta.schemas.tool import Tool from letta.schemas.usage import LettaUsageStatistics -def send_thinking_message(self: Agent, message: str) -> Optional[str]: +def send_thinking_message(self: "Agent", message: str) -> Optional[str]: """ Sends a thinking message so that the model can reason out loud before responding. @@ -24,7 +24,7 @@ def send_thinking_message(self: Agent, message: str) -> Optional[str]: return None -def send_final_message(self: Agent, message: str) -> Optional[str]: +def send_final_message(self: "Agent", message: str) -> Optional[str]: """ Sends a final message to the human user after thinking for a while. diff --git a/letta/orm/__all__.py b/letta/orm/__all__.py index e69de29b..ed823219 100644 --- a/letta/orm/__all__.py +++ b/letta/orm/__all__.py @@ -0,0 +1,15 @@ +"""__all__ acts as manual import management to avoid collisions and circular imports.""" + +# from letta.orm.agent import Agent +# from letta.orm.users_agents import UsersAgents +# from letta.orm.blocks_agents import BlocksAgents +# from letta.orm.token import Token +# from letta.orm.source import Source +# from letta.orm.document import Document +# from letta.orm.passage import Passage +# from letta.orm.memory_templates import MemoryTemplate, HumanMemoryTemplate, PersonaMemoryTemplate +# from letta.orm.sources_agents import SourcesAgents +# from letta.orm.tools_agents import ToolsAgents +# from letta.orm.job import Job +# from letta.orm.block import Block +# from letta.orm.message import Message diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index 71845b6e..6ff3ec19 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -55,7 +55,6 @@ class OrganizationMixin(Base): __abstract__ = True - # Changed _organization_id to store string (still a valid UUID4 string) _organization_id: Mapped[str] = mapped_column(String, ForeignKey("organization._id")) @property @@ -65,3 +64,19 @@ class OrganizationMixin(Base): @organization_id.setter def organization_id(self, value: str) -> None: _relation_setter(self, "organization", value) + + +class UserMixin(Base): + """Mixin for models that belong to a user.""" + + __abstract__ = True + + _user_id: Mapped[str] = mapped_column(String, ForeignKey("user._id")) + + @property + def user_id(self) -> str: + return _relation_getter(self, "user") + + @user_id.setter + def user_id(self, value: str) -> None: + _relation_setter(self, "user", value) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 244c49e0..51e87e8a 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization if TYPE_CHECKING: + from letta.orm.tool import Tool from letta.orm.user import User @@ -19,6 +20,7 @@ class Organization(SqlalchemyBase): name: Mapped[str] = mapped_column(doc="The display name of the organization.") users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") + tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") # TODO: Map these relationships later when we actually make these models # below is just a suggestion diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 0e0a3821..a23d03da 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -184,21 +184,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): logger.warning("to_record is deprecated, use to_pydantic instead.") return self.to_pydantic() - # TODO: Look into this later and maybe add back? - # def _infer_organization(self, db_session: "Session") -> None: - # """🪄 MAGIC ALERT! 🪄 - # Because so much of the original API is centered around user scopes, - # this allows us to continue with that scope and then infer the org from the creating user. - # - # IF a created_by_id is set, we will use that to infer the organization and magic set it at create time! - # If not do nothing to the object. Mutates in place. - # """ - # if self.created_by_id and hasattr(self, "_organization_id"): - # try: - # from letta.orm.user import User # to avoid circular import - # - # created_by = User.read(db_session, self.created_by_id) - # except NoResultFound: - # logger.warning(f"User {self.created_by_id} not found, unable to infer organization.") - # return - # self._organization_id = created_by._organization_id + def _infer_organization(self, db_session: "Session") -> None: + """🪄 MAGIC ALERT! 🪄 + Because so much of the original API is centered around user scopes, + this allows us to continue with that scope and then infer the org from the creating user. + + IF a created_by_id is set, we will use that to infer the organization and magic set it at create time! + If not do nothing to the object. Mutates in place. + """ + if self.created_by_id and hasattr(self, "_organization_id"): + try: + from letta.orm.user import User # to avoid circular import + + created_by = User.read(db_session, self.created_by_id) + except NoResultFound: + logger.warning(f"User {self.created_by_id} not found, unable to infer organization.") + return + self._organization_id = created_by._organization_id diff --git a/letta/orm/tool.py b/letta/orm/tool.py new file mode 100644 index 00000000..158ed235 --- /dev/null +++ b/letta/orm/tool.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy import JSON, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +# TODO everything in functions should live in this model +from letta.orm.enums import ToolSourceType +from letta.orm.mixins import OrganizationMixin, UserMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.tool import Tool as PydanticTool + +if TYPE_CHECKING: + pass + + from letta.orm.organization import Organization + from letta.orm.user import User + + +class Tool(SqlalchemyBase, OrganizationMixin, UserMixin): + """Represents an available tool that the LLM can invoke. + + NOTE: polymorphic inheritance makes more sense here as a TODO. We want a superset of tools + that are always available, and a subset scoped to the organization. Alternatively, we could use the apply_access_predicate to build + more granular permissions. + """ + + __tablename__ = "tool" + __pydantic_model__ = PydanticTool + + # Add unique constraint on (name, _organization_id) + # An organization should not have multiple tools with the same name + __table_args__ = ( + UniqueConstraint("name", "_organization_id", name="uix_name_organization"), + UniqueConstraint("name", "_user_id", name="uix_name_user"), + ) + + name: Mapped[str] = mapped_column(doc="The display name of the tool.") + description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.") + tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.") + source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json) + source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.") + json_schema: Mapped[dict] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.") + module: Mapped[Optional[str]] = mapped_column( + String, nullable=True, doc="the module path from which this tool was derived in the codebase." + ) + + # TODO: add terminal here eventually + # This was an intentional decision by Sarah + + # relationships + # TODO: Possibly add in user in the future + # This will require some more thought and justification to add this in. + user: Mapped["User"] = relationship("User", back_populates="tools", lazy="selectin") + organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin") diff --git a/letta/orm/user.py b/letta/orm/user.py index bb555721..31bd40f8 100644 --- a/letta/orm/user.py +++ b/letta/orm/user.py @@ -1,10 +1,15 @@ +from typing import TYPE_CHECKING, List + from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin -from letta.orm.organization import Organization from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.user import User as PydanticUser +if TYPE_CHECKING: + from letta.orm.organization import Organization + from letta.orm.tool import Tool + class User(SqlalchemyBase, OrganizationMixin): """User ORM class""" @@ -16,6 +21,7 @@ class User(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="users") + tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="user", cascade="all, delete-orphan") # TODO: Add this back later potentially # agents: Mapped[List["Agent"]] = relationship( diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index c1aa7ee7..e31a3856 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -10,19 +10,13 @@ from letta.functions.helpers import ( from letta.functions.schema_generator import generate_schema_from_args_schema from letta.schemas.letta_base import LettaBase from letta.schemas.openai.chat_completions import ToolCall +from letta.services.organization_manager import OrganizationManager +from letta.services.user_manager import UserManager class BaseTool(LettaBase): __id_prefix__ = "tool" - # optional fields - description: Optional[str] = Field(None, description="The description of the tool.") - source_type: Optional[str] = Field(None, description="The type of the source code.") - module: Optional[str] = Field(None, description="The module of the function.") - - # optional: user_id (user-specific tools) - user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the function.") - class Tool(BaseTool): """ @@ -37,8 +31,12 @@ class Tool(BaseTool): """ - id: str = BaseTool.generate_id_field() - + id: str = Field(..., description="The id of the tool.") + description: Optional[str] = Field(None, description="The description of the tool.") + source_type: Optional[str] = Field(None, description="The type of the source code.") + module: Optional[str] = Field(None, description="The module of the function.") + user_id: str = Field(..., description="The unique identifier of the user associated with the tool.") + organization_id: str = Field(..., description="The unique identifier of the organization associated with the tool.") name: str = Field(..., description="The name of the function.") tags: List[str] = Field(..., description="Metadata tags.") @@ -58,14 +56,31 @@ class Tool(BaseTool): ) ) + +class ToolCreate(LettaBase): + user_id: str = Field(UserManager.DEFAULT_USER_ID, description="The user that this tool belongs to. Defaults to the default user ID.") + organization_id: str = Field( + OrganizationManager.DEFAULT_ORG_ID, + description="The organization that this tool belongs to. Defaults to the default organization ID.", + ) + name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).") + description: Optional[str] = Field(None, description="The description of the tool.") + tags: List[str] = Field([], description="Metadata tags.") + module: Optional[str] = Field(None, description="The source code of the function.") + source_code: str = Field(..., description="The source code of the function.") + source_type: str = Field(..., description="The source type of the function.") + json_schema: Optional[Dict] = Field( + None, description="The JSON schema of the function (auto-generated from source_code if not provided)" + ) + terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).") + @classmethod - def get_composio_tool( - cls, - action: "ActionType", - ) -> "Tool": + def from_composio( + cls, action: "ActionType", user_id: str = UserManager.DEFAULT_USER_ID, organization_id: str = OrganizationManager.DEFAULT_ORG_ID + ) -> "ToolCreate": """ Class method to create an instance of Letta-compatible Composio Tool. - Check https://docs.composio.dev/introduction/intro/overview to look at options for get_composio_tool + Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio This function will error if we find more than one tool, or 0 tools. @@ -90,14 +105,9 @@ class Tool(BaseTool): wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action) json_schema = generate_schema_from_args_schema(composio_tool.args_schema, name=wrapper_func_name, description=description) - # append heartbeat (necessary for triggering another reasoning step after this tool call) - json_schema["parameters"]["properties"]["request_heartbeat"] = { - "type": "boolean", - "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", - } - json_schema["parameters"]["required"].append("request_heartbeat") - return cls( + user_id=user_id, + organization_id=organization_id, name=wrapper_func_name, description=description, source_type=source_type, @@ -107,7 +117,13 @@ class Tool(BaseTool): ) @classmethod - def from_langchain(cls, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool": + def from_langchain( + cls, + langchain_tool: "LangChainBaseTool", + additional_imports_module_attr_map: dict[str, str] = None, + user_id: str = UserManager.DEFAULT_USER_ID, + organization_id: str = OrganizationManager.DEFAULT_ORG_ID, + ) -> "ToolCreate": """ Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools). @@ -125,14 +141,9 @@ class Tool(BaseTool): wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool, additional_imports_module_attr_map) json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description) - # append heartbeat (necessary for triggering another reasoning step after this tool call) - json_schema["parameters"]["properties"]["request_heartbeat"] = { - "type": "boolean", - "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", - } - json_schema["parameters"]["required"].append("request_heartbeat") - return cls( + user_id=user_id, + organization_id=organization_id, name=wrapper_func_name, description=description, source_type=source_type, @@ -142,7 +153,13 @@ class Tool(BaseTool): ) @classmethod - def from_crewai(cls, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool": + def from_crewai( + cls, + crewai_tool: "CrewAIBaseTool", + additional_imports_module_attr_map: dict[str, str] = None, + user_id: str = UserManager.DEFAULT_USER_ID, + organization_id: str = OrganizationManager.DEFAULT_ORG_ID, + ) -> "ToolCreate": """ Class method to create an instance of Tool from a crewAI BaseTool object. @@ -158,14 +175,9 @@ class Tool(BaseTool): wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool, additional_imports_module_attr_map) json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description) - # append heartbeat (necessary for triggering another reasoning step after this tool call) - json_schema["parameters"]["properties"]["request_heartbeat"] = { - "type": "boolean", - "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", - } - json_schema["parameters"]["required"].append("request_heartbeat") - return cls( + user_id=user_id, + organization_id=organization_id, name=wrapper_func_name, description=description, source_type=source_type, @@ -175,54 +187,43 @@ class Tool(BaseTool): ) @classmethod - def load_default_langchain_tools(cls) -> List["Tool"]: + def load_default_langchain_tools(cls) -> List["ToolCreate"]: # For now, we only support wikipedia tool from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper - wikipedia_tool = Tool.from_langchain( + wikipedia_tool = ToolCreate.from_langchain( WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()), {"langchain_community.utilities": "WikipediaAPIWrapper"} ) return [wikipedia_tool] @classmethod - def load_default_crewai_tools(cls) -> List["Tool"]: + def load_default_crewai_tools(cls) -> List["ToolCreate"]: # For now, we only support scrape website tool from crewai_tools import ScrapeWebsiteTool - web_scrape_tool = Tool.from_crewai(ScrapeWebsiteTool()) + web_scrape_tool = ToolCreate.from_crewai(ScrapeWebsiteTool()) return [web_scrape_tool] @classmethod - def load_default_composio_tools(cls) -> List["Tool"]: + def load_default_composio_tools(cls) -> List["ToolCreate"]: from composio_langchain import Action - calculator = Tool.get_composio_tool(action=Action.MATHEMATICAL_CALCULATOR) - serp_news = Tool.get_composio_tool(action=Action.SERPAPI_NEWS_SEARCH) - serp_google_search = Tool.get_composio_tool(action=Action.SERPAPI_SEARCH) - serp_google_maps = Tool.get_composio_tool(action=Action.SERPAPI_GOOGLE_MAPS_SEARCH) + calculator = ToolCreate.from_composio(action=Action.MATHEMATICAL_CALCULATOR) + serp_news = ToolCreate.from_composio(action=Action.SERPAPI_NEWS_SEARCH) + serp_google_search = ToolCreate.from_composio(action=Action.SERPAPI_SEARCH) + serp_google_maps = ToolCreate.from_composio(action=Action.SERPAPI_GOOGLE_MAPS_SEARCH) return [calculator, serp_news, serp_google_search, serp_google_maps] -class ToolCreate(BaseTool): - id: Optional[str] = Field(None, description="The unique identifier of the tool. If this is not provided, it will be autogenerated.") - name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).") - description: Optional[str] = Field(None, description="The description of the tool.") - tags: List[str] = Field([], description="Metadata tags.") - source_code: str = Field(..., description="The source code of the function.") - json_schema: Optional[Dict] = Field( - None, description="The JSON schema of the function (auto-generated from source_code if not provided)" - ) - terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).") - - -class ToolUpdate(ToolCreate): - id: str = Field(..., description="The unique identifier of the tool.") +class ToolUpdate(LettaBase): description: Optional[str] = Field(None, description="The description of the tool.") name: Optional[str] = Field(None, description="The name of the function.") tags: Optional[List[str]] = Field(None, description="Metadata tags.") + module: Optional[str] = Field(None, description="The source code of the function.") source_code: Optional[str] = Field(None, description="The source code of the function.") json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.") + source_type: Optional[str] = Field(None, description="The type of the source code.") diff --git a/letta/schemas/user.py b/letta/schemas/user.py index b499e126..f6947c7c 100644 --- a/letta/schemas/user.py +++ b/letta/schemas/user.py @@ -3,8 +3,8 @@ from typing import Optional from pydantic import Field -from letta.constants import DEFAULT_ORG_ID from letta.schemas.letta_base import LettaBase +from letta.services.organization_manager import OrganizationManager class UserBase(LettaBase): @@ -22,7 +22,7 @@ class User(UserBase): """ id: str = Field(..., description="The id of the user.") - organization_id: Optional[str] = Field(DEFAULT_ORG_ID, description="The organization id of the user") + organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user") name: str = Field(..., description="The name of the user.") created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.") updated_at: datetime = Field(default_factory=datetime.utcnow, description="The update date of the user.") diff --git a/letta/server/rest_api/admin/users.py b/letta/server/rest_api/admin/users.py index e0d333d8..02fd91aa 100644 --- a/letta/server/rest_api/admin/users.py +++ b/letta/server/rest_api/admin/users.py @@ -31,7 +31,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): Create a new user in the database """ try: - user = server.create_user(request) + user = server.user_manager.create_user(request) except HTTPException: raise except Exception as e: diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index e8782b89..35f41a26 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -2,6 +2,7 @@ from typing import List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException +from letta.orm.errors import NoResultFound from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -13,13 +14,12 @@ router = APIRouter(prefix="/tools", tags=["tools"]) def delete_tool( tool_id: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a tool by name """ # actor = server.get_user_or_default(user_id=user_id) - server.delete_tool(tool_id=tool_id) + server.tool_manager.delete_tool(tool_id=tool_id) @router.get("/{tool_id}", response_model=Tool, operation_id="get_tool") @@ -30,9 +30,7 @@ def get_tool( """ Get a tool by ID """ - # actor = server.get_current_user() - - tool = server.get_tool(tool_id=tool_id) + tool = server.tool_manager.get_tool_by_id(tool_id=tool_id) if tool is None: # return 404 error raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.") @@ -50,26 +48,26 @@ def get_tool_id( """ actor = server.get_user_or_default(user_id=user_id) - tool_id = server.get_tool_id(tool_name, user_id=actor.id) - if tool_id is None: - # return 404 error - raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.") - return tool_id + try: + tool = server.tool_manager.get_tool_by_name_and_org_id(tool_name=tool_name, organization_id=actor.organization_id) + return tool.id + except NoResultFound: + raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} and organization id {actor.organization_id} not found.") @router.get("/", response_model=List[Tool], operation_id="list_tools") -def list_all_tools( +def list_tools( cursor: Optional[str] = None, limit: Optional[int] = 50, server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ - Get a list of all tools available to agents created by a user + Get a list of all tools available to agents belonging to the org of the user """ try: actor = server.get_user_or_default(user_id=user_id) - return server.list_tools(cursor=cursor, limit=limit, user_id=actor.id) + return server.tool_manager.list_tools_for_org(organization_id=actor.organization_id, cursor=cursor, limit=limit) except Exception as e: # Log or print the full exception here for debugging print(f"Error occurred: {e}") @@ -78,21 +76,21 @@ def list_all_tools( @router.post("/", response_model=Tool, operation_id="create_tool") def create_tool( - tool: ToolCreate = Body(...), - update: bool = False, + request: ToolCreate = Body(...), server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Create a new tool """ + # Derive user and org id from actor actor = server.get_user_or_default(user_id=user_id) + request.organization_id = actor.organization_id + request.user_id = actor.id - return server.create_tool( - request=tool, - # update=update, - update=True, - user_id=actor.id, + # Send request to create the tool + return server.tool_manager.create_or_update_tool( + tool_create=request, ) @@ -106,6 +104,4 @@ def update_tool( """ Update an existing tool """ - assert tool_id == request.id, "Tool ID in path must match tool ID in request body" - # actor = server.get_user_or_default(user_id=user_id) - return server.update_tool(request, user_id) + return server.tool_manager.update_tool_by_id(tool_id, request) diff --git a/letta/server/server.py b/letta/server/server.py index 34363ad1..98e0ef83 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1,6 +1,4 @@ # inspecting tools -import importlib -import inspect import os import traceback import warnings @@ -16,7 +14,6 @@ import letta.system as system from letta.agent import Agent, save_agent from letta.agent_store.db import attach_base from letta.agent_store.storage import StorageConnector, TableType -from letta.client.utils import derive_function_name_regex from letta.credentials import LettaCredentials from letta.data_sources.connectors import DataConnector, load_data @@ -30,11 +27,7 @@ from letta.data_sources.connectors import DataConnector, load_data # Token, # User, # ) -from letta.functions.functions import ( - generate_schema, - load_function_set, - parse_source_code, -) +from letta.functions.functions import generate_schema, parse_source_code from letta.functions.schema_generator import generate_schema # TODO use custom interface @@ -82,10 +75,11 @@ from letta.schemas.memory import ( from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate -from letta.schemas.tool import Tool, ToolCreate, ToolUpdate +from letta.schemas.tool import Tool, ToolCreate from letta.schemas.usage import LettaUsageStatistics -from letta.schemas.user import User, UserCreate +from letta.schemas.user import User from letta.services.organization_manager import OrganizationManager +from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.utils import create_random_username, json_dumps, json_loads @@ -214,6 +208,7 @@ class SyncServer(Server): chaining: bool = True, max_chaining_steps: Optional[bool] = None, default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(), + init_with_default_org_and_user: bool = True, # default_interface: AgentInterface = CLIInterface(), # default_persistence_manager_cls: PersistenceManager = LocalStateManager, # auth_mode: str = "none", # "none, "jwt", "external" @@ -249,13 +244,19 @@ class SyncServer(Server): # Managers that interface with data models self.organization_manager = OrganizationManager() self.user_manager = UserManager() + self.tool_manager = ToolManager() - # TODO: this should be removed - # add global default tools (for admin) - self.add_default_tools(module_name="base") + # Make default user and org + if init_with_default_org_and_user: + self.default_org = self.organization_manager.create_default_organization() + self.default_user = self.user_manager.create_default_user() + self.add_default_blocks(self.default_user.id) + self.tool_manager.add_default_tools(module_name="base", user_id=self.default_user.id, org_id=self.default_org.id) - if settings.load_default_external_tools: - self.add_default_external_tools() + # If there is a default org/user + # This logic may have to change in the future + if settings.load_default_external_tools: + self.add_default_external_tools(user_id=self.default_user.id, org_id=self.default_org.id) # collect providers (always has Letta as a default) self._enabled_providers: List[Provider] = [LettaProvider()] @@ -364,7 +365,7 @@ class SyncServer(Server): logger.debug(f"Creating an agent object") tool_objs = [] for name in agent_state.tools: - tool_obj = self.ms.get_tool(tool_name=name, user_id=user_id) + tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=name, user_id=user_id) if not tool_obj: logger.exception(f"Tool {name} does not exist for user {user_id}") raise ValueError(f"Tool {name} does not exist for user {user_id}") @@ -755,22 +756,6 @@ class SyncServer(Server): command = command[1:] # strip the prefix return self._command(user_id=user_id, agent_id=agent_id, command=command) - def create_user(self, request: UserCreate) -> User: - """Create a new user using a config""" - if not request.name: - # auto-generate a name - request.name = create_random_username() - user = self.user_manager.create_user(request) - logger.debug(f"Created new user from config: {user}") - - # add default for the user - # TODO: move to org - assert user.id is not None, f"User id is None: {user}" - self.add_default_blocks(user.id) - self.add_default_tools(module_name="base", user_id=user.id) - - return user - def create_agent( self, request: CreateAgent, @@ -816,8 +801,7 @@ class SyncServer(Server): tool_objs = [] if request.tools: for tool_name in request.tools: - tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id) - assert tool_obj, f"Tool {tool_name} does not exist" + tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=user_id) tool_objs.append(tool_obj) assert request.memory is not None @@ -832,16 +816,15 @@ class SyncServer(Server): json_schema = generate_schema(func, terminal=False, name=func_name) source_type = "python" tags = ["memory", "memgpt-base"] - tool = self.create_tool( - request=ToolCreate( + tool = self.tool_manager.create_or_update_tool( + ToolCreate( source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema, user_id=user_id, - ), - update=True, - user_id=user_id, + organization_id=user.organization_id, + ) ) tool_objs.append(tool) if not request.tools: @@ -939,7 +922,7 @@ class SyncServer(Server): # (1) get tools + make sure they exist tool_objs = [] for tool_name in request.tools: - tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id) + tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=user_id) assert tool_obj, f"Tool {tool_name} does not exist" tool_objs.append(tool_obj) @@ -995,12 +978,12 @@ class SyncServer(Server): # Get all the tool objects from the request tool_objs = [] - tool_obj = self.ms.get_tool(tool_id=tool_id, user_id=user_id) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id) assert tool_obj, f"Tool with id={tool_id} does not exist" tool_objs.append(tool_obj) for tool in letta_agent.tools: - tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id) assert tool_obj, f"Tool with id={tool.id} does not exist" # If it's not the already added tool @@ -1035,7 +1018,7 @@ class SyncServer(Server): # Get all the tool_objs tool_objs = [] for tool in letta_agent.tools: - tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id) assert tool_obj, f"Tool with id={tool.id} does not exist" # If it's not the tool we want to remove @@ -1076,86 +1059,6 @@ class SyncServer(Server): agents_states = self.ms.list_agents(user_id=user_id) return agents_states - # TODO make return type pydantic - def list_agents_legacy( - self, - user_id: str, - ) -> dict: - """List all available agents to a user""" - - if user_id is None: - agents_states = self.ms.list_all_agents() - else: - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") - - agents_states = self.ms.list_agents(user_id=user_id) - - agents_states_dicts = [self._agent_state_to_config(state) for state in agents_states] - - # TODO add a get_message_obj_from_message_id(...) function - # this would allow grabbing Message.created_by without having to load the agent object - # all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific - self.ms.list_tools() - - for agent_state, return_dict in zip(agents_states, agents_states_dicts): - - # Get the agent object (loaded in memory) - letta_agent = self._get_or_load_agent(user_id=agent_state.user_id, agent_id=agent_state.id) - - # TODO remove this eventually when return type get pydanticfied - # this is to add persona_name and human_name so that the columns in UI can populate - # TODO hack for frontend, remove - # (top level .persona is persona_name, and nested memory.persona is the state) - # TODO: eventually modify this to be contained in the metadata - return_dict["persona"] = agent_state._metadata.get("persona", None) - return_dict["human"] = agent_state._metadata.get("human", None) - - # Add information about tools - # TODO letta_agent should really have a field of List[ToolModel] - # then we could just pull that field and return it here - # return_dict["tools"] = [tool for tool in all_available_tools if tool.json_schema in letta_agent.functions] - - # get tool info from agent state - tools = [] - for tool_name in agent_state.tools: - tool = self.ms.get_tool(tool_name=tool_name, user_id=user_id) - tools.append(tool) - return_dict["tools"] = tools - - # Add information about memory (raw core, size of recall, size of archival) - core_memory = letta_agent.memory - recall_memory = letta_agent.persistence_manager.recall_memory - archival_memory = letta_agent.persistence_manager.archival_memory - memory_obj = { - "core_memory": core_memory.to_flat_dict(), - "recall_memory": len(recall_memory) if recall_memory is not None else None, - "archival_memory": len(archival_memory) if archival_memory is not None else None, - } - return_dict["memory"] = memory_obj - - # Add information about last run - # NOTE: 'last_run' is just the timestamp on the latest message in the buffer - # Retrieve the Message object via the recall storage or by directly access _messages - last_msg_obj = letta_agent._messages[-1] - return_dict["last_run"] = last_msg_obj.created_at - - # Add information about attached sources - sources_ids = self.ms.list_attached_sources(agent_id=agent_state.id) - sources = [self.ms.get_source(source_id=s_id) for s_id in sources_ids] - return_dict["sources"] = [vars(s) for s in sources] - - # Sort agents by "last_run" in descending order, most recent first - agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True) - - logger.debug(f"Retrieved {len(agents_states)} agents for user {user_id}") - return { - "num_agents": len(agents_states), - "agents": agents_states_dicts, - } - - # blocks - def get_blocks( self, user_id: Optional[str] = None, @@ -1830,195 +1733,17 @@ class SyncServer(Server): return sources_with_metadata - def get_tool(self, tool_id: str) -> Optional[Tool]: - """Get tool by ID.""" - return self.ms.get_tool(tool_id=tool_id) - - def tool_with_name_and_user_id_exists(self, tool: Tool, user_id: Optional[str] = None) -> bool: - """Check if tool exists""" - tool = self.ms.get_tool_with_name_and_user_id(tool_name=tool.name, user_id=user_id) - - if tool is None: - return False - else: - return True - - def get_tool_id(self, name: str, user_id: str) -> Optional[str]: - """Get tool ID from name and user_id.""" - tool = self.ms.get_tool(tool_name=name, user_id=user_id) - if not tool or tool.id is None: - return None - return tool.id - - def update_tool(self, request: ToolUpdate, user_id: Optional[str] = None) -> Tool: - """Update an existing tool""" - if request.name: - existing_tool = self.ms.get_tool_with_name_and_user_id(tool_name=request.name, user_id=user_id) - if existing_tool is None: - raise ValueError(f"Tool with name={request.name}, user_id={user_id} does not exist") - else: - existing_tool = self.ms.get_tool(tool_id=request.id) - if existing_tool is None: - raise ValueError(f"Tool with id={request.id} does not exist") - - # Preserve the original tool id - # As we can override the tool id as well - # This is probably bad design if this is exposed to users... - original_id = existing_tool.id - - # override updated fields - if request.id: - existing_tool.id = request.id - if request.description: - existing_tool.description = request.description - if request.source_code: - existing_tool.source_code = request.source_code - if request.source_type: - existing_tool.source_type = request.source_type - if request.tags: - existing_tool.tags = request.tags - if request.json_schema: - existing_tool.json_schema = request.json_schema - - # If name is explicitly provided here, overide the tool name - if request.name: - existing_tool.name = request.name - # Otherwise, if there's no name, and there's source code, we try to derive the name - elif request.source_code: - existing_tool.name = derive_function_name_regex(request.source_code) - - self.ms.update_tool(original_id, existing_tool) - return self.ms.get_tool(tool_id=request.id) - - def create_tool(self, request: ToolCreate, user_id: Optional[str] = None, update: bool = True) -> Tool: # TODO: add other fields - """Create a new tool""" - - # NOTE: deprecated code that existed when we were trying to pretend that `self` was the memory object - # if request.tags and "memory" in request.tags: - # # special modifications to memory functions - # # self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory) - # request.source_code = request.source_code.replace("self.memory", "self.memory.memory") - - if not request.json_schema: - # auto-generate openai schema - try: - env = {} - env.update(globals()) - exec(request.source_code, env) - - # get available functions - functions = [f for f in env if callable(env[f])] - - except Exception as e: - logger.error(f"Failed to execute source code: {e}") - - # TODO: not sure if this always works - func = env[functions[-1]] - json_schema = generate_schema(func, terminal=request.terminal) - else: - # provided by client - json_schema = request.json_schema - - if not request.name: - # use name from JSON schema - request.name = json_schema["name"] - assert request.name, f"Tool name must be provided in json_schema {json_schema}. This should never happen." - - # check if already exists: - existing_tool = self.ms.get_tool(tool_id=request.id, tool_name=request.name, user_id=user_id) - if existing_tool: - if update: - # id is an optional field, so we will fill it with the existing tool id - if not request.id: - request.id = existing_tool.id - updated_tool = self.update_tool(ToolUpdate(**vars(request)), user_id) - assert updated_tool is not None, f"Failed to update tool {request.name}" - return updated_tool - else: - raise ValueError(f"Tool {request.name} already exists and update=False") - - # check for description - description = None - if request.description: - description = request.description - - tool = Tool( - name=request.name, - source_code=request.source_code, - source_type=request.source_type, - tags=request.tags, - json_schema=json_schema, - user_id=user_id, - description=description, - ) - - if request.id: - tool.id = request.id - - self.ms.create_tool(tool) - created_tool = self.ms.get_tool(tool_id=tool.id, user_id=user_id) - return created_tool - - def delete_tool(self, tool_id: str): - """Delete a tool""" - self.ms.delete_tool(tool_id) - - def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[Tool]: - """List tools available to user_id""" - tools = self.ms.list_tools(cursor=cursor, limit=limit, user_id=user_id) - return tools - - def add_default_tools(self, module_name="base", user_id: Optional[str] = None): - """Add default tools in {module_name}.py""" - full_module_name = f"letta.functions.function_sets.{module_name}" - try: - module = importlib.import_module(full_module_name) - except Exception as e: - # Handle other general exceptions - raise e - - functions_to_schema = [] - try: - # Load the function set - functions_to_schema = load_function_set(module) - except ValueError as e: - err = f"Error loading function set '{module_name}': {e}" - warnings.warn(err) - - # create tool in db - for name, schema in functions_to_schema.items(): - # print([str(inspect.getsource(line)) for line in schema["imports"]]) - source_code = inspect.getsource(schema["python_function"]) - tags = [module_name] - if module_name == "base": - tags.append("letta-base") - - # create to tool - self.create_tool( - ToolCreate( - name=name, - tags=tags, - source_type="python", - module=schema["module"], - source_code=source_code, - json_schema=schema["json_schema"], - user_id=user_id, - ), - update=True, - ) - - def add_default_external_tools(self, user_id: Optional[str] = None) -> bool: + def add_default_external_tools(self, user_id: str, org_id: str) -> bool: """Add default langchain tools. Return true if successful, false otherwise.""" success = True + tool_creates = ToolCreate.load_default_langchain_tools() + ToolCreate.load_default_crewai_tools() if tool_settings.composio_api_key: - tools = Tool.load_default_langchain_tools() + Tool.load_default_crewai_tools() + Tool.load_default_composio_tools() - else: - tools = Tool.load_default_langchain_tools() + Tool.load_default_crewai_tools() - for tool in tools: + tool_creates += ToolCreate.load_default_composio_tools() + for tool_create in tool_creates: try: - self.ms.create_tool(tool) + self.tool_manager.create_or_update_tool(tool_create) except Exception as e: - warnings.warn(f"An error occurred while creating tool {tool}: {e}") + warnings.warn(f"An error occurred while creating tool {tool_create}: {e}") warnings.warn(traceback.format_exc()) success = False @@ -2108,25 +1833,15 @@ class SyncServer(Server): letta_agent = self._get_or_load_agent(agent_id=agent_id) return letta_agent.retry_message() - # TODO: Move a lot of this default logic to the ORM - def get_default_user(self) -> User: - self.organization_manager.create_default_organization() - user = self.user_manager.create_default_user() - - self.add_default_blocks(user.id) - self.add_default_tools(module_name="base", user_id=user.id) - - return user - def get_user_or_default(self, user_id: Optional[str]) -> User: """Get the user object for user_id if it exists, otherwise return the default user object""" if user_id is None: - return self.get_default_user() - else: - try: - return self.user_manager.get_user_by_id(user_id=user_id) - except ValueError: - raise HTTPException(status_code=404, detail=f"User with id {user_id} not found") + user_id = self.user_manager.DEFAULT_USER_ID + + try: + return self.user_manager.get_user_by_id(user_id=user_id) + except ValueError: + raise HTTPException(status_code=404, detail=f"User with id {user_id} not found") def list_llm_models(self) -> List[LLMConfig]: """List available models""" diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index 7e90602e..8c13c037 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -1,8 +1,7 @@ from typing import List, Optional -from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME from letta.orm.errors import NoResultFound -from letta.orm.organization import Organization +from letta.orm.organization import Organization as OrganizationModel from letta.schemas.organization import Organization as PydanticOrganization from letta.utils import create_random_username, enforce_types @@ -10,6 +9,9 @@ from letta.utils import create_random_username, enforce_types class OrganizationManager: """Manager class to handle business logic related to Organizations.""" + DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000" + DEFAULT_ORG_NAME = "default_org" + def __init__(self): # This is probably horrible but we reuse this technique from metadata.py # TODO: Please refactor this out @@ -19,12 +21,17 @@ class OrganizationManager: self.session_maker = db_context + @enforce_types + def get_default_organization(self) -> PydanticOrganization: + """Fetch the default organization.""" + return self.get_organization_by_id(self.DEFAULT_ORG_ID) + @enforce_types def get_organization_by_id(self, org_id: str) -> PydanticOrganization: """Fetch an organization by ID.""" with self.session_maker() as session: try: - organization = Organization.read(db_session=session, identifier=org_id) + organization = OrganizationModel.read(db_session=session, identifier=org_id) return organization.to_pydantic() except NoResultFound: raise ValueError(f"Organization with id {org_id} not found.") @@ -33,7 +40,7 @@ class OrganizationManager: def create_organization(self, name: Optional[str] = None) -> PydanticOrganization: """Create a new organization. If a name is provided, it is used, otherwise, a random one is generated.""" with self.session_maker() as session: - org = Organization(name=name if name else create_random_username()) + org = OrganizationModel(name=name if name else create_random_username()) org.create(session) return org.to_pydantic() @@ -43,10 +50,10 @@ class OrganizationManager: with self.session_maker() as session: # Try to get it first try: - org = Organization.read(db_session=session, identifier=DEFAULT_ORG_ID) + org = OrganizationModel.read(db_session=session, identifier=self.DEFAULT_ORG_ID) # If it doesn't exist, make it except NoResultFound: - org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID) + org = OrganizationModel(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID) org.create(session) return org.to_pydantic() @@ -55,22 +62,22 @@ class OrganizationManager: def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: """Update an organization.""" with self.session_maker() as session: - organization = Organization.read(db_session=session, identifier=org_id) + org = OrganizationModel.read(db_session=session, identifier=org_id) if name: - organization.name = name - organization.update(session) - return organization.to_pydantic() + org.name = name + org.update(session) + return org.to_pydantic() @enforce_types def delete_organization_by_id(self, org_id: str): """Delete an organization by marking it as deleted.""" with self.session_maker() as session: - organization = Organization.read(db_session=session, identifier=org_id) + organization = OrganizationModel.read(db_session=session, identifier=org_id) organization.delete(session) @enforce_types def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: """List organizations with pagination based on cursor (org_id) and limit.""" with self.session_maker() as session: - results = Organization.list(db_session=session, cursor=cursor, limit=limit) + results = OrganizationModel.list(db_session=session, cursor=cursor, limit=limit) return [org.to_pydantic() for org in results] diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py new file mode 100644 index 00000000..2f939613 --- /dev/null +++ b/letta/services/tool_manager.py @@ -0,0 +1,193 @@ +import importlib +import inspect +import warnings +from typing import List, Optional + +from letta.functions.functions import derive_openai_json_schema, load_function_set + +# TODO: Remove this once we translate all of these to the ORM +from letta.orm.errors import NoResultFound +from letta.orm.organization import Organization as OrganizationModel +from letta.orm.tool import Tool as ToolModel +from letta.orm.user import User as UserModel +from letta.schemas.tool import Tool as PydanticTool +from letta.schemas.tool import ToolCreate, ToolUpdate +from letta.utils import enforce_types + + +class ToolManager: + """Manager class to handle business logic related to Tools.""" + + def __init__(self): + # Fetching the db_context similarly as in OrganizationManager + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def create_or_update_tool(self, tool_create: ToolCreate) -> PydanticTool: + """Create a new tool based on the ToolCreate schema.""" + # Derive json_schema + derived_json_schema = tool_create.json_schema or derive_openai_json_schema(tool_create) + derived_name = tool_create.name or derived_json_schema["name"] + + try: + # NOTE: We use the organization id here + # This is important, because even if it's a different user, adding the same tool to the org should not happen + tool = self.get_tool_by_name_and_org_id(tool_name=derived_name, organization_id=tool_create.organization_id) + # Put to dict and remove fields that should not be reset + update_data = tool_create.model_dump(exclude={"user_id", "organization_id", "module", "terminal"}, exclude_unset=True) + # Remove redundant update fields + update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value} + + # If there's anything to update + if update_data: + self.update_tool_by_id(tool.id, ToolUpdate(**update_data)) + else: + warnings.warn( + f"`create_or_update_tool` was called with user_id={tool_create.user_id}, organization_id={tool_create.organization_id}, name={tool_create.name}, but found existing tool with nothing to update." + ) + except NoResultFound: + tool_create.json_schema = derived_json_schema + tool_create.name = derived_name + tool = self.create_tool(tool_create) + + return tool + + @enforce_types + def create_tool(self, tool_create: ToolCreate) -> PydanticTool: + """Create a new tool based on the ToolCreate schema.""" + # Create the tool + with self.session_maker() as session: + # Include all fields except 'terminal' (which is not part of ToolModel) at the moment + create_data = tool_create.model_dump(exclude={"terminal"}) + tool = ToolModel(**create_data) # Unpack everything directly into ToolModel + tool.create(session) + + return tool.to_pydantic() + + @enforce_types + def get_tool_by_id(self, tool_id: str) -> PydanticTool: + """Fetch a tool by its ID.""" + with self.session_maker() as session: + try: + # Retrieve tool by id using the Tool model's read method + tool = ToolModel.read(db_session=session, identifier=tool_id) + # Convert the SQLAlchemy Tool object to PydanticTool + return tool.to_pydantic() + except NoResultFound: + raise ValueError(f"Tool with id {tool_id} not found.") + + @enforce_types + def get_tool_by_name_and_user_id(self, tool_name: str, user_id: str) -> PydanticTool: + """Retrieve a tool by its name and organization_id.""" + with self.session_maker() as session: + # Use the list method to apply filters + results = ToolModel.list(db_session=session, name=tool_name, _user_id=UserModel.get_uid_from_identifier(user_id)) + + # Ensure only one result is returned (since there is a unique constraint) + if not results: + raise NoResultFound(f"Tool with name {tool_name} and user_id {user_id} not found.") + + if len(results) > 1: + raise RuntimeError( + f"Multiple tools with name {tool_name} and user_id {user_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error." + ) + + # Return the single result + return results[0] + + @enforce_types + def get_tool_by_name_and_org_id(self, tool_name: str, organization_id: str) -> PydanticTool: + """Retrieve a tool by its name and organization_id.""" + with self.session_maker() as session: + # Use the list method to apply filters + results = ToolModel.list( + db_session=session, name=tool_name, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id) + ) + + # Ensure only one result is returned (since there is a unique constraint) + if not results: + raise NoResultFound(f"Tool with name {tool_name} and organization_id {organization_id} not found.") + + if len(results) > 1: + raise RuntimeError( + f"Multiple tools with name {tool_name} and organization_id {organization_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error." + ) + + # Return the single result + return results[0] + + @enforce_types + def list_tools_for_org(self, organization_id: str, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: + """List all tools with optional pagination using cursor and limit.""" + with self.session_maker() as session: + tools = ToolModel.list( + db_session=session, cursor=cursor, limit=limit, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id) + ) + return [tool.to_pydantic() for tool in tools] + + @enforce_types + def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate) -> None: + """Update a tool by its ID with the given ToolUpdate object.""" + with self.session_maker() as session: + # Fetch the tool by ID + tool = ToolModel.read(db_session=session, identifier=tool_id) + + # Update tool attributes with only the fields that were explicitly set + update_data = tool_update.model_dump(exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(tool, key, value) + + # Save the updated tool to the database + tool.update(db_session=session) + + @enforce_types + def delete_tool_by_id(self, tool_id: str) -> None: + """Delete a tool by its ID.""" + with self.session_maker() as session: + try: + tool = ToolModel.read(db_session=session, identifier=tool_id) + tool.delete(db_session=session) + except NoResultFound: + raise ValueError(f"Tool with id {tool_id} not found.") + + @enforce_types + def add_default_tools(self, user_id: str, org_id: str, module_name="base"): + """Add default tools in {module_name}.py""" + full_module_name = f"letta.functions.function_sets.{module_name}" + try: + module = importlib.import_module(full_module_name) + except Exception as e: + # Handle other general exceptions + raise e + + functions_to_schema = [] + try: + # Load the function set + functions_to_schema = load_function_set(module) + except ValueError as e: + err = f"Error loading function set '{module_name}': {e}" + warnings.warn(err) + + # create tool in db + for name, schema in functions_to_schema.items(): + # print([str(inspect.getsource(line)) for line in schema["imports"]]) + source_code = inspect.getsource(schema["python_function"]) + tags = [module_name] + if module_name == "base": + tags.append("letta-base") + + # create to tool + self.create_or_update_tool( + ToolCreate( + name=name, + tags=tags, + source_type="python", + module=schema["module"], + source_code=source_code, + json_schema=schema["json_schema"], + organization_id=org_id, + user_id=user_id, + ), + ) diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index ddd1fc5b..10116c0b 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -1,20 +1,20 @@ from typing import List, Optional, Tuple -from letta.constants import DEFAULT_ORG_ID, DEFAULT_USER_ID, DEFAULT_USER_NAME - -# TODO: Remove this once we translate all of these to the ORM -from letta.metadata import AgentModel, AgentSourceMappingModel, SourceModel from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.orm.user import User as UserModel from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserCreate, UserUpdate +from letta.services.organization_manager import OrganizationManager from letta.utils import enforce_types class UserManager: """Manager class to handle business logic related to Users.""" + DEFAULT_USER_NAME = "default_user" + DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000" + def __init__(self): # Fetching the db_context similarly as in OrganizationManager from letta.server.server import db_context @@ -22,7 +22,7 @@ class UserManager: self.session_maker = db_context @enforce_types - def create_default_user(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser: + def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser: """Create the default user.""" with self.session_maker() as session: # Make sure the org id exists @@ -33,10 +33,10 @@ class UserManager: # Try to retrieve the user try: - user = UserModel.read(db_session=session, identifier=DEFAULT_USER_ID) + user = UserModel.read(db_session=session, identifier=self.DEFAULT_USER_ID) except NoResultFound: # If it doesn't exist, make it - user = UserModel(id=DEFAULT_USER_ID, name=DEFAULT_USER_NAME, organization_id=org_id) + user = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id) user.create(session) return user.to_pydantic() @@ -73,11 +73,11 @@ class UserManager: user = UserModel.read(db_session=session, identifier=user_id) user.delete(session) - # TODO: Remove this once we have ORM models for the Agent, Source, and AgentSourceMapping + # TODO: Integrate this via the ORM models for the Agent, Source, and AgentSourceMapping # Cascade delete for related models: Agent, Source, AgentSourceMapping - session.query(AgentModel).filter(AgentModel.user_id == user_id).delete() - session.query(SourceModel).filter(SourceModel.user_id == user_id).delete() - session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete() + # session.query(AgentModel).filter(AgentModel.user_id == user_id).delete() + # session.query(SourceModel).filter(SourceModel.user_id == user_id).delete() + # session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete() session.commit() @@ -91,6 +91,11 @@ class UserManager: except NoResultFound: raise ValueError(f"User with id {user_id} not found.") + @enforce_types + def get_default_user(self) -> PydanticUser: + """Fetch the default user.""" + return self.get_user_by_id(self.DEFAULT_USER_ID) + @enforce_types def list_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> Tuple[Optional[str], List[PydanticUser]]: """List users with pagination using cursor (id) and limit.""" diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 225b323b..23f27e9b 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -165,16 +165,13 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse: """ from crewai_tools import ScrapeWebsiteTool - from letta.schemas.tool import Tool - crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com") - tool = Tool.from_crewai(crewai_tool) - tool_name = tool.name # Set up client client = create_client() cleanup(client=client, agent_uuid=agent_uuid) - client.add_tool(tool) + tool = client.load_crewai_tool(crewai_tool=crewai_tool) + tool_name = tool.name # Set up persona for tool usage persona = f""" diff --git a/tests/test_cli.py b/tests/test_cli.py index df93bb4b..2b6c00b2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -41,17 +41,17 @@ def test_letta_run_create_new_agent(swap_letta_config): child = pexpect.spawn("poetry run letta run", encoding="utf-8") # Start the letta run command child.logfile = sys.stdout - child.expect("Creating new agent", timeout=10) + child.expect("Creating new agent", timeout=20) # Optional: LLM model selection try: - child.expect("Select LLM model:", timeout=10) - child.sendline("\033[B\033[B\033[B\033[B\033[B") + child.expect("Select LLM model:", timeout=20) + child.sendline("") except (pexpect.TIMEOUT, pexpect.EOF): print("[WARNING] LLM model selection step was skipped.") # Optional: Embedding model selection try: - child.expect("Select embedding model:", timeout=10) + child.expect("Select embedding model:", timeout=20) child.sendline("text-embedding-ada-002") except (pexpect.TIMEOUT, pexpect.EOF): print("[WARNING] Embedding model selection step was skipped.") diff --git a/tests/test_client.py b/tests/test_client.py index c9619c62..718f6045 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -262,15 +262,6 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta assert human.value == "Human text", "Creating human failed" -# def test_tools(client, agent): -# tools_response = client.list_tools() -# print("TOOLS", tools_response) -# -# tool_name = "TestTool" -# tool_response = client.create_tool(name=tool_name, source_code="print('Hello World')", source_type="python") -# assert tool_response, "Creating tool failed" - - def test_config(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() diff --git a/tests/test_new_client.py b/tests/test_local_client.py similarity index 88% rename from tests/test_new_client.py rename to tests/test_local_client.py index fd502ab8..a7297cb2 100644 --- a/tests/test_new_client.py +++ b/tests/test_local_client.py @@ -1,16 +1,15 @@ import uuid -from typing import Union import pytest from letta import create_client -from letta.client.client import LocalClient, RESTClient +from letta.client.client import LocalClient from letta.schemas.agent import AgentState from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory -from letta.schemas.tool import Tool +from letta.schemas.tool import ToolCreate @pytest.fixture(scope="module") @@ -35,7 +34,7 @@ def agent(client): assert client.get_agent(agent_state.id) is None, f"Failed to properly delete agent {agent_state.id}" -def test_agent(client: Union[LocalClient, RESTClient]): +def test_agent(client: LocalClient): # create agent agent_state_test = client.create_agent( name="test_agent2", @@ -120,18 +119,16 @@ def test_agent(client: Union[LocalClient, RESTClient]): client.delete_agent(agent_state_test.id) -def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent): +def test_agent_add_remove_tools(client: LocalClient, agent): # Create and add two tools to the client # tool 1 from composio_langchain import Action - github_tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - client.add_tool(github_tool) + github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) # tool 2 from crewai_tools import ScrapeWebsiteTool - scrape_website_tool = Tool.from_crewai(ScrapeWebsiteTool(website_url="https://www.example.com")) - client.add_tool(scrape_website_tool) + scrape_website_tool = client.load_crewai_tool(crewai_tool=ScrapeWebsiteTool(website_url="https://www.example.com")) # assert both got added tools = client.list_tools() @@ -171,7 +168,7 @@ def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent): assert scrape_website_tool.name in curr_tool_names -def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]): +def test_agent_with_shared_blocks(client: LocalClient): persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id) human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id) existing_non_template_blocks = [persona_block, human_block] @@ -222,7 +219,7 @@ def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]): client.delete_agent(second_agent_state_test.id) -def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_memory(client: LocalClient, agent: AgentState): # get agent memory original_memory = client.get_in_context_memory(agent.id) assert original_memory is not None @@ -235,7 +232,7 @@ def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState): assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated -def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_archival_memory(client: LocalClient, agent: AgentState): """Test functions for interacting with archival memory store""" # add archival memory @@ -250,7 +247,7 @@ def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentSta client.delete_archival_memory(agent.id, passage.id) -def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_recall_memory(client: LocalClient, agent: AgentState): """Test functions for interacting with recall memory store""" # send message to the agent @@ -274,7 +271,7 @@ def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState assert exists -def test_tools(client: Union[LocalClient, RESTClient]): +def test_tools(client: LocalClient): def print_tool(message: str): """ A tool to print a message @@ -314,20 +311,18 @@ def test_tools(client: Union[LocalClient, RESTClient]): assert client.get_tool(tool.id).tags == extras2 # update tool: source code - client.update_tool(tool.id, func=print_tool2) + client.update_tool(tool.id, name="print_tool2", func=print_tool2) assert client.get_tool(tool.id).name == "print_tool2" -def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]): +def test_tools_from_composio_basic(client: LocalClient): from composio_langchain import Action # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) client = create_client() - tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - # create tool - client.add_tool(tool) + tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) # list tools tools = client.list_tools() @@ -337,18 +332,15 @@ def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]): # The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable -def test_tools_from_crewai(client: Union[LocalClient, RESTClient]): +def test_tools_from_crewai(client: LocalClient): # create crewAI tool from crewai_tools import ScrapeWebsiteTool crewai_tool = ScrapeWebsiteTool() - # Translate to memGPT Tool - tool = Tool.from_crewai(crewai_tool) - # Add the tool - client.add_tool(tool) + tool = client.load_crewai_tool(crewai_tool=crewai_tool) # list tools tools = client.list_tools() @@ -372,18 +364,15 @@ def test_tools_from_crewai(client: Union[LocalClient, RESTClient]): assert expected_content in func(website_url=simple_webpage_url) -def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]): +def test_tools_from_crewai_with_params(client: LocalClient): # create crewAI tool from crewai_tools import ScrapeWebsiteTool crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com") - # Translate to memGPT Tool - tool = Tool.from_crewai(crewai_tool) - # Add the tool - client.add_tool(tool) + tool = client.load_crewai_tool(crewai_tool=crewai_tool) # list tools tools = client.list_tools() @@ -404,7 +393,7 @@ def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]): assert expected_content in func() -def test_tools_from_langchain(client: Union[LocalClient, RESTClient]): +def test_tools_from_langchain(client: LocalClient): # create langchain tool from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper @@ -412,11 +401,10 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]): api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100) langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) - # Translate to memGPT Tool - tool = Tool.from_langchain(langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}) - # Add the tool - client.add_tool(tool) + tool = client.load_langchain_tool( + langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"} + ) # list tools tools = client.list_tools() @@ -436,7 +424,7 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]): assert expected_content in func(query="Albert Einstein") -def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, RESTClient]): +def test_tool_creation_langchain_missing_imports(client: LocalClient): # create langchain tool from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper @@ -447,4 +435,4 @@ def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, REST # Translate to memGPT Tool # Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"} with pytest.raises(RuntimeError): - Tool.from_langchain(langchain_tool) + ToolCreate.from_langchain(langchain_tool) diff --git a/tests/test_managers.py b/tests/test_managers.py index 1f9e5616..b6ec71b7 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2,14 +2,12 @@ import pytest from sqlalchemy import delete import letta.utils as utils -from letta.constants import ( - DEFAULT_ORG_ID, - DEFAULT_ORG_NAME, - DEFAULT_USER_ID, - DEFAULT_USER_NAME, -) +from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.orm.organization import Organization +from letta.orm.tool import Tool from letta.orm.user import User +from letta.schemas.tool import ToolCreate, ToolUpdate +from letta.services.organization_manager import OrganizationManager utils.DEBUG = True from letta.config import LettaConfig @@ -18,21 +16,58 @@ from letta.server.server import SyncServer @pytest.fixture(autouse=True) -def clear_organization_and_user_table(server: SyncServer): +def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: + session.execute(delete(Tool)) # Clear all records from the Tool table session.execute(delete(User)) # Clear all records from the user table session.execute(delete(Organization)) # Clear all records from the organization table session.commit() # Commit the deletion +@pytest.fixture +def tool_fixture(server: SyncServer): + """Fixture to create a tool with default settings and clean up after the test.""" + + def print_tool(message: str): + """ + Args: + message (str): The message to print. + + Returns: + str: The message that was printed. + """ + print(message) + return message + + source_code = parse_source_code(print_tool) + source_type = "python" + description = "test_description" + tags = ["test"] + + org = server.organization_manager.create_default_organization() + user = server.user_manager.create_default_user() + tool_create = ToolCreate( + user_id=user.id, organization_id=org.id, description=description, tags=tags, source_code=source_code, source_type=source_type + ) + derived_json_schema = derive_openai_json_schema(tool_create) + derived_name = derived_json_schema["name"] + tool_create.json_schema = derived_json_schema + tool_create.name = derived_name + + tool = server.tool_manager.create_tool(tool_create) + + # Yield the created tool, organization, and user for use in tests + yield {"tool": tool, "organization": org, "user": user, "tool_create": tool_create} + + @pytest.fixture(scope="module") def server(): config = LettaConfig.load() config.save() - server = SyncServer() + server = SyncServer(init_with_default_org_and_user=False) return server @@ -55,8 +90,8 @@ def test_list_organizations(server: SyncServer): def test_create_default_organization(server: SyncServer): server.organization_manager.create_default_organization() - retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID) - assert retrieved.name == DEFAULT_ORG_NAME + retrieved = server.organization_manager.get_default_organization() + assert retrieved.name == server.organization_manager.DEFAULT_ORG_NAME def test_update_organization_name(server: SyncServer): @@ -105,8 +140,8 @@ def test_list_users(server: SyncServer): def test_create_default_user(server: SyncServer): org = server.organization_manager.create_default_organization() server.user_manager.create_default_user(org_id=org.id) - retrieved = server.user_manager.get_user_by_id(DEFAULT_USER_ID) - assert retrieved.name == DEFAULT_USER_NAME + retrieved = server.user_manager.get_default_user() + assert retrieved.name == server.user_manager.DEFAULT_USER_NAME def test_update_user(server: SyncServer): @@ -124,9 +159,117 @@ def test_update_user(server: SyncServer): # Adjust name user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b)) assert user.name == user_name_b - assert user.organization_id == DEFAULT_ORG_ID + assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID # Adjust org id user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id)) assert user.name == user_name_b assert user.organization_id == test_org.id + + +# ====================================================================================================================== +# Tool Manager Tests +# ====================================================================================================================== +def test_create_tool(server: SyncServer, tool_fixture): + tool = tool_fixture["tool"] + tool_create = tool_fixture["tool_create"] + user = tool_fixture["user"] + org = tool_fixture["organization"] + + # Assertions to ensure the created tool matches the expected values + assert tool.user_id == user.id + assert tool.organization_id == org.id + assert tool.description == tool_create.description + assert tool.tags == tool_create.tags + assert tool.source_code == tool_create.source_code + assert tool.source_type == tool_create.source_type + assert tool.json_schema == derive_openai_json_schema(tool_create) + + +def test_get_tool_by_id(server: SyncServer, tool_fixture): + tool = tool_fixture["tool"] + + # Fetch the tool by ID using the manager method + fetched_tool = server.tool_manager.get_tool_by_id(tool.id) + + # Assertions to check if the fetched tool matches the created tool + assert fetched_tool.id == tool.id + assert fetched_tool.name == tool.name + assert fetched_tool.description == tool.description + assert fetched_tool.tags == tool.tags + assert fetched_tool.source_code == tool.source_code + assert fetched_tool.source_type == tool.source_type + + +def test_get_tool_by_name_and_org_id(server: SyncServer, tool_fixture): + tool = tool_fixture["tool"] + org = tool_fixture["organization"] + + # Fetch the tool by name and organization ID + fetched_tool = server.tool_manager.get_tool_by_name_and_org_id(tool.name, org.id) + + # Assertions to check if the fetched tool matches the created tool + assert fetched_tool.id == tool.id + assert fetched_tool.name == tool.name + assert fetched_tool.organization_id == org.id + assert fetched_tool.description == tool.description + assert fetched_tool.tags == tool.tags + assert fetched_tool.source_code == tool.source_code + assert fetched_tool.source_type == tool.source_type + + +def test_get_tool_by_name_and_user_id(server: SyncServer, tool_fixture): + tool = tool_fixture["tool"] + user = tool_fixture["user"] + + # Fetch the tool by name and organization ID + fetched_tool = server.tool_manager.get_tool_by_name_and_user_id(tool.name, user.id) + + # Assertions to check if the fetched tool matches the created tool + assert fetched_tool.id == tool.id + assert fetched_tool.name == tool.name + assert fetched_tool.user_id == user.id + assert fetched_tool.description == tool.description + assert fetched_tool.tags == tool.tags + assert fetched_tool.source_code == tool.source_code + assert fetched_tool.source_type == tool.source_type + + +def test_list_tools(server: SyncServer, tool_fixture): + tool = tool_fixture["tool"] + org = tool_fixture["organization"] + + # List tools (should include the one created by the fixture) + tools = server.tool_manager.list_tools_for_org(organization_id=org.id) + + # Assertions to check that the created tool is listed + assert len(tools) == 1 + assert any(t.id == tool.id for t in tools) + + +def test_update_tool_by_id(server: SyncServer, tool_fixture): + tool = tool_fixture["tool"] + updated_description = "updated_description" + + # Create a ToolUpdate object to modify the tool's description + tool_update = ToolUpdate(description=updated_description) + + # Update the tool using the manager method + server.tool_manager.update_tool_by_id(tool.id, tool_update) + + # Fetch the updated tool to verify the changes + updated_tool = server.tool_manager.get_tool_by_id(tool.id) + + # Assertions to check if the update was successful + assert updated_tool.description == updated_description + + +def test_delete_tool_by_id(server: SyncServer, tool_fixture): + tool = tool_fixture["tool"] + org = tool_fixture["organization"] + + # Delete the tool using the manager method + server.tool_manager.delete_tool_by_id(tool.id) + + tools = server.tool_manager.list_tools_for_org(organization_id=org.id) + assert len(tools) == 0 diff --git a/tests/test_server.py b/tests/test_server.py index 0622d50e..5bbffdd9 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -25,7 +25,6 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.message import Message from letta.schemas.source import SourceCreate -from letta.schemas.user import UserCreate from letta.server.server import SyncServer from .utils import DummyDataConnector @@ -33,26 +32,9 @@ from .utils import DummyDataConnector @pytest.fixture(scope="module") def server(): - # if os.getenv("OPENAI_API_KEY"): - # create_config("openai") - # credentials = LettaCredentials( - # openai_key=os.getenv("OPENAI_API_KEY"), - # ) - # else: # hosted - # create_config("letta_hosted") - # credentials = LettaCredentials() - config = LettaConfig.load() print("CONFIG PATH", config.config_path) - ## set to use postgres - # config.archival_storage_uri = db_url - # config.recall_storage_uri = db_url - # config.metadata_storage_uri = db_url - # config.archival_storage_type = "postgres" - # config.recall_storage_type = "postgres" - # config.metadata_storage_type = "postgres" - config.save() server = SyncServer() @@ -62,7 +44,7 @@ def server(): @pytest.fixture(scope="module") def org_id(server): # create org - org = server.organization_manager.create_organization(name="test_org") + org = server.organization_manager.create_default_organization() print(f"Created org\n{org.id}") yield org.id @@ -74,7 +56,7 @@ def org_id(server): @pytest.fixture(scope="module") def user_id(server, org_id): # create user - user = server.create_user(UserCreate(name="test_user", organization_id=org_id)) + user = server.user_manager.create_default_user() print(f"Created user\n{user.id}") yield user.id diff --git a/tests/test_tools.py b/tests/test_tools.py index c607aa96..b8507d65 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -56,13 +56,13 @@ def client(request): time.sleep(5) print("Running client tests with server:", server_url) else: - server_url = None assert False, "Local client not implemented" assert server_url is not None client = create_client(base_url=server_url) # This yields control back to the test function client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + # Clear all records from the Tool table yield client @@ -93,7 +93,16 @@ def test_create_tool(client: Union[LocalClient, RESTClient]): return message tools = client.list_tools() - print(f"Original tools {[t.name for t in tools]}") + assert sorted([t.name for t in tools]) == sorted( + [ + "archival_memory_search", + "send_message", + "pause_heartbeats", + "conversation_search", + "conversation_search_date", + "archival_memory_insert", + ] + ) tool = client.create_tool(print_tool, name="my_name", tags=["extras"]) @@ -108,13 +117,15 @@ def test_create_tool(client: Union[LocalClient, RESTClient]): # create agent with tool agent_state = client.create_agent(tools=[tool.name]) - response = client.user_message(agent_id=agent_state.id, message="hi") + + # Send message without error + client.user_message(agent_id=agent_state.id, message="hi") def test_create_agent_tool(client): """Test creation of a agent tool""" - def core_memory_clear(self: Agent): + def core_memory_clear(self: "Agent"): """ Args: agent (Agent): The agent to delete from memory.