diff --git a/fern/openapi.json b/fern/openapi.json index 02aa4198..29d1d1b4 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -36716,7 +36716,8 @@ "type": "null" } ], - "description": "The agent state" + "description": "The agent state", + "deprecated": true }, "stdout": { "anyOf": [ diff --git a/letta/schemas/tool_execution_result.py b/letta/schemas/tool_execution_result.py index fd5bd6b4..71c4e713 100644 --- a/letta/schemas/tool_execution_result.py +++ b/letta/schemas/tool_execution_result.py @@ -8,7 +8,7 @@ from letta.schemas.agent import AgentState class ToolExecutionResult(BaseModel): status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object") func_return: Optional[Any] = Field(None, description="The function return object") - agent_state: Optional[AgentState] = Field(None, description="The agent state") + agent_state: Optional[AgentState] = Field(None, description="The agent state", deprecated=True) stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation") stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation") sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 110fef91..41ae7022 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -616,11 +616,9 @@ async def run_tool_for_agent( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - # Get agent with tools and environment variables + # Get agent with all relationships agent = await server.agent_manager.get_agent_by_id_async( - agent_id=agent_id, - actor=actor, - include_relationships=["tools", "tool_exec_environment_variables"], + agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"] ) # Find the tool by name among attached tools @@ -663,6 +661,11 @@ async def run_tool_for_agent( tool=tool, ) + # don't return a result if the tool execution failed + if tool_execution_result.status == "error": + tool_execution_result.func_return = None + # remove deprecated agent_state field + tool_execution_result.agent_state = None return tool_execution_result diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 70240f77..cfc05dd3 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -2436,3 +2436,23 @@ def test_create_agent_with_tools(client: LettaSDKClient) -> None: # clean up client.tools.delete(tool_from_func.id) client.tools.delete(tool_from_class.id) + + +def test_calling_tools(client: LettaSDKClient, agent: AgentState) -> None: + """Test to make sure calling tools through the SDK works as expected""" + + blocks = list(client.agents.blocks.list(agent_id=agent.id)) + assert len(blocks) == 1, f"Expected 1 block, got {len(blocks)}" + + # test calling a stateful tool + result = client.agents.tools.run(agent_id=agent.id, tool_name="memory_insert", args={"label": "human", "new_str": "test"}) + assert result.status == "success", f"Expected success, got {result.status}" + # get the block + block = client.agents.blocks.retrieve(agent_id=agent.id, block_label="human") + assert "test" in block.value, f"Test value not found in block value {block.value}" + + # test calling a tool wrong + result = client.agents.tools.run(agent_id=agent.id, tool_name="memory_insert", args={"label": "human", "FAKE_ARG": "test"}) + assert result.status == "error", f"Expected error, got {result.status}" + assert result.func_return is None, f"Expected func_return to be None, got {result.func_return}" + print(result)