From bf1874dbc9ba95dd0e75725c3db67a8594e3d618 Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Fri, 9 May 2025 15:01:12 -0700 Subject: [PATCH] fix: summarization includes tool call message before truncation (#2084) Co-authored-by: Sarah Wooders --- letta/llm_api/helpers.py | 4 ++++ letta/services/agent_manager.py | 15 ++------------- letta/services/summarizer/summarizer.py | 10 +++------- tests/integration_test_summarizer.py | 2 +- 4 files changed, 10 insertions(+), 21 deletions(-) diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 6d6f77c1..ed497e09 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -337,6 +337,10 @@ def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts ) break + # includes the tool response to be summarized after a tool call so we don't have any hanging tool calls after trimming. + if i + 1 < len(in_context_messages_openai) and in_context_messages_openai[i + 1]["role"] == "tool": + cutoff += 1 + logger.info(f"Evicting {cutoff}/{len(in_context_messages)} messages...") return cutoff + 1 diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 96b891dd..190ca8a8 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -866,9 +866,8 @@ class AgentManager: @enforce_types def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids - newer_messages = self._trim_tool_response(agent_id=agent_id, actor=actor, message_ids=message_ids[num:]) - trimmed_messages = [message_ids[0]] + newer_messages # 0 is system message - return self.set_in_context_messages(agent_id=agent_id, message_ids=trimmed_messages, actor=actor) + new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message + return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) @enforce_types def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: @@ -877,16 +876,6 @@ class AgentManager: new_messages = [message_ids[0]] # 0 is system message return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) - def _trim_tool_response(self, agent_id: str, actor: PydanticUser, message_ids: list[str]) -> PydanticAgentState: - """ - Trims the tool response from the in-context messages if there is no tool call present in trimmed messages. - """ - if message_ids: - messages = self.message_manager.get_messages_by_ids(message_ids=[message_ids[0]], actor=actor) - if messages and messages[0].role == "tool": - return message_ids[1:] - return message_ids - @enforce_types def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index 6fd2d1e3..55dd62f8 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -107,13 +107,9 @@ class Summarizer: self.summarizer_agent.update_message_transcript(message_transcripts=formatted_evicted_messages + formatted_in_context_messages) # Add line numbers to the formatted messages - line_number = 0 - for i in range(len(formatted_evicted_messages)): - formatted_evicted_messages[i] = f"{line_number}. " + formatted_evicted_messages[i] - line_number += 1 - for i in range(len(formatted_in_context_messages)): - formatted_in_context_messages[i] = f"{line_number}. " + formatted_in_context_messages[i] - line_number += 1 + offset = len(formatted_evicted_messages) + formatted_evicted_messages = [f"{i}. {msg}" for (i, msg) in enumerate(formatted_evicted_messages)] + formatted_in_context_messages = [f"{i + offset}. {msg}" for (i, msg) in enumerate(formatted_in_context_messages)] evicted_messages_str = "\n".join(formatted_evicted_messages) in_context_messages_str = "\n".join(formatted_in_context_messages) diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index b38e03ba..6c0f74c1 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -147,7 +147,7 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user) test1 = mock_set_messages.call_args_list[0][1] - assert len(test1["message_ids"]) == 4 + assert len(test1["message_ids"]) == 5 mock_set_messages.reset_mock()