feat: optimize in context message fetch (#2511)

This commit is contained in:
cthomas
2025-05-28 23:30:33 -07:00
committed by GitHub
parent ab2c8f7e2d
commit bdfcb9f87c
4 changed files with 9 additions and 15 deletions

View File

@@ -1286,7 +1286,7 @@ class Agent(BaseAgent):
# Grab the in-context messages
# conversion of messages to OpenAI dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
)
@@ -1408,7 +1408,7 @@ class Agent(BaseAgent):
# Grab the in-context messages
# conversion of messages to anthropic dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
)

View File

@@ -516,7 +516,7 @@ class LettaAgent(BaseAgent):
# Mirror the sync agent loop: get allowed tools or allow all if none are allowed
if self.last_function_response is None:
self.last_function_response = await self._load_last_function_response_async()
self.last_function_response = await self._load_last_function_response_async(agent_state)
valid_tool_names = tool_rules_solver.get_allowed_tool_names(
available_tools=set([t.name for t in tools]),
last_function_response=self.last_function_response,
@@ -737,9 +737,9 @@ class LettaAgent(BaseAgent):
return results
@trace_method
async def _load_last_function_response_async(self):
async def _load_last_function_response_async(self, agent_state: AgentState):
"""Load the last function response from message history"""
in_context_messages = await self.agent_manager.get_in_context_messages_async(agent_id=self.agent_id, actor=self.actor)
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=self.actor)
for msg in reversed(in_context_messages):
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
text_content = msg.content[0].text

View File

@@ -1312,12 +1312,6 @@ class AgentManager:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
@trace_method
@enforce_types
async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
@trace_method
@enforce_types
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
@@ -2684,7 +2678,7 @@ class AgentManager:
# Grab the in-context messages
# conversion of messages to anthropic dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.get_in_context_messages_async(agent_id=agent_id, actor=actor),
self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor),
self.passage_manager.size_async(actor=actor, agent_id=agent_id),
self.message_manager.size_async(actor=actor, agent_id=agent_id),
)
@@ -2832,7 +2826,7 @@ class AgentManager:
# Grab the in-context messages
# conversion of messages to OpenAI dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.get_in_context_messages_async(agent_id=agent_id, actor=actor),
self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor),
self.passage_manager.size_async(actor=actor, agent_id=agent_id),
self.message_manager.size_async(actor=actor, agent_id=agent_id),
)

View File

@@ -687,7 +687,7 @@ async def test_create_agent_passed_in_initial_messages(server: SyncServer, defau
actor=default_user,
)
assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 2
init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user)
init_messages = await server.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=default_user)
# Check that the system appears in the first initial message
assert create_agent_request.system in init_messages[0].content[0].text
@@ -715,7 +715,7 @@ async def test_create_agent_default_initial_message(server: SyncServer, default_
actor=default_user,
)
assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 4
init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user)
init_messages = await server.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=default_user)
# Check that the system appears in the first initial message
assert create_agent_request.system in init_messages[0].content[0].text
assert create_agent_request.memory_blocks[0].value in init_messages[0].content[0].text