fix: summarization includes tool call message before truncation (#2084)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Andy Li
2025-05-09 15:01:12 -07:00
committed by GitHub
parent f67ad6e0c6
commit bf1874dbc9
4 changed files with 10 additions and 21 deletions

View File

@@ -337,6 +337,10 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts
)
break
# includes the tool response to be summarized after a tool call so we don't have any hanging tool calls after trimming.
if i + 1 < len(in_context_messages_openai) and in_context_messages_openai[i + 1]["role"] == "tool":
cutoff += 1
logger.info(f"Evicting {cutoff}/{len(in_context_messages)} messages...")
return cutoff + 1

View File

@@ -866,9 +866,8 @@ class AgentManager:
@enforce_types
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
newer_messages = self._trim_tool_response(agent_id=agent_id, actor=actor, message_ids=message_ids[num:])
trimmed_messages = [message_ids[0]] + newer_messages # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=trimmed_messages, actor=actor)
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@enforce_types
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
@@ -877,16 +876,6 @@ class AgentManager:
new_messages = [message_ids[0]] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
def _trim_tool_response(self, agent_id: str, actor: PydanticUser, message_ids: list[str]) -> PydanticAgentState:
"""
Trims the tool response from the in-context messages if there is no tool call present in trimmed messages.
"""
if message_ids:
messages = self.message_manager.get_messages_by_ids(message_ids=[message_ids[0]], actor=actor)
if messages and messages[0].role == "tool":
return message_ids[1:]
return message_ids
@enforce_types
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids

View File

@@ -107,13 +107,9 @@ class Summarizer:
self.summarizer_agent.update_message_transcript(message_transcripts=formatted_evicted_messages + formatted_in_context_messages)
# Add line numbers to the formatted messages
line_number = 0
for i in range(len(formatted_evicted_messages)):
formatted_evicted_messages[i] = f"{line_number}. " + formatted_evicted_messages[i]
line_number += 1
for i in range(len(formatted_in_context_messages)):
formatted_in_context_messages[i] = f"{line_number}. " + formatted_in_context_messages[i]
line_number += 1
offset = len(formatted_evicted_messages)
formatted_evicted_messages = [f"{i}. {msg}" for (i, msg) in enumerate(formatted_evicted_messages)]
formatted_in_context_messages = [f"{i + offset}. {msg}" for (i, msg) in enumerate(formatted_in_context_messages)]
evicted_messages_str = "\n".join(formatted_evicted_messages)
in_context_messages_str = "\n".join(formatted_in_context_messages)

View File

@@ -147,7 +147,7 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st
client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user)
test1 = mock_set_messages.call_args_list[0][1]
assert len(test1["message_ids"]) == 4
assert len(test1["message_ids"]) == 5
mock_set_messages.reset_mock()