feat: add tool return truncation to summarization as a fallback [LET-5970] (#5859)
This commit is contained in:
committed by
Caren Thomas
parent
cdde791b11
commit
57bb051ea4
@@ -378,6 +378,9 @@ FUNCTION_RETURN_CHAR_LIMIT = 50000 # ~300 words
|
||||
BASE_FUNCTION_RETURN_CHAR_LIMIT = 50000 # same as regular function limit
|
||||
FILE_IS_TRUNCATED_WARNING = "# NOTE: This block is truncated, use functions to view the full content."
|
||||
|
||||
# Tool return truncation limit for LLM context window management
|
||||
TOOL_RETURN_TRUNCATION_CHARS = 5000
|
||||
|
||||
MAX_PAUSE_HEARTBEATS = 360 # in min
|
||||
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL = "gpt-3.5-turbo"
|
||||
|
||||
@@ -231,6 +231,7 @@ class AnthropicClient(LLMClientBase):
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
||||
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise
|
||||
@@ -336,6 +337,7 @@ class AnthropicClient(LLMClientBase):
|
||||
# if react, use native content + strip heartbeats
|
||||
native_content=is_v1,
|
||||
strip_request_heartbeat=is_v1,
|
||||
tool_return_truncation_chars=tool_return_truncation_chars,
|
||||
)
|
||||
|
||||
# Ensure first message is user
|
||||
@@ -474,6 +476,14 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
# make sure to check for overflow errors, regardless of error type
|
||||
error_str = str(e).lower()
|
||||
if "prompt is too long" in error_str or "exceed context limit" in error_str or "exceeds context" in error_str:
|
||||
logger.warning(f"[Anthropic] Context window exceeded: {str(e)}")
|
||||
return ContextWindowExceededError(
|
||||
message=f"Context window exceeded for Anthropic: {str(e)}",
|
||||
)
|
||||
|
||||
if isinstance(e, anthropic.APITimeoutError):
|
||||
logger.warning(f"[Anthropic] Request timeout: {e}")
|
||||
return LLMTimeoutError(
|
||||
|
||||
@@ -71,6 +71,7 @@ class BedrockClient(AnthropicClient):
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
# remove disallowed fields
|
||||
|
||||
@@ -340,6 +340,7 @@ class DeepseekClient(OpenAIClient):
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
|
||||
llm_config.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
@@ -291,6 +291,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
tools: List[dict],
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
|
||||
@@ -30,6 +30,7 @@ class GroqClient(OpenAIClient):
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
|
||||
@@ -47,13 +47,22 @@ class LLMClientBase:
|
||||
force_tool_call: Optional[str] = None,
|
||||
telemetry_manager: Optional["TelemetryManager"] = None,
|
||||
step_id: Optional[str] = None,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
|
||||
"""
|
||||
Issues a request to the downstream model endpoint and parses response.
|
||||
If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
request_data = self.build_request_data(
|
||||
agent_type,
|
||||
messages,
|
||||
llm_config,
|
||||
tools,
|
||||
force_tool_call,
|
||||
requires_subsequent_tool_call=False,
|
||||
tool_return_truncation_chars=tool_return_truncation_chars,
|
||||
)
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
@@ -128,9 +137,14 @@ class LLMClientBase:
|
||||
tools: List[dict],
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
|
||||
Args:
|
||||
tool_return_truncation_chars: If set, truncates tool return content to this many characters.
|
||||
Used during summarization to avoid context window issues.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -229,6 +229,7 @@ class OpenAIClient(LLMClientBase):
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for the OpenAI Responses API.
|
||||
@@ -236,7 +237,9 @@ class OpenAIClient(LLMClientBase):
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
raise ValueError("Inner thoughts in kwargs are not supported for the OpenAI Responses API")
|
||||
|
||||
openai_messages_list = PydanticMessage.to_openai_responses_dicts_from_list(messages)
|
||||
openai_messages_list = PydanticMessage.to_openai_responses_dicts_from_list(
|
||||
messages, tool_return_truncation_chars=tool_return_truncation_chars
|
||||
)
|
||||
# Add multi-modal support for Responses API by rewriting user messages
|
||||
# into input_text/input_image parts.
|
||||
openai_messages_list = fill_image_content_in_responses_input(openai_messages_list, messages)
|
||||
@@ -377,6 +380,7 @@ class OpenAIClient(LLMClientBase):
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for the OpenAI API.
|
||||
@@ -390,6 +394,7 @@ class OpenAIClient(LLMClientBase):
|
||||
tools=tools,
|
||||
force_tool_call=force_tool_call,
|
||||
requires_subsequent_tool_call=requires_subsequent_tool_call,
|
||||
tool_return_truncation_chars=tool_return_truncation_chars,
|
||||
)
|
||||
|
||||
if agent_type == AgentType.letta_v1_agent:
|
||||
@@ -419,6 +424,7 @@ class OpenAIClient(LLMClientBase):
|
||||
messages,
|
||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
use_developer_message=use_developer_message,
|
||||
tool_return_truncation_chars=tool_return_truncation_chars,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ class XAIClient(OpenAIClient):
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
|
||||
@@ -56,6 +56,14 @@ from letta.system import unpack_message
|
||||
from letta.utils import parse_json, validate_function_response
|
||||
|
||||
|
||||
def truncate_tool_return(content: Optional[str], limit: Optional[int]) -> Optional[str]:
|
||||
if limit is None or content is None:
|
||||
return content
|
||||
if len(content) <= limit:
|
||||
return content
|
||||
return content[:limit] + f"... [truncated {len(content) - limit} chars]"
|
||||
|
||||
|
||||
def add_inner_thoughts_to_tool_call(
|
||||
tool_call: OpenAIToolCall,
|
||||
inner_thoughts: str,
|
||||
@@ -1090,6 +1098,7 @@ class Message(BaseMessage):
|
||||
# if true, then treat the content field as AssistantMessage
|
||||
native_content: bool = False,
|
||||
strip_request_heartbeat: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict | None:
|
||||
"""Go from Message class to ChatCompletion message object"""
|
||||
assert not (native_content and put_inner_thoughts_in_kwargs), "native_content and put_inner_thoughts_in_kwargs cannot both be true"
|
||||
@@ -1191,16 +1200,18 @@ class Message(BaseMessage):
|
||||
tool_return = self.tool_returns[0]
|
||||
if not tool_return.tool_call_id:
|
||||
raise TypeError("OpenAI API requires tool_call_id to be set.")
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
openai_message = {
|
||||
"content": tool_return.func_response,
|
||||
"content": func_response,
|
||||
"role": self.role,
|
||||
"tool_call_id": tool_return.tool_call_id[:max_tool_id_length] if max_tool_id_length else tool_return.tool_call_id,
|
||||
}
|
||||
else:
|
||||
# Legacy fallback for old message format
|
||||
assert self.tool_call_id is not None, vars(self)
|
||||
legacy_content = truncate_tool_return(text_content, tool_return_truncation_chars)
|
||||
openai_message = {
|
||||
"content": text_content,
|
||||
"content": legacy_content,
|
||||
"role": self.role,
|
||||
"tool_call_id": self.tool_call_id[:max_tool_id_length] if max_tool_id_length else self.tool_call_id,
|
||||
}
|
||||
@@ -1232,6 +1243,7 @@ class Message(BaseMessage):
|
||||
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
use_developer_message: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> List[dict]:
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result: List[dict] = []
|
||||
@@ -1256,6 +1268,7 @@ class Message(BaseMessage):
|
||||
max_tool_id_length=max_tool_id_length,
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
use_developer_message=use_developer_message,
|
||||
tool_return_truncation_chars=tool_return_truncation_chars,
|
||||
)
|
||||
if d is not None:
|
||||
result.append(d)
|
||||
@@ -1265,6 +1278,7 @@ class Message(BaseMessage):
|
||||
def to_openai_responses_dicts(
|
||||
self,
|
||||
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> List[dict]:
|
||||
"""Go from Message class to ChatCompletion message object"""
|
||||
|
||||
@@ -1345,22 +1359,24 @@ class Message(BaseMessage):
|
||||
for tool_return in self.tool_returns:
|
||||
if not tool_return.tool_call_id:
|
||||
raise TypeError("OpenAI Responses API requires tool_call_id to be set.")
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
message_dicts.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_return.tool_call_id[:max_tool_id_length] if max_tool_id_length else tool_return.tool_call_id,
|
||||
"output": tool_return.func_response,
|
||||
"output": func_response,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Legacy fallback for old message format
|
||||
assert self.tool_call_id is not None, vars(self)
|
||||
assert len(self.content) == 1 and isinstance(self.content[0], TextContent), vars(self)
|
||||
legacy_output = truncate_tool_return(self.content[0].text, tool_return_truncation_chars)
|
||||
message_dicts.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": self.tool_call_id[:max_tool_id_length] if max_tool_id_length else self.tool_call_id,
|
||||
"output": self.content[0].text,
|
||||
"output": legacy_output,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1373,11 +1389,16 @@ class Message(BaseMessage):
|
||||
def to_openai_responses_dicts_from_list(
|
||||
messages: List[Message],
|
||||
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> List[dict]:
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result = []
|
||||
for message in messages:
|
||||
result.extend(message.to_openai_responses_dicts(max_tool_id_length=max_tool_id_length))
|
||||
result.extend(
|
||||
message.to_openai_responses_dicts(
|
||||
max_tool_id_length=max_tool_id_length, tool_return_truncation_chars=tool_return_truncation_chars
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def to_anthropic_dict(
|
||||
@@ -1388,6 +1409,7 @@ class Message(BaseMessage):
|
||||
# if true, then treat the content field as AssistantMessage
|
||||
native_content: bool = False,
|
||||
strip_request_heartbeat: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Convert to an Anthropic message dictionary
|
||||
@@ -1563,11 +1585,12 @@ class Message(BaseMessage):
|
||||
for tool_return in self.tool_returns:
|
||||
if not tool_return.tool_call_id:
|
||||
raise TypeError("Anthropic API requires tool_use_id to be set.")
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_return.tool_call_id,
|
||||
"content": tool_return.func_response,
|
||||
"content": func_response,
|
||||
}
|
||||
)
|
||||
if content:
|
||||
@@ -1580,6 +1603,7 @@ class Message(BaseMessage):
|
||||
raise TypeError("Anthropic API requires tool_use_id to be set.")
|
||||
|
||||
# This is for legacy reasons
|
||||
legacy_content = truncate_tool_return(text_content, tool_return_truncation_chars)
|
||||
anthropic_message = {
|
||||
"role": "user", # NOTE: diff
|
||||
"content": [
|
||||
@@ -1587,7 +1611,7 @@ class Message(BaseMessage):
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": self.tool_call_id,
|
||||
"content": text_content,
|
||||
"content": legacy_content,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -1606,6 +1630,7 @@ class Message(BaseMessage):
|
||||
# if true, then treat the content field as AssistantMessage
|
||||
native_content: bool = False,
|
||||
strip_request_heartbeat: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> List[dict]:
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result = [
|
||||
@@ -1615,6 +1640,7 @@ class Message(BaseMessage):
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
native_content=native_content,
|
||||
strip_request_heartbeat=strip_request_heartbeat,
|
||||
tool_return_truncation_chars=tool_return_truncation_chars,
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
@@ -1628,6 +1654,7 @@ class Message(BaseMessage):
|
||||
# if true, then treat the content field as AssistantMessage
|
||||
native_content: bool = False,
|
||||
strip_request_heartbeat: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Go from Message class to Google AI REST message object
|
||||
@@ -1776,11 +1803,14 @@ class Message(BaseMessage):
|
||||
# Use the function name if available, otherwise use tool_call_id
|
||||
function_name = self.name if self.name else tool_return.tool_call_id
|
||||
|
||||
# Truncate the tool return if needed
|
||||
func_response = truncate_tool_return(tool_return.func_response, tool_return_truncation_chars)
|
||||
|
||||
# NOTE: Google AI API wants the function response as JSON only, no string
|
||||
try:
|
||||
function_response = parse_json(tool_return.func_response)
|
||||
function_response = parse_json(func_response)
|
||||
except:
|
||||
function_response = {"function_response": tool_return.func_response}
|
||||
function_response = {"function_response": func_response}
|
||||
|
||||
parts.append(
|
||||
{
|
||||
@@ -1808,11 +1838,14 @@ class Message(BaseMessage):
|
||||
else:
|
||||
function_name = self.name
|
||||
|
||||
# Truncate the legacy content if needed
|
||||
legacy_content = truncate_tool_return(text_content, tool_return_truncation_chars)
|
||||
|
||||
# NOTE: Google AI API wants the function response as JSON only, no string
|
||||
try:
|
||||
function_response = parse_json(text_content)
|
||||
function_response = parse_json(legacy_content)
|
||||
except:
|
||||
function_response = {"function_response": text_content}
|
||||
function_response = {"function_response": legacy_content}
|
||||
|
||||
google_ai_message = {
|
||||
"role": "function",
|
||||
@@ -1848,6 +1881,7 @@ class Message(BaseMessage):
|
||||
current_model: str,
|
||||
put_inner_thoughts_in_kwargs: bool = True,
|
||||
native_content: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
):
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result = [
|
||||
@@ -1855,6 +1889,7 @@ class Message(BaseMessage):
|
||||
current_model=current_model,
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
native_content=native_content,
|
||||
tool_return_truncation_chars=tool_return_truncation_chars,
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
@@ -4,7 +4,13 @@ import traceback
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, MESSAGE_SUMMARY_REQUEST_ACK
|
||||
from letta.constants import (
|
||||
DEFAULT_MESSAGE_TOOL,
|
||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
MESSAGE_SUMMARY_REQUEST_ACK,
|
||||
TOOL_RETURN_TRUNCATION_CHARS,
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
@@ -394,7 +400,27 @@ async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor:
|
||||
response_data = await llm_client.request_async(request_data, summarizer_llm_config)
|
||||
except Exception as e:
|
||||
# handle LLM error (likely a context window exceeded error)
|
||||
raise llm_client.handle_llm_error(e)
|
||||
try:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
except ContextWindowExceededError as context_error:
|
||||
logger.warning(
|
||||
f"Context window exceeded during summarization, falling back to truncated tool returns. Original error: {context_error}"
|
||||
)
|
||||
|
||||
# Fallback: rebuild request with truncated tool returns
|
||||
request_data = llm_client.build_request_data(
|
||||
AgentType.letta_v1_agent,
|
||||
input_messages_obj,
|
||||
summarizer_llm_config,
|
||||
tools=[],
|
||||
tool_return_truncation_chars=TOOL_RETURN_TRUNCATION_CHARS,
|
||||
)
|
||||
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, summarizer_llm_config)
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback summarization also failed: {fallback_error}")
|
||||
raise llm_client.handle_llm_error(fallback_error)
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config)
|
||||
if response.choices[0].message.content is None:
|
||||
logger.warning("No content returned from summarizer")
|
||||
|
||||
9
tests/configs/llm_model_configs/claude-4-5-haiku.json
Normal file
9
tests/configs/llm_model_configs/claude-4-5-haiku.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"model": "claude-haiku-4-5",
|
||||
"model_endpoint_type": "anthropic",
|
||||
"model_endpoint": "https://api.anthropic.com/v1",
|
||||
"model_wrapper": null,
|
||||
"context_window": 200000,
|
||||
"put_inner_thoughts_in_kwargs": true,
|
||||
"enable_reasoner": true
|
||||
}
|
||||
@@ -39,13 +39,15 @@ def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model
|
||||
# Test configurations - using a subset of models for summarization tests
|
||||
all_configs = [
|
||||
"openai-gpt-5-mini.json",
|
||||
"claude-4-5-haiku.json",
|
||||
"gemini-2.5-flash.json",
|
||||
# "gemini-2.5-flash-vertex.json", # Requires Vertex AI credentials
|
||||
# "openai-gpt-4.1.json",
|
||||
# "openai-o1.json",
|
||||
# "openai-o3.json",
|
||||
# "openai-o4-mini.json",
|
||||
# "claude-4-sonnet.json",
|
||||
# "claude-3-7-sonnet.json",
|
||||
# "gemini-2.5-flash-vertex.json",
|
||||
# "gemini-2.5-pro-vertex.json",
|
||||
]
|
||||
|
||||
@@ -517,3 +519,86 @@ async def test_summarize_multiple_large_tool_calls(server: SyncServer, actor, ll
|
||||
assert hasattr(msg, "content")
|
||||
|
||||
print(f"Summarized {len(in_context_messages)} messages with {total_content_size} chars to {len(result)} messages")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
async def test_summarize_truncates_large_tool_return(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
"""
|
||||
Test that summarization properly truncates very large tool returns.
|
||||
This ensures that oversized tool returns don't consume excessive context.
|
||||
"""
|
||||
# Create an extremely large tool return (100k chars)
|
||||
large_return = create_large_tool_return(100000)
|
||||
original_size = len(large_return)
|
||||
|
||||
# Create messages with a large tool return
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(type="text", text="Please run the database query.")],
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
content=[
|
||||
TextContent(type="text", text="Running query..."),
|
||||
ToolCallContent(
|
||||
type="tool_call",
|
||||
id="call_1",
|
||||
name="run_query",
|
||||
input={"query": "SELECT * FROM large_table"},
|
||||
),
|
||||
],
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.tool,
|
||||
tool_call_id="call_1",
|
||||
content=[
|
||||
ToolReturnContent(
|
||||
type="tool_return",
|
||||
tool_call_id="call_1",
|
||||
content=large_return,
|
||||
is_error=False,
|
||||
)
|
||||
],
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(type="text", text="Query completed successfully with many results.")],
|
||||
),
|
||||
]
|
||||
|
||||
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
|
||||
# Verify the original tool return is indeed large
|
||||
assert original_size > 90000, f"Expected tool return >90k chars, got {original_size}"
|
||||
|
||||
# Run summarization
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
|
||||
# Find tool return messages in the result and verify truncation occurred
|
||||
tool_returns_found = False
|
||||
for msg in result:
|
||||
if msg.role == MessageRole.tool:
|
||||
for content in msg.content:
|
||||
if isinstance(content, ToolReturnContent):
|
||||
tool_returns_found = True
|
||||
result_size = len(content.content)
|
||||
# Verify that the tool return has been truncated
|
||||
assert result_size < original_size, (
|
||||
f"Expected tool return to be truncated from {original_size} chars, but got {result_size} chars"
|
||||
)
|
||||
print(f"Tool return successfully truncated from {original_size} to {result_size} chars")
|
||||
|
||||
# If we didn't find any tool returns in the result, that's also acceptable
|
||||
# (they may have been completely removed during aggressive summarization)
|
||||
if not tool_returns_found:
|
||||
print("Tool returns were completely removed during summarization")
|
||||
|
||||
Reference in New Issue
Block a user