diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 615811d7..b5117408 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -14,6 +14,8 @@ from starlette.middleware.cors import CORSMiddleware from letta.__init__ import __version__ from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError +from letta.log import get_logger +from letta.orm.errors import NoResultFound from letta.schemas.letta_response import LettaResponse from letta.server.constants import REST_DEFAULT_PORT @@ -45,6 +47,7 @@ from letta.settings import settings # NOTE(charles): @ethan I had to add this to get the global as the bottom to work interface: StreamingServerInterface = StreamingServerInterface server = SyncServer(default_interface_factory=lambda: interface()) +logger = get_logger(__name__) # TODO: remove password = None @@ -170,6 +173,16 @@ def create_application() -> "FastAPI": }, ) + @app.exception_handler(NoResultFound) + async def no_result_found_handler(request: Request, exc: NoResultFound): + logger.error(f"NoResultFound request: {request}") + logger.error(f"NoResultFound: {exc}") + + return JSONResponse( + status_code=404, + content={"detail": str(exc)}, + ) + @app.exception_handler(ValueError) async def value_error_handler(request: Request, exc: ValueError): return JSONResponse(status_code=400, content={"detail": str(exc)}) diff --git a/letta/server/server.py b/letta/server/server.py index 96490b39..b01bfd34 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -19,7 +19,6 @@ from letta.agent import Agent, save_agent from letta.chat_only_agent import ChatOnlyAgent from letta.credentials import LettaCredentials from letta.data_sources.connectors import DataConnector, load_data -from letta.errors import LettaAgentNotFoundError # TODO use custom interface from letta.interface import AgentInterface # abstract @@ -399,9 +398,6 @@ class SyncServer(Server): with agent_lock: agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - if agent_state is None: - raise LettaAgentNotFoundError(f"Agent (agent_id={agent_id}) does not exist") - interface = interface or self.default_interface_factory() if agent_state.agent_type == AgentType.memgpt_agent: agent = Agent(agent_state=agent_state, interface=interface, user=actor) @@ -901,32 +897,14 @@ class SyncServer(Server): # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user actor = self.user_manager.get_user_or_default(user_id=user_id) + agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) + + # TODO: This is very redundant, and should probably be simplified # Get the agent object (loaded in memory) letta_agent = self.load_agent(agent_id=agent_id, actor=actor) + letta_agent.link_tools(agent_state.tools) - # Get all the tool objects from the request - tool_objs = [] - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor) - assert tool_obj, f"Tool with id={tool_id} does not exist" - tool_objs.append(tool_obj) - - for tool in letta_agent.agent_state.tools: - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor) - assert tool_obj, f"Tool with id={tool.id} does not exist" - - # If it's not the already added tool - if tool_obj.id != tool_id: - tool_objs.append(tool_obj) - - # replace the list of tool names ("ids") inside the agent state - letta_agent.agent_state.tools = tool_objs - - # then attempt to link the tools modules - letta_agent.link_tools(tool_objs) - - # save the agent - save_agent(letta_agent) - return letta_agent.agent_state + return agent_state def remove_tool_from_agent( self, @@ -937,29 +915,13 @@ class SyncServer(Server): """Remove tools from an existing agent""" # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user actor = self.user_manager.get_user_or_default(user_id=user_id) + agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) # Get the agent object (loaded in memory) letta_agent = self.load_agent(agent_id=agent_id, actor=actor) + letta_agent.link_tools(agent_state.tools) - # Get all the tool_objs - tool_objs = [] - for tool in letta_agent.agent_state.tools: - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor) - assert tool_obj, f"Tool with id={tool.id} does not exist" - - # If it's not the tool we want to remove - if tool_obj.id != tool_id: - tool_objs.append(tool_obj) - - # replace the list of tool names ("ids") inside the agent state - letta_agent.agent_state.tools = tool_objs - - # then attempt to link the tools modules - letta_agent.link_tools(tool_objs) - - # save the agent - save_agent(letta_agent) - return letta_agent.agent_state + return agent_state # convert name->id diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 093831aa..52a526f9 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from letta.log import get_logger from letta.orm import Agent as AgentModel from letta.orm import Block as BlockModel from letta.orm import Source as SourceModel @@ -25,6 +26,8 @@ from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.utils import enforce_types +logger = get_logger(__name__) + # Agent Manager Class class AgentManager: @@ -403,3 +406,74 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + + # ====================================================================================================================== + # Tool Management + # ====================================================================================================================== + @enforce_types + def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: + """ + Attaches a tool to an agent. + + Args: + agent_id: ID of the agent to attach the tool to. + tool_id: ID of the tool to attach. + actor: User performing the action. + + Raises: + NoResultFound: If the agent or tool is not found. + + Returns: + PydanticAgentState: The updated agent state. + """ + with self.session_maker() as session: + # Verify the agent exists and user has permission to access it + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + + # Use the _process_relationship helper to attach the tool + _process_relationship( + session=session, + agent=agent, + relationship_name="tools", + model_class=ToolModel, + item_ids=[tool_id], + allow_partial=False, # Ensure the tool exists + replace=False, # Extend the existing tools + ) + + # Commit and refresh the agent + agent.update(session, actor=actor) + return agent.to_pydantic() + + @enforce_types + def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: + """ + Detaches a tool from an agent. + + Args: + agent_id: ID of the agent to detach the tool from. + tool_id: ID of the tool to detach. + actor: User performing the action. + + Raises: + NoResultFound: If the agent or tool is not found. + + Returns: + PydanticAgentState: The updated agent state. + """ + with self.session_maker() as session: + # Verify the agent exists and user has permission to access it + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + + # Filter out the tool to be detached + remaining_tools = [tool for tool in agent.tools if tool.id != tool_id] + + if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship + logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}") + + # Update the tools relationship + agent.tools = remaining_tools + + # Commit and refresh the agent + agent.update(session, actor=actor) + return agent.to_pydantic() diff --git a/tests/test_managers.py b/tests/test_managers.py index 96b5faa4..3df4e8b5 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -438,6 +438,82 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe assert updated_agent.message_ids == update_agent_request.message_ids +# ====================================================================================================================== +# AgentManager Tests - Tools Relationship +# ====================================================================================================================== + + +def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_user): + """Test attaching a tool to an agent.""" + # Attach the tool + server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + + # Verify attachment through get_agent_by_id + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert print_tool.id in [t.id for t in agent.tools] + + # Verify that attaching the same tool again doesn't cause duplication + server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert len([t for t in agent.tools if t.id == print_tool.id]) == 1 + + +def test_detach_tool(server: SyncServer, sarah_agent, print_tool, default_user): + """Test detaching a tool from an agent.""" + # Attach the tool first + server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + + # Verify it's attached + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert print_tool.id in [t.id for t in agent.tools] + + # Detach the tool + server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + + # Verify it's detached + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert print_tool.id not in [t.id for t in agent.tools] + + # Verify that detaching an already detached tool doesn't cause issues + server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + + +def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user): + """Test attaching a tool to a nonexistent agent.""" + with pytest.raises(NoResultFound): + server.agent_manager.attach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user) + + +def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user): + """Test attaching a nonexistent tool to an agent.""" + with pytest.raises(NoResultFound): + server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user) + + +def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user): + """Test detaching a tool from a nonexistent agent.""" + with pytest.raises(NoResultFound): + server.agent_manager.detach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user) + + +def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user): + """Test listing tools attached to an agent.""" + # Initially should have no tools + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert len(agent.tools) == 0 + + # Attach tools + server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user) + + # List tools and verify + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + attached_tool_ids = [t.id for t in agent.tools] + assert len(attached_tool_ids) == 2 + assert print_tool.id in attached_tool_ids + assert other_tool.id in attached_tool_ids + + # ====================================================================================================================== # AgentManager Tests - Sources Relationship # ====================================================================================================================== @@ -693,6 +769,7 @@ def test_attach_block(server: SyncServer, sarah_agent, default_block, default_us assert agent.memory.blocks[0].label == default_block.label +@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.") def test_attach_block_duplicate_label(server: SyncServer, sarah_agent, default_block, other_block, default_user): """Test attempting to attach a block with a duplicate label.""" # Set up both blocks with same label @@ -1143,6 +1220,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ assert print_tool.organization_id == default_organization.id +@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.") def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization): data = print_tool.model_dump(exclude=["id"]) tool = PydanticTool(**data)