feat: add tool return truncation to summarization as a fallback [LET-5970] (#5859)

This commit is contained in:
Sarah Wooders
2025-10-31 15:29:14 -07:00
committed by Caren Thomas
parent cdde791b11
commit 57bb051ea4
13 changed files with 209 additions and 16 deletions

View File

@@ -378,6 +378,9 @@ FUNCTION_RETURN_CHAR_LIMIT = 50000 # ~300 words
BASE_FUNCTION_RETURN_CHAR_LIMIT = 50000 # same as regular function limit
FILE_IS_TRUNCATED_WARNING = "# NOTE: This block is truncated, use functions to view the full content."
# Tool return truncation limit for LLM context window management
TOOL_RETURN_TRUNCATION_CHARS = 5000
MAX_PAUSE_HEARTBEATS = 360 # in min
MESSAGE_CHATGPT_FUNCTION_MODEL = "gpt-3.5-turbo"

View File

@@ -231,6 +231,7 @@ class AnthropicClient(LLMClientBase):
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise
@@ -336,6 +337,7 @@ class AnthropicClient(LLMClientBase):
# if react, use native content + strip heartbeats
native_content=is_v1,
strip_request_heartbeat=is_v1,
tool_return_truncation_chars=tool_return_truncation_chars,
)
# Ensure first message is user
@@ -474,6 +476,14 @@ class AnthropicClient(LLMClientBase):
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:
# make sure to check for overflow errors, regardless of error type
error_str = str(e).lower()
if "prompt is too long" in error_str or "exceed context limit" in error_str or "exceeds context" in error_str:
logger.warning(f"[Anthropic] Context window exceeded: {str(e)}")
return ContextWindowExceededError(
message=f"Context window exceeded for Anthropic: {str(e)}",
)
if isinstance(e, anthropic.APITimeoutError):
logger.warning(f"[Anthropic] Request timeout: {e}")
return LLMTimeoutError(

View File

@@ -71,6 +71,7 @@ class BedrockClient(AnthropicClient):
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# remove disallowed fields

View File

@@ -340,6 +340,7 @@ class DeepseekClient(OpenAIClient):
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
llm_config.put_inner_thoughts_in_kwargs = False

View File

@@ -291,6 +291,7 @@ class GoogleVertexClient(LLMClientBase):
tools: List[dict],
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
"""
Constructs a request object in the expected data format for this client.

View File

@@ -30,6 +30,7 @@ class GroqClient(OpenAIClient):
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)

View File

@@ -47,13 +47,22 @@ class LLMClientBase:
force_tool_call: Optional[str] = None,
telemetry_manager: Optional["TelemetryManager"] = None,
step_id: Optional[str] = None,
tool_return_truncation_chars: Optional[int] = None,
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
"""
Issues a request to the downstream model endpoint and parses response.
If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over.
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
request_data = self.build_request_data(
agent_type,
messages,
llm_config,
tools,
force_tool_call,
requires_subsequent_tool_call=False,
tool_return_truncation_chars=tool_return_truncation_chars,
)
try:
log_event(name="llm_request_sent", attributes=request_data)
@@ -128,9 +137,14 @@ class LLMClientBase:
tools: List[dict],
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
"""
Constructs a request object in the expected data format for this client.
Args:
tool_return_truncation_chars: If set, truncates tool return content to this many characters.
Used during summarization to avoid context window issues.
"""
raise NotImplementedError

View File

@@ -229,6 +229,7 @@ class OpenAIClient(LLMClientBase):
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
"""
Constructs a request object in the expected data format for the OpenAI Responses API.
@@ -236,7 +237,9 @@ class OpenAIClient(LLMClientBase):
if llm_config.put_inner_thoughts_in_kwargs:
raise ValueError("Inner thoughts in kwargs are not supported for the OpenAI Responses API")
openai_messages_list = PydanticMessage.to_openai_responses_dicts_from_list(messages)
openai_messages_list = PydanticMessage.to_openai_responses_dicts_from_list(
messages, tool_return_truncation_chars=tool_return_truncation_chars
)
# Add multi-modal support for Responses API by rewriting user messages
# into input_text/input_image parts.
openai_messages_list = fill_image_content_in_responses_input(openai_messages_list, messages)
@@ -377,6 +380,7 @@ class OpenAIClient(LLMClientBase):
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
"""
Constructs a request object in the expected data format for the OpenAI API.
@@ -390,6 +394,7 @@ class OpenAIClient(LLMClientBase):
tools=tools,
force_tool_call=force_tool_call,
requires_subsequent_tool_call=requires_subsequent_tool_call,
tool_return_truncation_chars=tool_return_truncation_chars,
)
if agent_type == AgentType.letta_v1_agent:
@@ -419,6 +424,7 @@ class OpenAIClient(LLMClientBase):
messages,
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
tool_return_truncation_chars=tool_return_truncation_chars,
)
]

