diff --git a/letta/client/client.py b/letta/client/client.py index 8e1cf629..c6f960df 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -225,6 +225,9 @@ class AbstractClient(object): def get_tool_id(self, name: str) -> Optional[str]: raise NotImplementedError + def add_base_tools(self) -> List[Tool]: + raise NotImplementedError + def load_data(self, connector: DataConnector, source_name: str): raise NotImplementedError @@ -1271,6 +1274,13 @@ class RESTClient(AbstractClient): raise ValueError(f"Failed to get tool: {response.text}") return response.json() + def add_base_tools(self) -> List[Tool]: + response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to add base tools: {response.text}") + + return [Tool(**tool) for tool in response.json()] + def create_tool( self, func: Callable, diff --git a/letta/constants.py b/letta/constants.py index ccbd4fb0..fdce01ad 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -36,7 +36,6 @@ DEFAULT_PRESET = "memgpt_chat" # Tools BASE_TOOLS = [ "send_message", - # "pause_heartbeats", "conversation_search", "conversation_search_date", "archival_memory_insert", diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 59c90d94..2e5954e4 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -107,23 +107,32 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): """ # Start the query query = select(cls) + # Collect query conditions for better error reporting + query_conditions = [] # If an identifier is provided, add it to the query conditions if identifier is not None: identifier = cls.get_uid_from_identifier(identifier) query = query.where(cls._id == identifier) + query_conditions.append(f"id='{identifier}'") if kwargs: query = query.filter_by(**kwargs) + query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items())) if actor: query = cls.apply_access_predicate(query, actor, access) + query_conditions.append(f"access level in {access} for actor='{actor}'") if hasattr(cls, "is_deleted"): query = query.where(cls.is_deleted == False) + query_conditions.append("is_deleted=False") if found := db_session.execute(query).scalar(): return found - raise NoResultFound(f"{cls.__name__} with id {identifier} not found") + + # Construct a detailed error message based on query conditions + conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions" + raise NoResultFound(f"{cls.__name__} not found with {conditions_str}") def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]: if actor: diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index f12fcd18..d1b442f7 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -104,3 +104,15 @@ def update_tool( """ actor = server.get_user_or_default(user_id=user_id) return server.tool_manager.update_tool_by_id(tool_id, actor.id, request) + + +@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools") +def add_base_tools( + server: SyncServer = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Add base tools + """ + actor = server.get_user_or_default(user_id=user_id) + return server.tool_manager.add_base_tools(actor=actor) diff --git a/letta/server/server.py b/letta/server/server.py index a05c87cb..33900dd3 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -254,7 +254,7 @@ class SyncServer(Server): self.default_org = self.organization_manager.create_default_organization() self.default_user = self.user_manager.create_default_user() self.add_default_blocks(self.default_user.id) - self.tool_manager.add_default_tools(module_name="base", actor=self.default_user) + self.tool_manager.add_base_tools(actor=self.default_user) # If there is a default org/user # This logic may have to change in the future diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index b8325b6c..7658464c 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -18,6 +18,14 @@ from letta.utils import enforce_types class ToolManager: """Manager class to handle business logic related to Tools.""" + BASE_TOOL_NAMES = [ + "send_message", + "conversation_search", + "conversation_search_date", + "archival_memory_insert", + "archival_memory_search", + ] + def __init__(self): # Fetching the db_context similarly as in OrganizationManager from letta.server.server import db_context @@ -137,8 +145,9 @@ class ToolManager: raise ValueError(f"Tool with id {tool_id} not found.") @enforce_types - def add_default_tools(self, actor: PydanticUser, module_name="base"): - """Add default tools in {module_name}.py""" + def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: + """Add default tools in base.py""" + module_name = "base" full_module_name = f"letta.functions.function_sets.{module_name}" try: module = importlib.import_module(full_module_name) @@ -155,22 +164,28 @@ class ToolManager: warnings.warn(err) # create tool in db + tools = [] for name, schema in functions_to_schema.items(): - # print([str(inspect.getsource(line)) for line in schema["imports"]]) - source_code = inspect.getsource(schema["python_function"]) - tags = [module_name] - if module_name == "base": - tags.append("letta-base") + if name in self.BASE_TOOL_NAMES: + # print([str(inspect.getsource(line)) for line in schema["imports"]]) + source_code = inspect.getsource(schema["python_function"]) + tags = [module_name] + if module_name == "base": + tags.append("letta-base") - # create to tool - self.create_or_update_tool( - ToolCreate( - name=name, - tags=tags, - source_type="python", - module=schema["module"], - source_code=source_code, - json_schema=schema["json_schema"], - ), - actor=actor, - ) + # create to tool + tools.append( + self.create_or_update_tool( + ToolCreate( + name=name, + tags=tags, + source_type="python", + module=schema["module"], + source_code=source_code, + json_schema=schema["json_schema"], + ), + actor=actor, + ) + ) + + return tools diff --git a/tests/test_client.py b/tests/test_client.py index 14e251e0..2c990ff1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -26,6 +26,7 @@ from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics +from letta.services.tool_manager import ToolManager from letta.settings import model_settings from tests.helpers.client_helper import upload_file_using_client @@ -299,7 +300,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta assert human.value == "Human text", "Creating human failed" -def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): tools = client.list_tools() visited_ids = {t.id: False for t in tools} @@ -321,6 +322,13 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: Ag assert all(visited_ids.values()) +def test_list_tools(client: Union[LocalClient, RESTClient]): + tools = client.add_base_tools() + tool_names = [t.name for t in tools] + expected = ToolManager.BASE_TOOL_NAMES + assert sorted(tool_names) == sorted(expected) + + def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): # clear sources for source in client.list_sources(): diff --git a/tests/test_tools.py b/tests/test_tools.py index f987afca..4195b220 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -97,7 +97,6 @@ def test_create_tool(client: Union[LocalClient, RESTClient]): [ "archival_memory_search", "send_message", - "pause_heartbeats", "conversation_search", "conversation_search_date", "archival_memory_insert",