From acda68c0a839a874e9569247d98048f383e81c57 Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Mon, 5 May 2025 21:02:23 -0700 Subject: [PATCH] fix: summarization trims tool call without trimming tool response (#2010) Co-authored-by: cthomas Co-authored-by: Sarah Wooders --- letta/services/agent_manager.py | 15 ++++++++-- tests/integration_test_summarizer.py | 44 ++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 190ca8a8..96b891dd 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -866,8 +866,9 @@ class AgentManager: @enforce_types def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids - new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message - return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) + newer_messages = self._trim_tool_response(agent_id=agent_id, actor=actor, message_ids=message_ids[num:]) + trimmed_messages = [message_ids[0]] + newer_messages # 0 is system message + return self.set_in_context_messages(agent_id=agent_id, message_ids=trimmed_messages, actor=actor) @enforce_types def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: @@ -876,6 +877,16 @@ class AgentManager: new_messages = [message_ids[0]] # 0 is system message return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) + def _trim_tool_response(self, agent_id: str, actor: PydanticUser, message_ids: list[str]) -> PydanticAgentState: + """ + Trims the tool response from the in-context messages if there is no tool call present in trimmed messages. + """ + if message_ids: + messages = self.message_manager.get_messages_by_ids(message_ids=[message_ids[0]], actor=actor) + if messages and messages[0].role == "tool": + return message_ids[1:] + return message_ids + @enforce_types def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 43a51e16..b38e03ba 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -113,6 +113,50 @@ def test_cutoff_calculation(mocker): assert messages[cutoff - 1].role == MessageRole.user +def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_state): + """Test that trim_older_in_context_messages properly handles tool responses with _trim_tool_response.""" + agent_state = client.get_agent(agent_id=agent_state.id) + + # Setup + messages = [ + generate_message("system"), + generate_message("user", text="First user message"), + generate_message( + "assistant", tool_calls=[{"id": "tool_call_1", "type": "function", "function": {"name": "test_function", "arguments": "{}"}}] + ), + generate_message("tool", text="First tool response"), + generate_message("assistant", text="First assistant response after tool"), + generate_message("user", text="Second user message"), + generate_message("assistant", text="Second assistant response"), + ] + + def mock_get_messages_by_ids(message_ids, actor): + return [msg for msg in messages if msg.id in message_ids] + + mocker.patch.object(client.server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids) + + # Mock get_agent_by_id to return an agent with our message IDs + mock_agent = mocker.Mock() + mock_agent.message_ids = [msg.id for msg in messages] + mocker.patch.object(client.server.agent_manager, "get_agent_by_id", return_value=mock_agent) + + # Mock set_in_context_messages to capture what messages are being set + mock_set_messages = mocker.patch.object(client.server.agent_manager, "set_in_context_messages", return_value=agent_state) + + # Test Case: Trim to remove orphaned tool response + client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user) + + test1 = mock_set_messages.call_args_list[0][1] + assert len(test1["message_ids"]) == 4 + + mock_set_messages.reset_mock() + + # Test Case: Does not result in trimming the orphaned tool response + client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=client.user) + test2 = mock_set_messages.call_args_list[0][1] + assert len(test2["message_ids"]) == 6 + + def test_summarize_many_messages_basic(client, disable_e2b_api_key): small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") small_context_llm_config.context_window = 3000