feat: remove concurrent db connection spawning (#3380)
This commit is contained in:
@@ -1304,12 +1304,11 @@ class Agent(BaseAgent):
|
||||
async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview:
|
||||
"""Get the context window of the agent"""
|
||||
# Grab the in-context messages
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
|
||||
self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
in_context_messages = await self.message_manager.get_messages_by_ids_async(
|
||||
message_ids=self.agent_state.message_ids, actor=self.user
|
||||
)
|
||||
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
|
||||
# Extract system, memory and external summary
|
||||
@@ -1396,6 +1395,15 @@ class Agent(BaseAgent):
|
||||
)
|
||||
assert isinstance(num_tokens_used_total, int)
|
||||
|
||||
passage_manager_size = await self.passage_manager.agent_passage_size_async(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
)
|
||||
message_manager_size = await self.message_manager.size_async(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(in_context_messages),
|
||||
@@ -1426,12 +1434,11 @@ class Agent(BaseAgent):
|
||||
model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None
|
||||
|
||||
# Grab the in-context messages
|
||||
# conversion of messages to anthropic dict format, which is passed to the token counter
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
|
||||
self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
in_context_messages = await self.message_manager.get_messages_by_ids_async(
|
||||
message_ids=self.agent_state.message_ids, actor=self.user
|
||||
)
|
||||
|
||||
# conversion of messages to anthropic dict format, which is passed to the token counter
|
||||
in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages]
|
||||
|
||||
# Extract system, memory and external summary
|
||||
@@ -1546,6 +1553,15 @@ class Agent(BaseAgent):
|
||||
)
|
||||
assert isinstance(num_tokens_used_total, int)
|
||||
|
||||
passage_manager_size = await self.passage_manager.agent_passage_size_async(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
)
|
||||
message_manager_size = await self.message_manager.size_async(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(in_context_messages),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -948,18 +947,17 @@ class LettaAgent(BaseAgent):
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
) -> tuple[dict, list[str]]:
|
||||
self.num_messages, self.num_archival_memories = await asyncio.gather(
|
||||
(
|
||||
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if self.num_messages is None
|
||||
else asyncio.sleep(0, result=self.num_messages)
|
||||
),
|
||||
(
|
||||
self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if self.num_archival_memories is None
|
||||
else asyncio.sleep(0, result=self.num_archival_memories)
|
||||
),
|
||||
)
|
||||
if not self.num_messages:
|
||||
self.num_messages = await self.message_manager.size_async(
|
||||
agent_id=agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
if not self.num_archival_memories:
|
||||
self.num_archival_memories = await self.passage_manager.agent_passage_size_async(
|
||||
agent_id=agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages,
|
||||
agent_state,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -308,18 +307,17 @@ class VoiceAgent(BaseAgent):
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
) -> List[Message]:
|
||||
self.num_messages, self.num_archival_memories = await asyncio.gather(
|
||||
(
|
||||
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if self.num_messages is None
|
||||
else asyncio.sleep(0, result=self.num_messages)
|
||||
),
|
||||
(
|
||||
self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if self.num_archival_memories is None
|
||||
else asyncio.sleep(0, result=self.num_archival_memories)
|
||||
),
|
||||
)
|
||||
if not self.num_messages:
|
||||
self.num_messages = await self.message_manager.size_async(
|
||||
agent_id=agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
if not self.num_archival_memories:
|
||||
self.num_archival_memories = await self.passage_manager.agent_passage_size_async(
|
||||
agent_id=agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
return await super()._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
|
||||
@@ -1544,15 +1544,9 @@ class AgentManager:
|
||||
|
||||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||||
"""
|
||||
num_messages_task = self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories_task = self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
|
||||
agent_state_task = self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
|
||||
|
||||
num_messages, num_archival_memories, agent_state = await asyncio.gather(
|
||||
num_messages_task,
|
||||
num_archival_memories_task,
|
||||
agent_state_task,
|
||||
)
|
||||
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
|
||||
|
||||
if not tool_rules_solver:
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
@@ -1769,9 +1763,10 @@ class AgentManager:
|
||||
)
|
||||
|
||||
# refresh memory from DB (using block ids)
|
||||
blocks = await asyncio.gather(
|
||||
*[self.block_manager.get_block_by_id_async(block.id, actor=actor) for block in agent_state.memory.get_blocks()]
|
||||
blocks = await self.block_manager.get_all_blocks_by_ids_async(
|
||||
block_ids=[b.id for b in agent_state.memory.get_blocks()], actor=actor
|
||||
)
|
||||
|
||||
agent_state.memory = Memory(
|
||||
blocks=blocks,
|
||||
file_blocks=agent_state.memory.file_blocks,
|
||||
|
||||
Reference in New Issue
Block a user