From fe592eda728f37f664b511b100d5e33e9d019bf0 Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 19 May 2025 16:01:59 -0700 Subject: [PATCH] feat: protect against anthropic nested tool args (#2250) --- letta/agents/letta_agent.py | 3 +++ letta/interfaces/anthropic_streaming_interface.py | 9 ++++++++- letta/llm_api/anthropic_client.py | 8 ++++---- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 3e06c58e..718fd583 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -488,8 +488,11 @@ class LettaAgent(BaseAgent): try: tool_args = json.loads(tool_call_args_str) + assert isinstance(tool_args, dict), "tool_args must be a dict" except json.JSONDecodeError: tool_args = {} + except AssertionError: + tool_args = json.loads(tool_args) # Get request heartbeats and coerce to bool request_heartbeat = tool_args.pop("request_heartbeat", False) diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 08c7ce0a..1a8aa220 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -1,3 +1,4 @@ +import json from datetime import datetime, timezone from enum import Enum from typing import AsyncGenerator, List, Union @@ -89,7 +90,13 @@ class AnthropicStreamingInterface: def get_tool_call_object(self) -> ToolCall: """Useful for agent loop""" - return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name)) + # hack for tool rules + tool_input = json.loads(self.accumulated_tool_call_args) + if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input: + arguments = str(json.dumps(tool_input["function"]["arguments"], indent=2)) + else: + arguments = self.accumulated_tool_call_args + return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name)) def _check_inner_thoughts_complete(self, combined_args: str) -> bool: """ diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index f26d58eb..ca010d68 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -369,11 +369,11 @@ class AnthropicClient(LLMClientBase): content = strip_xml_tags(string=content_part.text, tag="thinking") if content_part.type == "tool_use": # hack for tool rules - input = json.loads(json.dumps(content_part.input)) - if "id" in input and input["id"].startswith("toolu_") and "function" in input: - arguments = str(input["function"]["arguments"]) + tool_input = json.loads(json.dumps(content_part.input)) + if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input: + arguments = str(tool_input["function"]["arguments"]) else: - arguments = json.dumps(content_part.input, indent=2) + arguments = json.dumps(tool_input, indent=2) tool_calls = [ ToolCall( id=content_part.id,