feat(asyncify): remove non async memory methods (#2245)

This commit is contained in:
cthomas
2025-05-20 17:56:54 -07:00
committed by GitHub
parent 2b3a6ae248
commit 18e30dbfba
10 changed files with 43 additions and 228 deletions

View File

@@ -72,61 +72,6 @@ class BaseAgent(ABC):
return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages]
def _rebuild_memory(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None, # storing these calculations is specific to the voice agent
num_archival_memories: int | None = None,
) -> List[Message]:
try:
# Refresh Memory
# TODO: This only happens for the summary block (voice?)
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
# [DB Call] size of messages and archival memories
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# [DB Call] Update Messages
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
except:
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise
async def _rebuild_memory_async(
self,
in_context_messages: List[Message],

View File

@@ -37,7 +37,6 @@ from letta.services.passage_manager import PassageManager
from letta.services.step_manager import NoopStepManager, StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import settings
from letta.system import package_function_response
from letta.tracing import log_event, trace_method, tracer
@@ -97,16 +96,12 @@ class LettaAgent(BaseAgent):
for _ in range(max_steps):
step_id = generate_step_id()
in_context_messages = current_in_context_messages + new_in_context_messages
if settings.experimental_enable_async_db_engine:
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
in_context_messages = await self._rebuild_memory_async(
current_in_context_messages + new_in_context_messages,
agent_state,
num_messages=self.num_messages,
num_archival_memories=self.num_archival_memories,
)
log_event("agent.stream_no_tokens.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
@@ -200,16 +195,12 @@ class LettaAgent(BaseAgent):
for _ in range(max_steps):
step_id = generate_step_id()
in_context_messages = current_in_context_messages + new_in_context_messages
if settings.experimental_enable_async_db_engine:
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
in_context_messages = await self._rebuild_memory_async(
current_in_context_messages + new_in_context_messages,
agent_state,
num_messages=self.num_messages,
num_archival_memories=self.num_archival_memories,
)
log_event("agent.step.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
@@ -299,17 +290,12 @@ class LettaAgent(BaseAgent):
for _ in range(max_steps):
step_id = generate_step_id()
in_context_messages = current_in_context_messages + new_in_context_messages
if settings.experimental_enable_async_db_engine:
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
in_context_messages = await self._rebuild_memory_async(
current_in_context_messages + new_in_context_messages,
agent_state,
num_messages=self.num_messages,
num_archival_memories=self.num_archival_memories,
)
log_event("agent.step.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
@@ -439,19 +425,13 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
if settings.experimental_enable_async_db_engine:
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
self.num_archival_memories = self.num_archival_memories or (
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
)
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
self.num_archival_memories = self.num_archival_memories or (
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
)
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
tools = [
t

View File

@@ -556,16 +556,6 @@ class LettaAgentBatch(BaseAgent):
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)
return in_context_messages
# TODO: Make this a bullk function
def _rebuild_memory(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None,
num_archival_memories: int | None = None,
) -> List[Message]:
return super()._rebuild_memory(in_context_messages, agent_state)
# Not used in batch.
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
raise NotImplementedError

View File

@@ -154,7 +154,7 @@ class VoiceAgent(BaseAgent):
# TODO: Define max steps here
for _ in range(max_steps):
# Rebuild memory each loop
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True)
openai_messages.extend(in_memory_message_history)
@@ -292,14 +292,14 @@ class VoiceAgent(BaseAgent):
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
def _rebuild_memory(
async def _rebuild_memory_async(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None,
num_archival_memories: int | None = None,
) -> List[Message]:
return super()._rebuild_memory(
return super()._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)

View File

@@ -2773,11 +2773,8 @@ class LocalClient(AbstractClient):
# humans / personas
def get_block_id(self, name: str, label: str) -> str:
block = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label=label, is_template=True)
if not block:
return None
return block[0].id
def get_block_id(self, name: str, label: str) -> str | None:
return None
def create_human(self, name: str, text: str):
"""
@@ -2812,7 +2809,7 @@ class LocalClient(AbstractClient):
Returns:
humans (List[Human]): List of human blocks
"""
return self.server.block_manager.get_blocks(actor=self.user, label="human", is_template=True)
return []
def list_personas(self) -> List[Persona]:
"""
@@ -2821,7 +2818,7 @@ class LocalClient(AbstractClient):
Returns:
personas (List[Persona]): List of persona blocks
"""
return self.server.block_manager.get_blocks(actor=self.user, label="persona", is_template=True)
return []
def update_human(self, human_id: str, text: str):
"""
@@ -2879,7 +2876,7 @@ class LocalClient(AbstractClient):
assert id, f"Human ID must be provided"
return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump())
def get_persona_id(self, name: str) -> str:
def get_persona_id(self, name: str) -> str | None:
"""
Get the ID of a persona block template
@@ -2889,12 +2886,9 @@ class LocalClient(AbstractClient):
Returns:
id (str): ID of the persona block
"""
persona = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="persona", is_template=True)
if not persona:
return None
return persona[0].id
return None
def get_human_id(self, name: str) -> str:
def get_human_id(self, name: str) -> str | None:
"""
Get the ID of a human block template
@@ -2904,10 +2898,7 @@ class LocalClient(AbstractClient):
Returns:
id (str): ID of the human block
"""
human = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="human", is_template=True)
if not human:
return None
return human[0].id
return None
def delete_persona(self, id: str):
"""
@@ -3381,7 +3372,7 @@ class LocalClient(AbstractClient):
Returns:
blocks (List[Block]): List of blocks
"""
return self.server.block_manager.get_blocks(actor=self.user, label=label, is_template=templates_only)
return []
def create_block(
self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False

View File

@@ -1468,17 +1468,6 @@ class AgentManager:
return agent_state
@enforce_types
def refresh_memory(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
block_ids = [b.id for b in agent_state.memory.blocks]
if not block_ids:
return agent_state
agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(
block_ids=[b.id for b in agent_state.memory.blocks], actor=actor
)
return agent_state
@enforce_types
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
block_ids = [b.id for b in agent_state.memory.blocks]

View File

@@ -82,43 +82,6 @@ class BlockManager:
block.hard_delete(db_session=session, actor=actor)
return block.to_pydantic()
@enforce_types
def get_blocks(
self,
actor: PydanticUser,
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
identity_id: Optional[str] = None,
id: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
) -> List[PydanticBlock]:
"""Retrieve blocks based on various optional filters."""
with db_registry.session() as session:
# Prepare filters
filters = {"organization_id": actor.organization_id}
if label:
filters["label"] = label
if is_template is not None:
filters["is_template"] = is_template
if template_name:
filters["template_name"] = template_name
if id:
filters["id"] = id
blocks = BlockModel.list(
db_session=session,
after=after,
limit=limit,
identifier_keys=identifier_keys,
identity_id=identity_id,
**filters,
)
return [block.to_pydantic() for block in blocks]
@enforce_types
async def get_blocks_async(
self,
@@ -191,15 +154,6 @@ class BlockManager:
except NoResultFound:
return None
@enforce_types
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids."""
with db_registry.session() as session:
blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)]
# backwards compatibility. previous implementation added None for every block not found.
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
return blocks
@enforce_types
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
@@ -247,18 +201,6 @@ class BlockManager:
return pydantic_blocks
@enforce_types
def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
"""
Retrieve all agents associated with a given block.
"""
with db_registry.session() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
agents_orm = block.agents
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
return agents_pydantic
@enforce_types
async def get_agents_for_block_async(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
"""

View File

@@ -106,7 +106,7 @@ async def test_sleeptime_group_chat(server, actor):
# 3. Verify shared blocks
sleeptime_agent_id = group.agent_ids[0]
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
assert len(agents) == 2
assert sleeptime_agent_id in [agent.id for agent in agents]
assert main_agent.id in [agent.id for agent in agents]
@@ -220,7 +220,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
# 3. Verify shared blocks
sleeptime_agent_id = group.agent_ids[0]
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
assert len(agents) == 2
assert sleeptime_agent_id in [agent.id for agent in agents]
assert main_agent.id in [agent.id for agent in agents]

View File

@@ -511,7 +511,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
# 3. Verify shared blocks
sleeptime_agent_id = group.agent_ids[0]
shared_block = server.agent_manager.get_block_with_label(agent_id=voice_agent.id, block_label="human", actor=actor)
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
assert len(agents) == 2
assert sleeptime_agent_id in [agent.id for agent in agents]
assert voice_agent.id in [agent.id for agent in agents]

View File

@@ -1750,29 +1750,6 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
assert block.label == default_block.label
def test_refresh_memory(server: SyncServer, default_user):
block = server.block_manager.create_or_update_block(
PydanticBlock(
label="test",
value="test",
limit=1000,
),
actor=default_user,
)
agent = server.agent_manager.create_agent(
CreateAgent(
name="test",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
assert len(agent.memory.blocks) == 0
agent = server.agent_manager.refresh_memory(agent_state=agent, actor=default_user)
assert len(agent.memory.blocks) == 0
@pytest.mark.asyncio
async def test_refresh_memory_async(server: SyncServer, default_user, event_loop):
block = server.block_manager.create_or_update_block(
@@ -2826,7 +2803,8 @@ async def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent,
assert not (block.id in [b.id for b in agent_state.memory.blocks])
def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user):
@pytest.mark.asyncio
async def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop):
# Create and delete a block
block = server.block_manager.create_or_update_block(PydanticBlock(label="alien", value="Sample content"), actor=default_user)
sarah_agent = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user)
@@ -2837,7 +2815,7 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
assert block.id in [b.id for b in charles_agent.memory.blocks]
# Get the agents for that block
agent_states = server.block_manager.get_agents_for_block(block_id=block.id, actor=default_user)
agent_states = await server.block_manager.get_agents_for_block_async(block_id=block.id, actor=default_user)
assert len(agent_states) == 2
# Check both agents are in the list