feat: remove concurrent db connection spawning (#3380)

This commit is contained in:
cthomas
2025-07-17 14:01:38 -07:00
committed by GitHub
parent a6f7114c96
commit f91cbda6eb
4 changed files with 54 additions and 47 deletions

View File

@@ -1304,12 +1304,11 @@ class Agent(BaseAgent):
async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview:
"""Get the context window of the agent"""
# Grab the in-context messages
# conversion of messages to OpenAI dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id),
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
in_context_messages = await self.message_manager.get_messages_by_ids_async(
message_ids=self.agent_state.message_ids, actor=self.user
)
# conversion of messages to OpenAI dict format, which is passed to the token counter
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
# Extract system, memory and external summary
@@ -1396,6 +1395,15 @@ class Agent(BaseAgent):
)
assert isinstance(num_tokens_used_total, int)
passage_manager_size = await self.passage_manager.agent_passage_size_async(
agent_id=self.agent_state.id,
actor=self.user,
)
message_manager_size = await self.message_manager.size_async(
agent_id=self.agent_state.id,
actor=self.user,
)
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(in_context_messages),
@@ -1426,12 +1434,11 @@ class Agent(BaseAgent):
model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None
# Grab the in-context messages
# conversion of messages to anthropic dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id),
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
in_context_messages = await self.message_manager.get_messages_by_ids_async(
message_ids=self.agent_state.message_ids, actor=self.user
)
# conversion of messages to anthropic dict format, which is passed to the token counter
in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages]
# Extract system, memory and external summary
@@ -1546,6 +1553,15 @@ class Agent(BaseAgent):
)
assert isinstance(num_tokens_used_total, int)
passage_manager_size = await self.passage_manager.agent_passage_size_async(
agent_id=self.agent_state.id,
actor=self.user,
)
message_manager_size = await self.message_manager.size_async(
agent_id=self.agent_state.id,
actor=self.user,
)
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(in_context_messages),

View File

@@ -1,4 +1,3 @@
import asyncio
import json
import uuid
from collections.abc import AsyncGenerator
@@ -948,18 +947,17 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
) -> tuple[dict, list[str]]:
self.num_messages, self.num_archival_memories = await asyncio.gather(
(
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
if self.num_messages is None
else asyncio.sleep(0, result=self.num_messages)
),
(
self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
if self.num_archival_memories is None
else asyncio.sleep(0, result=self.num_archival_memories)
),
)
if not self.num_messages:
self.num_messages = await self.message_manager.size_async(
agent_id=agent_state.id,
actor=self.actor,
)
if not self.num_archival_memories:
self.num_archival_memories = await self.passage_manager.agent_passage_size_async(
agent_id=agent_state.id,
actor=self.actor,
)
in_context_messages = await self._rebuild_memory_async(
in_context_messages,
agent_state,

View File

@@ -1,4 +1,3 @@
import asyncio
import json
import uuid
from datetime import datetime, timedelta, timezone
@@ -308,18 +307,17 @@ class VoiceAgent(BaseAgent):
in_context_messages: List[Message],
agent_state: AgentState,
) -> List[Message]:
self.num_messages, self.num_archival_memories = await asyncio.gather(
(
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
if self.num_messages is None
else asyncio.sleep(0, result=self.num_messages)
),
(
self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
if self.num_archival_memories is None
else asyncio.sleep(0, result=self.num_archival_memories)
),
)
if not self.num_messages:
self.num_messages = await self.message_manager.size_async(
agent_id=agent_state.id,
actor=self.actor,
)
if not self.num_archival_memories:
self.num_archival_memories = await self.passage_manager.agent_passage_size_async(
agent_id=agent_state.id,
actor=self.actor,
)
return await super()._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)

View File

@@ -1544,15 +1544,9 @@ class AgentManager:
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
"""
num_messages_task = self.message_manager.size_async(actor=actor, agent_id=agent_id)
num_archival_memories_task = self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
agent_state_task = self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
num_messages, num_archival_memories, agent_state = await asyncio.gather(
num_messages_task,
num_archival_memories_task,
agent_state_task,
)
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
if not tool_rules_solver:
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
@@ -1769,9 +1763,10 @@ class AgentManager:
)
# refresh memory from DB (using block ids)
blocks = await asyncio.gather(
*[self.block_manager.get_block_by_id_async(block.id, actor=actor) for block in agent_state.memory.get_blocks()]
blocks = await self.block_manager.get_all_blocks_by_ids_async(
block_ids=[b.id for b in agent_state.memory.get_blocks()], actor=actor
)
agent_state.memory = Memory(
blocks=blocks,
file_blocks=agent_state.memory.file_blocks,