From f91cbda6eb24df69850db5a16d3bbad928d32898 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 17 Jul 2025 14:01:38 -0700 Subject: [PATCH] feat: remove concurrent db connection spawning (#3380) --- letta/agent.py | 36 ++++++++++++++++++++++++--------- letta/agents/letta_agent.py | 24 ++++++++++------------ letta/agents/voice_agent.py | 24 ++++++++++------------ letta/services/agent_manager.py | 17 ++++++---------- 4 files changed, 54 insertions(+), 47 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 87b083c5..bc40f12b 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1304,12 +1304,11 @@ class Agent(BaseAgent): async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview: """Get the context window of the agent""" # Grab the in-context messages - # conversion of messages to OpenAI dict format, which is passed to the token counter - (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user), - self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id), - self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id), + in_context_messages = await self.message_manager.get_messages_by_ids_async( + message_ids=self.agent_state.message_ids, actor=self.user ) + + # conversion of messages to OpenAI dict format, which is passed to the token counter in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] # Extract system, memory and external summary @@ -1396,6 +1395,15 @@ class Agent(BaseAgent): ) assert isinstance(num_tokens_used_total, int) + passage_manager_size = await self.passage_manager.agent_passage_size_async( + agent_id=self.agent_state.id, + actor=self.user, + ) + message_manager_size = await self.message_manager.size_async( + agent_id=self.agent_state.id, + actor=self.user, + ) + return ContextWindowOverview( # context window breakdown (in messages) num_messages=len(in_context_messages), @@ -1426,12 +1434,11 @@ class Agent(BaseAgent): model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None # Grab the in-context messages - # conversion of messages to anthropic dict format, which is passed to the token counter - (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( - self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user), - self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id), - self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id), + in_context_messages = await self.message_manager.get_messages_by_ids_async( + message_ids=self.agent_state.message_ids, actor=self.user ) + + # conversion of messages to anthropic dict format, which is passed to the token counter in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages] # Extract system, memory and external summary @@ -1546,6 +1553,15 @@ class Agent(BaseAgent): ) assert isinstance(num_tokens_used_total, int) + passage_manager_size = await self.passage_manager.agent_passage_size_async( + agent_id=self.agent_state.id, + actor=self.user, + ) + message_manager_size = await self.message_manager.size_async( + agent_id=self.agent_state.id, + actor=self.user, + ) + return ContextWindowOverview( # context window breakdown (in messages) num_messages=len(in_context_messages), diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 7ab25738..0bd21f37 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1,4 +1,3 @@ -import asyncio import json import uuid from collections.abc import AsyncGenerator @@ -948,18 +947,17 @@ class LettaAgent(BaseAgent): agent_state: AgentState, tool_rules_solver: ToolRulesSolver, ) -> tuple[dict, list[str]]: - self.num_messages, self.num_archival_memories = await asyncio.gather( - ( - self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) - if self.num_messages is None - else asyncio.sleep(0, result=self.num_messages) - ), - ( - self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id) - if self.num_archival_memories is None - else asyncio.sleep(0, result=self.num_archival_memories) - ), - ) + if not self.num_messages: + self.num_messages = await self.message_manager.size_async( + agent_id=agent_state.id, + actor=self.actor, + ) + if not self.num_archival_memories: + self.num_archival_memories = await self.passage_manager.agent_passage_size_async( + agent_id=agent_state.id, + actor=self.actor, + ) + in_context_messages = await self._rebuild_memory_async( in_context_messages, agent_state, diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 0c77626e..18e930d3 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -1,4 +1,3 @@ -import asyncio import json import uuid from datetime import datetime, timedelta, timezone @@ -308,18 +307,17 @@ class VoiceAgent(BaseAgent): in_context_messages: List[Message], agent_state: AgentState, ) -> List[Message]: - self.num_messages, self.num_archival_memories = await asyncio.gather( - ( - self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) - if self.num_messages is None - else asyncio.sleep(0, result=self.num_messages) - ), - ( - self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id) - if self.num_archival_memories is None - else asyncio.sleep(0, result=self.num_archival_memories) - ), - ) + if not self.num_messages: + self.num_messages = await self.message_manager.size_async( + agent_id=agent_state.id, + actor=self.actor, + ) + if not self.num_archival_memories: + self.num_archival_memories = await self.passage_manager.agent_passage_size_async( + agent_id=agent_state.id, + actor=self.actor, + ) + return await super()._rebuild_memory_async( in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories ) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 2b92915f..12d9de46 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1544,15 +1544,9 @@ class AgentManager: Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages """ - num_messages_task = self.message_manager.size_async(actor=actor, agent_id=agent_id) - num_archival_memories_task = self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id) - agent_state_task = self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor) - - num_messages, num_archival_memories, agent_state = await asyncio.gather( - num_messages_task, - num_archival_memories_task, - agent_state_task, - ) + num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id) + num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id) + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor) if not tool_rules_solver: tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) @@ -1769,9 +1763,10 @@ class AgentManager: ) # refresh memory from DB (using block ids) - blocks = await asyncio.gather( - *[self.block_manager.get_block_by_id_async(block.id, actor=actor) for block in agent_state.memory.get_blocks()] + blocks = await self.block_manager.get_all_blocks_by_ids_async( + block_ids=[b.id for b in agent_state.memory.get_blocks()], actor=actor ) + agent_state.memory = Memory( blocks=blocks, file_blocks=agent_state.memory.file_blocks,