feat: patch tool calling endpoint and add SDK testing (#6456)
This commit is contained in:
committed by
Caren Thomas
parent
ceadacd30e
commit
e862bae524
@@ -36716,7 +36716,8 @@
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The agent state"
|
||||
"description": "The agent state",
|
||||
"deprecated": true
|
||||
},
|
||||
"stdout": {
|
||||
"anyOf": [
|
||||
|
||||
@@ -8,7 +8,7 @@ from letta.schemas.agent import AgentState
|
||||
class ToolExecutionResult(BaseModel):
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
||||
func_return: Optional[Any] = Field(None, description="The function return object")
|
||||
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
||||
agent_state: Optional[AgentState] = Field(None, description="The agent state", deprecated=True)
|
||||
stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
|
||||
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
||||
sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")
|
||||
|
||||
@@ -616,11 +616,9 @@ async def run_tool_for_agent(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Get agent with tools and environment variables
|
||||
# Get agent with all relationships
|
||||
agent = await server.agent_manager.get_agent_by_id_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
include_relationships=["tools", "tool_exec_environment_variables"],
|
||||
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
|
||||
)
|
||||
|
||||
# Find the tool by name among attached tools
|
||||
@@ -663,6 +661,11 @@ async def run_tool_for_agent(
|
||||
tool=tool,
|
||||
)
|
||||
|
||||
# don't return a result if the tool execution failed
|
||||
if tool_execution_result.status == "error":
|
||||
tool_execution_result.func_return = None
|
||||
# remove deprecated agent_state field
|
||||
tool_execution_result.agent_state = None
|
||||
return tool_execution_result
|
||||
|
||||
|
||||
|
||||
@@ -2436,3 +2436,23 @@ def test_create_agent_with_tools(client: LettaSDKClient) -> None:
|
||||
# clean up
|
||||
client.tools.delete(tool_from_func.id)
|
||||
client.tools.delete(tool_from_class.id)
|
||||
|
||||
|
||||
def test_calling_tools(client: LettaSDKClient, agent: AgentState) -> None:
|
||||
"""Test to make sure calling tools through the SDK works as expected"""
|
||||
|
||||
blocks = list(client.agents.blocks.list(agent_id=agent.id))
|
||||
assert len(blocks) == 1, f"Expected 1 block, got {len(blocks)}"
|
||||
|
||||
# test calling a stateful tool
|
||||
result = client.agents.tools.run(agent_id=agent.id, tool_name="memory_insert", args={"label": "human", "new_str": "test"})
|
||||
assert result.status == "success", f"Expected success, got {result.status}"
|
||||
# get the block
|
||||
block = client.agents.blocks.retrieve(agent_id=agent.id, block_label="human")
|
||||
assert "test" in block.value, f"Test value not found in block value {block.value}"
|
||||
|
||||
# test calling a tool wrong
|
||||
result = client.agents.tools.run(agent_id=agent.id, tool_name="memory_insert", args={"label": "human", "FAKE_ARG": "test"})
|
||||
assert result.status == "error", f"Expected error, got {result.status}"
|
||||
assert result.func_return is None, f"Expected func_return to be None, got {result.func_return}"
|
||||
print(result)
|
||||
|
||||
Reference in New Issue
Block a user