feat: patch tool calling endpoint and add SDK testing (#6456)
This commit is contained in:
committed by
Caren Thomas
parent
ceadacd30e
commit
e862bae524
@@ -36716,7 +36716,8 @@
|
|||||||
"type": "null"
|
"type": "null"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "The agent state"
|
"description": "The agent state",
|
||||||
|
"deprecated": true
|
||||||
},
|
},
|
||||||
"stdout": {
|
"stdout": {
|
||||||
"anyOf": [
|
"anyOf": [
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from letta.schemas.agent import AgentState
|
|||||||
class ToolExecutionResult(BaseModel):
|
class ToolExecutionResult(BaseModel):
|
||||||
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
||||||
func_return: Optional[Any] = Field(None, description="The function return object")
|
func_return: Optional[Any] = Field(None, description="The function return object")
|
||||||
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
agent_state: Optional[AgentState] = Field(None, description="The agent state", deprecated=True)
|
||||||
stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
|
stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
|
||||||
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
||||||
sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")
|
sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")
|
||||||
|
|||||||
@@ -616,11 +616,9 @@ async def run_tool_for_agent(
|
|||||||
"""
|
"""
|
||||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||||
|
|
||||||
# Get agent with tools and environment variables
|
# Get agent with all relationships
|
||||||
agent = await server.agent_manager.get_agent_by_id_async(
|
agent = await server.agent_manager.get_agent_by_id_async(
|
||||||
agent_id=agent_id,
|
agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"]
|
||||||
actor=actor,
|
|
||||||
include_relationships=["tools", "tool_exec_environment_variables"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find the tool by name among attached tools
|
# Find the tool by name among attached tools
|
||||||
@@ -663,6 +661,11 @@ async def run_tool_for_agent(
|
|||||||
tool=tool,
|
tool=tool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# don't return a result if the tool execution failed
|
||||||
|
if tool_execution_result.status == "error":
|
||||||
|
tool_execution_result.func_return = None
|
||||||
|
# remove deprecated agent_state field
|
||||||
|
tool_execution_result.agent_state = None
|
||||||
return tool_execution_result
|
return tool_execution_result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2436,3 +2436,23 @@ def test_create_agent_with_tools(client: LettaSDKClient) -> None:
|
|||||||
# clean up
|
# clean up
|
||||||
client.tools.delete(tool_from_func.id)
|
client.tools.delete(tool_from_func.id)
|
||||||
client.tools.delete(tool_from_class.id)
|
client.tools.delete(tool_from_class.id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calling_tools(client: LettaSDKClient, agent: AgentState) -> None:
|
||||||
|
"""Test to make sure calling tools through the SDK works as expected"""
|
||||||
|
|
||||||
|
blocks = list(client.agents.blocks.list(agent_id=agent.id))
|
||||||
|
assert len(blocks) == 1, f"Expected 1 block, got {len(blocks)}"
|
||||||
|
|
||||||
|
# test calling a stateful tool
|
||||||
|
result = client.agents.tools.run(agent_id=agent.id, tool_name="memory_insert", args={"label": "human", "new_str": "test"})
|
||||||
|
assert result.status == "success", f"Expected success, got {result.status}"
|
||||||
|
# get the block
|
||||||
|
block = client.agents.blocks.retrieve(agent_id=agent.id, block_label="human")
|
||||||
|
assert "test" in block.value, f"Test value not found in block value {block.value}"
|
||||||
|
|
||||||
|
# test calling a tool wrong
|
||||||
|
result = client.agents.tools.run(agent_id=agent.id, tool_name="memory_insert", args={"label": "human", "FAKE_ARG": "test"})
|
||||||
|
assert result.status == "error", f"Expected error, got {result.status}"
|
||||||
|
assert result.func_return is None, f"Expected func_return to be None, got {result.func_return}"
|
||||||
|
print(result)
|
||||||
|
|||||||
Reference in New Issue
Block a user