fix: rebuilding memory async (#2149)

This commit is contained in:
Andy Li
2025-05-13 12:14:51 -07:00
committed by GitHub
parent 6e2e496272
commit ca895f1987
2 changed files with 19 additions and 2 deletions

View File

@@ -127,7 +127,13 @@ class BaseAgent(ABC):
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise
async def _rebuild_memory_async(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
async def _rebuild_memory_async(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None, # storing these calculations is specific to the voice agent
num_archival_memories: int | None = None,
) -> List[Message]:
"""
Async version of function above. For now before breaking up components, changes should be made in both places.
"""

View File

@@ -60,6 +60,10 @@ class LettaAgent(BaseAgent):
self.last_function_response = self._load_last_function_response()
# Cached archival memory/message size
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
@@ -164,6 +168,11 @@ class LettaAgent(BaseAgent):
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
# TODO: This may be out of sync, if in between steps users add files
# NOTE (cliandy): temporary for now for particlar use cases.
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
# TODO: Also yield out a letta usage stats SSE
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
@@ -179,7 +188,9 @@ class LettaAgent(BaseAgent):
stream: bool,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
if settings.experimental_enable_async_db_engine:
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")