fix: rebuilding memory async (#2149)
This commit is contained in:
@@ -127,7 +127,13 @@ class BaseAgent(ABC):
|
||||
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
|
||||
raise
|
||||
|
||||
async def _rebuild_memory_async(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
async def _rebuild_memory_async(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None, # storing these calculations is specific to the voice agent
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Async version of function above. For now before breaking up components, changes should be made in both places.
|
||||
"""
|
||||
|
||||
@@ -60,6 +60,10 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
self.last_function_response = self._load_last_function_response()
|
||||
|
||||
# Cached archival memory/message size
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
@@ -164,6 +168,11 @@ class LettaAgent(BaseAgent):
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
|
||||
# TODO: This may be out of sync, if in between steps users add files
|
||||
# NOTE (cliandy): temporary for now for particlar use cases.
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
# TODO: Also yield out a letta usage stats SSE
|
||||
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
@@ -179,7 +188,9 @@ class LettaAgent(BaseAgent):
|
||||
stream: bool,
|
||||
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
|
||||
Reference in New Issue
Block a user