View File

@@ -30,6 +30,7 @@ class XAIClient(OpenAIClient):
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)

View File

@@ -56,6 +56,14 @@ from letta.system import unpack_message
from letta.utils import parse_json, validate_function_response
def truncate_tool_return(content: Optional[str], limit: Optional[int]) -> Optional[str]:
if limit is None or content is None:
return content
if len(content) <= limit:
return content
return content[:limit] + f"... [truncated {len(content) - limit} chars]"
def add_inner_thoughts_to_tool_call(
tool_call: OpenAIToolCall,
inner_thoughts: str,
@@ -1090,6 +1098,7 @@ class Message(BaseMessage):
# if true, then treat the content field as AssistantMessage
native_content: bool = False,
strip_request_heartbeat: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict | None:
"""Go from Message class to ChatCompletion message object"""
assert not (native_content and put_inner_thoughts_in_kwargs), "native_content and put_inner_thoughts_in_kwargs cannot both be true"
@@ -1191,16 +1200,18 @@ class Message(BaseMessage):
tool_return = self.tool_returns[0]
if not tool_return.tool_call_id:
raise TypeError("OpenAI API requires tool_call_id to be set.")
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
openai_message = {
"content": tool_return.func_response,
"content": func_response,
"role": self.role,
"tool_call_id": tool_return.tool_call_id[:max_tool_id_length] if max_tool_id_length else tool_return.tool_call_id,
}
else:
# Legacy fallback for old message format
assert self.tool_call_id is not None, vars(self)
legacy_content = truncate_tool_return(text_content, tool_return_truncation_chars)
openai_message = {
"content": text_content,
"content": legacy_content,
"role": self.role,
"tool_call_id": self.tool_call_id[:max_tool_id_length] if max_tool_id_length else self.tool_call_id,
}
@@ -1232,6 +1243,7 @@ class Message(BaseMessage):
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
put_inner_thoughts_in_kwargs: bool = False,
use_developer_message: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> List[dict]:
messages = Message.filter_messages_for_llm_api(messages)
result: List[dict] = []
@@ -1256,6 +1268,7 @@ class Message(BaseMessage):
max_tool_id_length=max_tool_id_length,
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
tool_return_truncation_chars=tool_return_truncation_chars,
)
if d is not None:
result.append(d)
@@ -1265,6 +1278,7 @@ class Message(BaseMessage):
def to_openai_responses_dicts(
self,
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
tool_return_truncation_chars: Optional[int] = None,
) -> List[dict]:
"""Go from Message class to ChatCompletion message object"""
@@ -1345,22 +1359,24 @@ class Message(BaseMessage):
for tool_return in self.tool_returns:
if not tool_return.tool_call_id:
raise TypeError("OpenAI Responses API requires tool_call_id to be set.")
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
message_dicts.append(
{
"type": "function_call_output",
"call_id": tool_return.tool_call_id[:max_tool_id_length] if max_tool_id_length else tool_return.tool_call_id,
"output": tool_return.func_response,
"output": func_response,
}
)
else:
# Legacy fallback for old message format
assert self.tool_call_id is not None, vars(self)
assert len(self.content) == 1 and isinstance(self.content[0], TextContent), vars(self)
legacy_output = truncate_tool_return(self.content[0].text, tool_return_truncation_chars)
message_dicts.append(
{
"type": "function_call_output",
"call_id": self.tool_call_id[:max_tool_id_length] if max_tool_id_length else self.tool_call_id,
"output": self.content[0].text,
"output": legacy_output,
}
)
@@ -1373,11 +1389,16 @@ class Message(BaseMessage):
def to_openai_responses_dicts_from_list(
messages: List[Message],
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
tool_return_truncation_chars: Optional[int] = None,
) -> List[dict]:
messages = Message.filter_messages_for_llm_api(messages)
result = []
for message in messages:
result.extend(message.to_openai_responses_dicts(max_tool_id_length=max_tool_id_length))
result.extend(
message.to_openai_responses_dicts(
max_tool_id_length=max_tool_id_length, tool_return_truncation_chars=tool_return_truncation_chars
)
)
return result
def to_anthropic_dict(
@@ -1388,6 +1409,7 @@ class Message(BaseMessage):
# if true, then treat the content field as AssistantMessage
native_content: bool = False,
strip_request_heartbeat: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict | None:
"""
Convert to an Anthropic message dictionary
@@ -1563,11 +1585,12 @@ class Message(BaseMessage):
for tool_return in self.tool_returns:
if not tool_return.tool_call_id:
raise TypeError("Anthropic API requires tool_use_id to be set.")
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
content.append(
{
"type": "tool_result",
"tool_use_id": tool_return.tool_call_id,
"content": tool_return.func_response,
"content": func_response,
}
)
if content:
@@ -1580,6 +1603,7 @@ class Message(BaseMessage):
raise TypeError("Anthropic API requires tool_use_id to be set.")
# This is for legacy reasons
legacy_content = truncate_tool_return(text_content, tool_return_truncation_chars)
anthropic_message = {
"role": "user", # NOTE: diff
"content": [
@@ -1587,7 +1611,7 @@ class Message(BaseMessage):
{
"type": "tool_result",
"tool_use_id": self.tool_call_id,
"content": text_content,
"content": legacy_content,
}
],
}
@@ -1606,6 +1630,7 @@ class Message(BaseMessage):
# if true, then treat the content field as AssistantMessage
native_content: bool = False,
strip_request_heartbeat: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> List[dict]:
messages = Message.filter_messages_for_llm_api(messages)
result = [
@@ -1615,6 +1640,7 @@ class Message(BaseMessage):
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
native_content=native_content,
strip_request_heartbeat=strip_request_heartbeat,
tool_return_truncation_chars=tool_return_truncation_chars,
)
for m in messages
]
@@ -1628,6 +1654,7 @@ class Message(BaseMessage):
# if true, then treat the content field as AssistantMessage
native_content: bool = False,
strip_request_heartbeat: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict | None:
"""
Go from Message class to Google AI REST message object
@@ -1776,11 +1803,14 @@ class Message(BaseMessage):
# Use the function name if available, otherwise use tool_call_id
function_name = self.name if self.name else tool_return.tool_call_id
# Truncate the tool return if needed
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
# NOTE: Google AI API wants the function response as JSON only, no string
try:
function_response = parse_json(tool_return.func_response)
function_response = parse_json(func_response)
except:
function_response = {"function_response": tool_return.func_response}
function_response = {"function_response": func_response}
parts.append(
{
@@ -1808,11 +1838,14 @@ class Message(BaseMessage):
else:
function_name = self.name
# Truncate the legacy content if needed
legacy_content = truncate_tool_return(text_content, tool_return_truncation_chars)
# NOTE: Google AI API wants the function response as JSON only, no string
try:
function_response = parse_json(text_content)
function_response = parse_json(legacy_content)
except:
function_response = {"function_response": text_content}
function_response = {"function_response": legacy_content}
google_ai_message = {
"role": "function",
@@ -1848,6 +1881,7 @@ class Message(BaseMessage):
current_model: str,
put_inner_thoughts_in_kwargs: bool = True,
native_content: bool = False,
tool_return_truncation_chars: Optional[int] = None,
):
messages = Message.filter_messages_for_llm_api(messages)
result = [
@@ -1855,6 +1889,7 @@ class Message(BaseMessage):
current_model=current_model,
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
native_content=native_content,
tool_return_truncation_chars=tool_return_truncation_chars,
)
for m in messages
]

View File

@@ -4,7 +4,13 @@ import traceback
from typing import List, Optional, Tuple, Union
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, MESSAGE_SUMMARY_REQUEST_ACK
from letta.constants import (
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
MESSAGE_SUMMARY_REQUEST_ACK,
TOOL_RETURN_TRUNCATION_CHARS,
)
from letta.errors import ContextWindowExceededError
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
@@ -394,7 +400,27 @@ async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor:
response_data = await llm_client.request_async(request_data, summarizer_llm_config)
except Exception as e:
# handle LLM error (likely a context window exceeded error)
raise llm_client.handle_llm_error(e)
try:
raise llm_client.handle_llm_error(e)
except ContextWindowExceededError as context_error:
logger.warning(
f"Context window exceeded during summarization, falling back to truncated tool returns. Original error: {context_error}"
)
# Fallback: rebuild request with truncated tool returns
request_data = llm_client.build_request_data(
AgentType.letta_v1_agent,
input_messages_obj,
summarizer_llm_config,
tools=[],
tool_return_truncation_chars=TOOL_RETURN_TRUNCATION_CHARS,
)
try:
response_data = await llm_client.request_async(request_data, summarizer_llm_config)
except Exception as fallback_error:
logger.error(f"Fallback summarization also failed: {fallback_error}")
raise llm_client.handle_llm_error(fallback_error)
response = llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config)
if response.choices[0].message.content is None:
logger.warning("No content returned from summarizer")

View File

@@ -0,0 +1,9 @@
{
"model": "claude-haiku-4-5",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null,
"context_window": 200000,
"put_inner_thoughts_in_kwargs": true,
"enable_reasoner": true
}

View File

@@ -39,13 +39,15 @@ def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model
# Test configurations - using a subset of models for summarization tests
all_configs = [
"openai-gpt-5-mini.json",
"claude-4-5-haiku.json",
"gemini-2.5-flash.json",
# "gemini-2.5-flash-vertex.json", # Requires Vertex AI credentials
# "openai-gpt-4.1.json",
# "openai-o1.json",
# "openai-o3.json",
# "openai-o4-mini.json",
# "claude-4-sonnet.json",
# "claude-3-7-sonnet.json",
# "gemini-2.5-flash-vertex.json",
# "gemini-2.5-pro-vertex.json",
]
@@ -517,3 +519,86 @@ async def test_summarize_multiple_large_tool_calls(server: SyncServer, actor, ll
assert hasattr(msg, "content")
print(f"Summarized {len(in_context_messages)} messages with {total_content_size} chars to {len(result)} messages")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
async def test_summarize_truncates_large_tool_return(server: SyncServer, actor, llm_config: LLMConfig):
"""
Test that summarization properly truncates very large tool returns.
This ensures that oversized tool returns don't consume excessive context.
"""
# Create an extremely large tool return (100k chars)
large_return = create_large_tool_return(100000)
original_size = len(large_return)
# Create messages with a large tool return
messages = [
PydanticMessage(
role=MessageRole.user,
content=[TextContent(type="text", text="Please run the database query.")],
),
PydanticMessage(
role=MessageRole.assistant,
content=[
TextContent(type="text", text="Running query..."),
ToolCallContent(
type="tool_call",
id="call_1",
name="run_query",
input={"query": "SELECT * FROM large_table"},
),
],
),
PydanticMessage(
role=MessageRole.tool,
tool_call_id="call_1",
content=[
ToolReturnContent(
type="tool_return",
tool_call_id="call_1",
content=large_return,
is_error=False,
)
],
),
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(type="text", text="Query completed successfully with many results.")],
),
]
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
# Verify the original tool return is indeed large
assert original_size > 90000, f"Expected tool return >90k chars, got {original_size}"
# Run summarization
result = await run_summarization(server, agent_state, in_context_messages, actor)
# Verify result
assert isinstance(result, list)
assert len(result) >= 1
# Find tool return messages in the result and verify truncation occurred
tool_returns_found = False
for msg in result:
if msg.role == MessageRole.tool:
for content in msg.content:
if isinstance(content, ToolReturnContent):
tool_returns_found = True
result_size = len(content.content)
# Verify that the tool return has been truncated
assert result_size < original_size, (
f"Expected tool return to be truncated from {original_size} chars, but got {result_size} chars"
)
print(f"Tool return successfully truncated from {original_size} to {result_size} chars")
# If we didn't find any tool returns in the result, that's also acceptable
# (they may have been completely removed during aggressive summarization)
if not tool_returns_found:
print("Tool returns were completely removed during summarization")