From bdfcb9f87c492ba4fe85be63133c2fbfecc5b551 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 28 May 2025 23:30:33 -0700 Subject: [PATCH] feat: optimize in context message fetch (#2511) --- letta/agent.py | 4 ++-- letta/agents/letta_agent.py | 6 +++--- letta/services/agent_manager.py | 10 ++-------- tests/test_managers.py | 4 ++-- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index d55e8a77..4c8b2553 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1286,7 +1286,7 @@ class Agent(BaseAgent): # Grab the in-context messages # conversion of messages to OpenAI dict format, which is passed to the token counter (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user), + self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user), self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id), self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id), ) @@ -1408,7 +1408,7 @@ class Agent(BaseAgent): # Grab the in-context messages # conversion of messages to anthropic dict format, which is passed to the token counter (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user), + self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user), self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id), self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id), ) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 1120ae04..5ff43790 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -516,7 +516,7 @@ class LettaAgent(BaseAgent): # Mirror the sync agent loop: get allowed tools or allow all if none are allowed if self.last_function_response is None: - self.last_function_response = await self._load_last_function_response_async() + self.last_function_response = await self._load_last_function_response_async(agent_state) valid_tool_names = tool_rules_solver.get_allowed_tool_names( available_tools=set([t.name for t in tools]), last_function_response=self.last_function_response, @@ -737,9 +737,9 @@ class LettaAgent(BaseAgent): return results @trace_method - async def _load_last_function_response_async(self): + async def _load_last_function_response_async(self, agent_state: AgentState): """Load the last function response from message history""" - in_context_messages = await self.agent_manager.get_in_context_messages_async(agent_id=self.agent_id, actor=self.actor) + in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=self.actor) for msg in reversed(in_context_messages): if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent): text_content = msg.content[0].text diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 95e721d0..0dd8877f 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1312,12 +1312,6 @@ class AgentManager: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor) - @trace_method - @enforce_types - async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: - agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) - return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor) - @trace_method @enforce_types def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: @@ -2684,7 +2678,7 @@ class AgentManager: # Grab the in-context messages # conversion of messages to anthropic dict format, which is passed to the token counter (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - self.get_in_context_messages_async(agent_id=agent_id, actor=actor), + self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor), self.passage_manager.size_async(actor=actor, agent_id=agent_id), self.message_manager.size_async(actor=actor, agent_id=agent_id), ) @@ -2832,7 +2826,7 @@ class AgentManager: # Grab the in-context messages # conversion of messages to OpenAI dict format, which is passed to the token counter (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - self.get_in_context_messages_async(agent_id=agent_id, actor=actor), + self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor), self.passage_manager.size_async(actor=actor, agent_id=agent_id), self.message_manager.size_async(actor=actor, agent_id=agent_id), ) diff --git a/tests/test_managers.py b/tests/test_managers.py index f268c76b..a7f236c0 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -687,7 +687,7 @@ async def test_create_agent_passed_in_initial_messages(server: SyncServer, defau actor=default_user, ) assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 2 - init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user) + init_messages = await server.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=default_user) # Check that the system appears in the first initial message assert create_agent_request.system in init_messages[0].content[0].text @@ -715,7 +715,7 @@ async def test_create_agent_default_initial_message(server: SyncServer, default_ actor=default_user, ) assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 4 - init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user) + init_messages = await server.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=default_user) # Check that the system appears in the first initial message assert create_agent_request.system in init_messages[0].content[0].text assert create_agent_request.memory_blocks[0].value in init_messages[0].content[0].text