feat(asyncify): remove non async memory methods (#2245)
This commit is contained in:
@@ -72,61 +72,6 @@ class BaseAgent(ABC):
|
||||
|
||||
return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages]
|
||||
|
||||
def _rebuild_memory(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None, # storing these calculations is specific to the voice agent
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
try:
|
||||
# Refresh Memory
|
||||
# TODO: This only happens for the summary block (voice?)
|
||||
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
|
||||
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
# [DB Call] size of messages and archival memories
|
||||
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# [DB Call] Update Messages
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
except:
|
||||
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
|
||||
raise
|
||||
|
||||
async def _rebuild_memory_async(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
|
||||
@@ -37,7 +37,6 @@ from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import NoopStepManager, StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.settings import settings
|
||||
from letta.system import package_function_response
|
||||
from letta.tracing import log_event, trace_method, tracer
|
||||
|
||||
@@ -97,16 +96,12 @@ class LettaAgent(BaseAgent):
|
||||
for _ in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
current_in_context_messages + new_in_context_messages,
|
||||
agent_state,
|
||||
num_messages=self.num_messages,
|
||||
num_archival_memories=self.num_archival_memories,
|
||||
)
|
||||
log_event("agent.stream_no_tokens.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
@@ -200,16 +195,12 @@ class LettaAgent(BaseAgent):
|
||||
for _ in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
current_in_context_messages + new_in_context_messages,
|
||||
agent_state,
|
||||
num_messages=self.num_messages,
|
||||
num_archival_memories=self.num_archival_memories,
|
||||
)
|
||||
log_event("agent.step.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
@@ -299,17 +290,12 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
for _ in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
current_in_context_messages + new_in_context_messages,
|
||||
agent_state,
|
||||
num_messages=self.num_messages,
|
||||
num_archival_memories=self.num_archival_memories,
|
||||
)
|
||||
log_event("agent.step.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
@@ -439,19 +425,13 @@ class LettaAgent(BaseAgent):
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
|
||||
self.num_archival_memories = self.num_archival_memories or (
|
||||
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
|
||||
self.num_archival_memories = self.num_archival_memories or (
|
||||
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
|
||||
tools = [
|
||||
t
|
||||
|
||||
@@ -556,16 +556,6 @@ class LettaAgentBatch(BaseAgent):
|
||||
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)
|
||||
return in_context_messages
|
||||
|
||||
# TODO: Make this a bullk function
|
||||
def _rebuild_memory(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None,
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
return super()._rebuild_memory(in_context_messages, agent_state)
|
||||
|
||||
# Not used in batch.
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -154,7 +154,7 @@ class VoiceAgent(BaseAgent):
|
||||
# TODO: Define max steps here
|
||||
for _ in range(max_steps):
|
||||
# Rebuild memory each loop
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
|
||||
openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True)
|
||||
openai_messages.extend(in_memory_message_history)
|
||||
|
||||
@@ -292,14 +292,14 @@ class VoiceAgent(BaseAgent):
|
||||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||||
)
|
||||
|
||||
def _rebuild_memory(
|
||||
async def _rebuild_memory_async(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None,
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
return super()._rebuild_memory(
|
||||
return super()._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
|
||||
|
||||
@@ -2773,11 +2773,8 @@ class LocalClient(AbstractClient):
|
||||
|
||||
# humans / personas
|
||||
|
||||
def get_block_id(self, name: str, label: str) -> str:
|
||||
block = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label=label, is_template=True)
|
||||
if not block:
|
||||
return None
|
||||
return block[0].id
|
||||
def get_block_id(self, name: str, label: str) -> str | None:
|
||||
return None
|
||||
|
||||
def create_human(self, name: str, text: str):
|
||||
"""
|
||||
@@ -2812,7 +2809,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
humans (List[Human]): List of human blocks
|
||||
"""
|
||||
return self.server.block_manager.get_blocks(actor=self.user, label="human", is_template=True)
|
||||
return []
|
||||
|
||||
def list_personas(self) -> List[Persona]:
|
||||
"""
|
||||
@@ -2821,7 +2818,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
personas (List[Persona]): List of persona blocks
|
||||
"""
|
||||
return self.server.block_manager.get_blocks(actor=self.user, label="persona", is_template=True)
|
||||
return []
|
||||
|
||||
def update_human(self, human_id: str, text: str):
|
||||
"""
|
||||
@@ -2879,7 +2876,7 @@ class LocalClient(AbstractClient):
|
||||
assert id, f"Human ID must be provided"
|
||||
return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump())
|
||||
|
||||
def get_persona_id(self, name: str) -> str:
|
||||
def get_persona_id(self, name: str) -> str | None:
|
||||
"""
|
||||
Get the ID of a persona block template
|
||||
|
||||
@@ -2889,12 +2886,9 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
id (str): ID of the persona block
|
||||
"""
|
||||
persona = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="persona", is_template=True)
|
||||
if not persona:
|
||||
return None
|
||||
return persona[0].id
|
||||
return None
|
||||
|
||||
def get_human_id(self, name: str) -> str:
|
||||
def get_human_id(self, name: str) -> str | None:
|
||||
"""
|
||||
Get the ID of a human block template
|
||||
|
||||
@@ -2904,10 +2898,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
id (str): ID of the human block
|
||||
"""
|
||||
human = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="human", is_template=True)
|
||||
if not human:
|
||||
return None
|
||||
return human[0].id
|
||||
return None
|
||||
|
||||
def delete_persona(self, id: str):
|
||||
"""
|
||||
@@ -3381,7 +3372,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
blocks (List[Block]): List of blocks
|
||||
"""
|
||||
return self.server.block_manager.get_blocks(actor=self.user, label=label, is_template=templates_only)
|
||||
return []
|
||||
|
||||
def create_block(
|
||||
self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False
|
||||
|
||||
@@ -1468,17 +1468,6 @@ class AgentManager:
|
||||
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
def refresh_memory(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||||
if not block_ids:
|
||||
return agent_state
|
||||
|
||||
agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(
|
||||
block_ids=[b.id for b in agent_state.memory.blocks], actor=actor
|
||||
)
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||||
|
||||
@@ -82,43 +82,6 @@ class BlockManager:
|
||||
block.hard_delete(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_blocks(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
label: Optional[str] = None,
|
||||
is_template: Optional[bool] = None,
|
||||
template_name: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks based on various optional filters."""
|
||||
with db_registry.session() as session:
|
||||
# Prepare filters
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if label:
|
||||
filters["label"] = label
|
||||
if is_template is not None:
|
||||
filters["is_template"] = is_template
|
||||
if template_name:
|
||||
filters["template_name"] = template_name
|
||||
if id:
|
||||
filters["id"] = id
|
||||
|
||||
blocks = BlockModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
identifier_keys=identifier_keys,
|
||||
identity_id=identity_id,
|
||||
**filters,
|
||||
)
|
||||
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
@enforce_types
|
||||
async def get_blocks_async(
|
||||
self,
|
||||
@@ -191,15 +154,6 @@ class BlockManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids."""
|
||||
with db_registry.session() as session:
|
||||
blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)]
|
||||
# backwards compatibility. previous implementation added None for every block not found.
|
||||
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
|
||||
return blocks
|
||||
|
||||
@enforce_types
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
|
||||
@@ -247,18 +201,6 @@ class BlockManager:
|
||||
|
||||
return pydantic_blocks
|
||||
|
||||
@enforce_types
|
||||
def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Retrieve all agents associated with a given block.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
agents_orm = block.agents
|
||||
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
|
||||
|
||||
return agents_pydantic
|
||||
|
||||
@enforce_types
|
||||
async def get_agents_for_block_async(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
|
||||
"""
|
||||
|
||||
@@ -106,7 +106,7 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
# 3. Verify shared blocks
|
||||
sleeptime_agent_id = group.agent_ids[0]
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
||||
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
assert sleeptime_agent_id in [agent.id for agent in agents]
|
||||
assert main_agent.id in [agent.id for agent in agents]
|
||||
@@ -220,7 +220,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
|
||||
# 3. Verify shared blocks
|
||||
sleeptime_agent_id = group.agent_ids[0]
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
||||
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
assert sleeptime_agent_id in [agent.id for agent in agents]
|
||||
assert main_agent.id in [agent.id for agent in agents]
|
||||
|
||||
@@ -511,7 +511,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
|
||||
# 3. Verify shared blocks
|
||||
sleeptime_agent_id = group.agent_ids[0]
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=voice_agent.id, block_label="human", actor=actor)
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
||||
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
assert sleeptime_agent_id in [agent.id for agent in agents]
|
||||
assert voice_agent.id in [agent.id for agent in agents]
|
||||
|
||||
@@ -1750,29 +1750,6 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
|
||||
assert block.label == default_block.label
|
||||
|
||||
|
||||
def test_refresh_memory(server: SyncServer, default_user):
|
||||
block = server.block_manager.create_or_update_block(
|
||||
PydanticBlock(
|
||||
label="test",
|
||||
value="test",
|
||||
limit=1000,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
agent = server.agent_manager.create_agent(
|
||||
CreateAgent(
|
||||
name="test",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(agent.memory.blocks) == 0
|
||||
agent = server.agent_manager.refresh_memory(agent_state=agent, actor=default_user)
|
||||
assert len(agent.memory.blocks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_memory_async(server: SyncServer, default_user, event_loop):
|
||||
block = server.block_manager.create_or_update_block(
|
||||
@@ -2826,7 +2803,8 @@ async def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent,
|
||||
assert not (block.id in [b.id for b in agent_state.memory.blocks])
|
||||
|
||||
|
||||
def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop):
|
||||
# Create and delete a block
|
||||
block = server.block_manager.create_or_update_block(PydanticBlock(label="alien", value="Sample content"), actor=default_user)
|
||||
sarah_agent = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user)
|
||||
@@ -2837,7 +2815,7 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
|
||||
assert block.id in [b.id for b in charles_agent.memory.blocks]
|
||||
|
||||
# Get the agents for that block
|
||||
agent_states = server.block_manager.get_agents_for_block(block_id=block.id, actor=default_user)
|
||||
agent_states = await server.block_manager.get_agents_for_block_async(block_id=block.id, actor=default_user)
|
||||
assert len(agent_states) == 2
|
||||
|
||||
# Check both agents are in the list
|
||||
|
||||
Reference in New Issue
Block a user