fix: summarization trims tool call without trimming tool response (#2010)

Co-authored-by: cthomas <caren@letta.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Andy Li
2025-05-05 21:02:23 -07:00
committed by GitHub
parent adfe8606a1
commit acda68c0a8
2 changed files with 57 additions and 2 deletions

View File

@@ -866,8 +866,9 @@ class AgentManager:
@enforce_types
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
newer_messages = self._trim_tool_response(agent_id=agent_id, actor=actor, message_ids=message_ids[num:])
trimmed_messages = [message_ids[0]] + newer_messages # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=trimmed_messages, actor=actor)
@enforce_types
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
@@ -876,6 +877,16 @@ class AgentManager:
new_messages = [message_ids[0]] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
def _trim_tool_response(self, agent_id: str, actor: PydanticUser, message_ids: list[str]) -> PydanticAgentState:
"""
Trims the tool response from the in-context messages if there is no tool call present in trimmed messages.
"""
if message_ids:
messages = self.message_manager.get_messages_by_ids(message_ids=[message_ids[0]], actor=actor)
if messages and messages[0].role == "tool":
return message_ids[1:]
return message_ids
@enforce_types
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids

View File

@@ -113,6 +113,50 @@ def test_cutoff_calculation(mocker):
assert messages[cutoff - 1].role == MessageRole.user
def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_state):
"""Test that trim_older_in_context_messages properly handles tool responses with _trim_tool_response."""
agent_state = client.get_agent(agent_id=agent_state.id)
# Setup
messages = [
generate_message("system"),
generate_message("user", text="First user message"),
generate_message(
"assistant", tool_calls=[{"id": "tool_call_1", "type": "function", "function": {"name": "test_function", "arguments": "{}"}}]
),
generate_message("tool", text="First tool response"),
generate_message("assistant", text="First assistant response after tool"),
generate_message("user", text="Second user message"),
generate_message("assistant", text="Second assistant response"),
]
def mock_get_messages_by_ids(message_ids, actor):
return [msg for msg in messages if msg.id in message_ids]
mocker.patch.object(client.server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids)
# Mock get_agent_by_id to return an agent with our message IDs
mock_agent = mocker.Mock()
mock_agent.message_ids = [msg.id for msg in messages]
mocker.patch.object(client.server.agent_manager, "get_agent_by_id", return_value=mock_agent)
# Mock set_in_context_messages to capture what messages are being set
mock_set_messages = mocker.patch.object(client.server.agent_manager, "set_in_context_messages", return_value=agent_state)
# Test Case: Trim to remove orphaned tool response
client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user)
test1 = mock_set_messages.call_args_list[0][1]
assert len(test1["message_ids"]) == 4
mock_set_messages.reset_mock()
# Test Case: Does not result in trimming the orphaned tool response
client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=client.user)
test2 = mock_set_messages.call_args_list[0][1]
assert len(test2["message_ids"]) == 6
def test_summarize_many_messages_basic(client, disable_e2b_api_key):
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 3000