fix: summarization trims tool call without trimming tool response (#2010)
Co-authored-by: cthomas <caren@letta.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
@@ -866,8 +866,9 @@ class AgentManager:
|
||||
@enforce_types
|
||||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
newer_messages = self._trim_tool_response(agent_id=agent_id, actor=actor, message_ids=message_ids[num:])
|
||||
trimmed_messages = [message_ids[0]] + newer_messages # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=trimmed_messages, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
@@ -876,6 +877,16 @@ class AgentManager:
|
||||
new_messages = [message_ids[0]] # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
def _trim_tool_response(self, agent_id: str, actor: PydanticUser, message_ids: list[str]) -> PydanticAgentState:
|
||||
"""
|
||||
Trims the tool response from the in-context messages if there is no tool call present in trimmed messages.
|
||||
"""
|
||||
if message_ids:
|
||||
messages = self.message_manager.get_messages_by_ids(message_ids=[message_ids[0]], actor=actor)
|
||||
if messages and messages[0].role == "tool":
|
||||
return message_ids[1:]
|
||||
return message_ids
|
||||
|
||||
@enforce_types
|
||||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
|
||||
@@ -113,6 +113,50 @@ def test_cutoff_calculation(mocker):
|
||||
assert messages[cutoff - 1].role == MessageRole.user
|
||||
|
||||
|
||||
def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_state):
|
||||
"""Test that trim_older_in_context_messages properly handles tool responses with _trim_tool_response."""
|
||||
agent_state = client.get_agent(agent_id=agent_state.id)
|
||||
|
||||
# Setup
|
||||
messages = [
|
||||
generate_message("system"),
|
||||
generate_message("user", text="First user message"),
|
||||
generate_message(
|
||||
"assistant", tool_calls=[{"id": "tool_call_1", "type": "function", "function": {"name": "test_function", "arguments": "{}"}}]
|
||||
),
|
||||
generate_message("tool", text="First tool response"),
|
||||
generate_message("assistant", text="First assistant response after tool"),
|
||||
generate_message("user", text="Second user message"),
|
||||
generate_message("assistant", text="Second assistant response"),
|
||||
]
|
||||
|
||||
def mock_get_messages_by_ids(message_ids, actor):
|
||||
return [msg for msg in messages if msg.id in message_ids]
|
||||
|
||||
mocker.patch.object(client.server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids)
|
||||
|
||||
# Mock get_agent_by_id to return an agent with our message IDs
|
||||
mock_agent = mocker.Mock()
|
||||
mock_agent.message_ids = [msg.id for msg in messages]
|
||||
mocker.patch.object(client.server.agent_manager, "get_agent_by_id", return_value=mock_agent)
|
||||
|
||||
# Mock set_in_context_messages to capture what messages are being set
|
||||
mock_set_messages = mocker.patch.object(client.server.agent_manager, "set_in_context_messages", return_value=agent_state)
|
||||
|
||||
# Test Case: Trim to remove orphaned tool response
|
||||
client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user)
|
||||
|
||||
test1 = mock_set_messages.call_args_list[0][1]
|
||||
assert len(test1["message_ids"]) == 4
|
||||
|
||||
mock_set_messages.reset_mock()
|
||||
|
||||
# Test Case: Does not result in trimming the orphaned tool response
|
||||
client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=client.user)
|
||||
test2 = mock_set_messages.call_args_list[0][1]
|
||||
assert len(test2["message_ids"]) == 6
|
||||
|
||||
|
||||
def test_summarize_many_messages_basic(client, disable_e2b_api_key):
|
||||
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
|
||||
small_context_llm_config.context_window = 3000
|
||||
|
||||
Reference in New Issue
Block a user