fix: summarization includes tool call message before truncation (#2084)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
@@ -337,6 +337,10 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts
|
||||
)
|
||||
break
|
||||
|
||||
# includes the tool response to be summarized after a tool call so we don't have any hanging tool calls after trimming.
|
||||
if i + 1 < len(in_context_messages_openai) and in_context_messages_openai[i + 1]["role"] == "tool":
|
||||
cutoff += 1
|
||||
|
||||
logger.info(f"Evicting {cutoff}/{len(in_context_messages)} messages...")
|
||||
return cutoff + 1
|
||||
|
||||
|
||||
@@ -866,9 +866,8 @@ class AgentManager:
|
||||
@enforce_types
|
||||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
newer_messages = self._trim_tool_response(agent_id=agent_id, actor=actor, message_ids=message_ids[num:])
|
||||
trimmed_messages = [message_ids[0]] + newer_messages # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=trimmed_messages, actor=actor)
|
||||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
@@ -877,16 +876,6 @@ class AgentManager:
|
||||
new_messages = [message_ids[0]] # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
def _trim_tool_response(self, agent_id: str, actor: PydanticUser, message_ids: list[str]) -> PydanticAgentState:
|
||||
"""
|
||||
Trims the tool response from the in-context messages if there is no tool call present in trimmed messages.
|
||||
"""
|
||||
if message_ids:
|
||||
messages = self.message_manager.get_messages_by_ids(message_ids=[message_ids[0]], actor=actor)
|
||||
if messages and messages[0].role == "tool":
|
||||
return message_ids[1:]
|
||||
return message_ids
|
||||
|
||||
@enforce_types
|
||||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
|
||||
@@ -107,13 +107,9 @@ class Summarizer:
|
||||
self.summarizer_agent.update_message_transcript(message_transcripts=formatted_evicted_messages + formatted_in_context_messages)
|
||||
|
||||
# Add line numbers to the formatted messages
|
||||
line_number = 0
|
||||
for i in range(len(formatted_evicted_messages)):
|
||||
formatted_evicted_messages[i] = f"{line_number}. " + formatted_evicted_messages[i]
|
||||
line_number += 1
|
||||
for i in range(len(formatted_in_context_messages)):
|
||||
formatted_in_context_messages[i] = f"{line_number}. " + formatted_in_context_messages[i]
|
||||
line_number += 1
|
||||
offset = len(formatted_evicted_messages)
|
||||
formatted_evicted_messages = [f"{i}. {msg}" for (i, msg) in enumerate(formatted_evicted_messages)]
|
||||
formatted_in_context_messages = [f"{i + offset}. {msg}" for (i, msg) in enumerate(formatted_in_context_messages)]
|
||||
|
||||
evicted_messages_str = "\n".join(formatted_evicted_messages)
|
||||
in_context_messages_str = "\n".join(formatted_in_context_messages)
|
||||
|
||||
@@ -147,7 +147,7 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st
|
||||
client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user)
|
||||
|
||||
test1 = mock_set_messages.call_args_list[0][1]
|
||||
assert len(test1["message_ids"]) == 4
|
||||
assert len(test1["message_ids"]) == 5
|
||||
|
||||
mock_set_messages.reset_mock()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user