diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 018d6300..a349366d 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -72,61 +72,6 @@ class BaseAgent(ABC): return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages] - def _rebuild_memory( - self, - in_context_messages: List[Message], - agent_state: AgentState, - num_messages: int | None = None, # storing these calculations is specific to the voice agent - num_archival_memories: int | None = None, - ) -> List[Message]: - try: - # Refresh Memory - # TODO: This only happens for the summary block (voice?) - # [DB Call] loading blocks (modifies: agent_state.memory.blocks) - self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) - - # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this - curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile() - curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: - # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" - ) - return in_context_messages - - memory_edit_timestamp = get_utc_time() - - # [DB Call] size of messages and archival memories - num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) - - new_system_message_str = compile_system_message( - system_prompt=agent_state.system, - in_context_memory=agent_state.memory, - in_context_memory_last_edit=memory_edit_timestamp, - previous_message_count=num_messages, - archival_memory_size=num_archival_memories, - ) - - diff = united_diff(curr_system_message_text, new_system_message_str) - if len(diff) > 0: - logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - - # [DB Call] Update Messages - new_system_message = self.message_manager.update_message_by_id( - curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor - ) - # Skip pulling down the agent's memory again to save on a db call - return [new_system_message] + in_context_messages[1:] - - else: - return in_context_messages - except: - logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})") - raise - async def _rebuild_memory_async( self, in_context_messages: List[Message], diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 72e75ac4..367605a8 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -37,7 +37,6 @@ from letta.services.passage_manager import PassageManager from letta.services.step_manager import NoopStepManager, StepManager from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager -from letta.settings import settings from letta.system import package_function_response from letta.tracing import log_event, trace_method, tracer @@ -97,16 +96,12 @@ class LettaAgent(BaseAgent): for _ in range(max_steps): step_id = generate_step_id() - in_context_messages = current_in_context_messages + new_in_context_messages - if settings.experimental_enable_async_db_engine: - in_context_messages = await self._rebuild_memory_async( - in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories - ) - else: - if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex": - logger.info("Skipping memory rebuild") - else: - in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + in_context_messages = await self._rebuild_memory_async( + current_in_context_messages + new_in_context_messages, + agent_state, + num_messages=self.num_messages, + num_archival_memories=self.num_archival_memories, + ) log_event("agent.stream_no_tokens.messages.refreshed") # [1^] request_data = await self._create_llm_request_data_async( @@ -200,16 +195,12 @@ class LettaAgent(BaseAgent): for _ in range(max_steps): step_id = generate_step_id() - in_context_messages = current_in_context_messages + new_in_context_messages - if settings.experimental_enable_async_db_engine: - in_context_messages = await self._rebuild_memory_async( - in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories - ) - else: - if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex": - logger.info("Skipping memory rebuild") - else: - in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + in_context_messages = await self._rebuild_memory_async( + current_in_context_messages + new_in_context_messages, + agent_state, + num_messages=self.num_messages, + num_archival_memories=self.num_archival_memories, + ) log_event("agent.step.messages.refreshed") # [1^] request_data = await self._create_llm_request_data_async( @@ -299,17 +290,12 @@ class LettaAgent(BaseAgent): for _ in range(max_steps): step_id = generate_step_id() - - in_context_messages = current_in_context_messages + new_in_context_messages - if settings.experimental_enable_async_db_engine: - in_context_messages = await self._rebuild_memory_async( - in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories - ) - else: - if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex": - logger.info("Skipping memory rebuild") - else: - in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + in_context_messages = await self._rebuild_memory_async( + current_in_context_messages + new_in_context_messages, + agent_state, + num_messages=self.num_messages, + num_archival_memories=self.num_archival_memories, + ) log_event("agent.step.messages.refreshed") # [1^] request_data = await self._create_llm_request_data_async( @@ -439,19 +425,13 @@ class LettaAgent(BaseAgent): agent_state: AgentState, tool_rules_solver: ToolRulesSolver, ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: - if settings.experimental_enable_async_db_engine: - self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)) - self.num_archival_memories = self.num_archival_memories or ( - await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) - ) - in_context_messages = await self._rebuild_memory_async( - in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories - ) - else: - if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex": - logger.info("Skipping memory rebuild") - else: - in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)) + self.num_archival_memories = self.num_archival_memories or ( + await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) + ) + in_context_messages = await self._rebuild_memory_async( + in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories + ) tools = [ t diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 46800bcc..e154107d 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -556,16 +556,6 @@ class LettaAgentBatch(BaseAgent): in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state) return in_context_messages - # TODO: Make this a bullk function - def _rebuild_memory( - self, - in_context_messages: List[Message], - agent_state: AgentState, - num_messages: int | None = None, - num_archival_memories: int | None = None, - ) -> List[Message]: - return super()._rebuild_memory(in_context_messages, agent_state) - # Not used in batch. async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: raise NotImplementedError diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 959a25a9..7f124038 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -154,7 +154,7 @@ class VoiceAgent(BaseAgent): # TODO: Define max steps here for _ in range(max_steps): # Rebuild memory each loop - in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state) openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True) openai_messages.extend(in_memory_message_history) @@ -292,14 +292,14 @@ class VoiceAgent(BaseAgent): agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor ) - def _rebuild_memory( + async def _rebuild_memory_async( self, in_context_messages: List[Message], agent_state: AgentState, num_messages: int | None = None, num_archival_memories: int | None = None, ) -> List[Message]: - return super()._rebuild_memory( + return super()._rebuild_memory_async( in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories ) diff --git a/letta/client/client.py b/letta/client/client.py index 802ca451..90e39400 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2773,11 +2773,8 @@ class LocalClient(AbstractClient): # humans / personas - def get_block_id(self, name: str, label: str) -> str: - block = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label=label, is_template=True) - if not block: - return None - return block[0].id + def get_block_id(self, name: str, label: str) -> str | None: + return None def create_human(self, name: str, text: str): """ @@ -2812,7 +2809,7 @@ class LocalClient(AbstractClient): Returns: humans (List[Human]): List of human blocks """ - return self.server.block_manager.get_blocks(actor=self.user, label="human", is_template=True) + return [] def list_personas(self) -> List[Persona]: """ @@ -2821,7 +2818,7 @@ class LocalClient(AbstractClient): Returns: personas (List[Persona]): List of persona blocks """ - return self.server.block_manager.get_blocks(actor=self.user, label="persona", is_template=True) + return [] def update_human(self, human_id: str, text: str): """ @@ -2879,7 +2876,7 @@ class LocalClient(AbstractClient): assert id, f"Human ID must be provided" return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump()) - def get_persona_id(self, name: str) -> str: + def get_persona_id(self, name: str) -> str | None: """ Get the ID of a persona block template @@ -2889,12 +2886,9 @@ class LocalClient(AbstractClient): Returns: id (str): ID of the persona block """ - persona = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="persona", is_template=True) - if not persona: - return None - return persona[0].id + return None - def get_human_id(self, name: str) -> str: + def get_human_id(self, name: str) -> str | None: """ Get the ID of a human block template @@ -2904,10 +2898,7 @@ class LocalClient(AbstractClient): Returns: id (str): ID of the human block """ - human = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="human", is_template=True) - if not human: - return None - return human[0].id + return None def delete_persona(self, id: str): """ @@ -3381,7 +3372,7 @@ class LocalClient(AbstractClient): Returns: blocks (List[Block]): List of blocks """ - return self.server.block_manager.get_blocks(actor=self.user, label=label, is_template=templates_only) + return [] def create_block( self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index e32c8dfb..66e52bc6 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1468,17 +1468,6 @@ class AgentManager: return agent_state - @enforce_types - def refresh_memory(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: - block_ids = [b.id for b in agent_state.memory.blocks] - if not block_ids: - return agent_state - - agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids( - block_ids=[b.id for b in agent_state.memory.blocks], actor=actor - ) - return agent_state - @enforce_types async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: block_ids = [b.id for b in agent_state.memory.blocks] diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 5ec95e05..8661f1fd 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -82,43 +82,6 @@ class BlockManager: block.hard_delete(db_session=session, actor=actor) return block.to_pydantic() - @enforce_types - def get_blocks( - self, - actor: PydanticUser, - label: Optional[str] = None, - is_template: Optional[bool] = None, - template_name: Optional[str] = None, - identifier_keys: Optional[List[str]] = None, - identity_id: Optional[str] = None, - id: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 50, - ) -> List[PydanticBlock]: - """Retrieve blocks based on various optional filters.""" - with db_registry.session() as session: - # Prepare filters - filters = {"organization_id": actor.organization_id} - if label: - filters["label"] = label - if is_template is not None: - filters["is_template"] = is_template - if template_name: - filters["template_name"] = template_name - if id: - filters["id"] = id - - blocks = BlockModel.list( - db_session=session, - after=after, - limit=limit, - identifier_keys=identifier_keys, - identity_id=identity_id, - **filters, - ) - - return [block.to_pydantic() for block in blocks] - @enforce_types async def get_blocks_async( self, @@ -191,15 +154,6 @@ class BlockManager: except NoResultFound: return None - @enforce_types - def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: - """Retrieve blocks by their ids.""" - with db_registry.session() as session: - blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)] - # backwards compatibility. previous implementation added None for every block not found. - blocks.extend([None for _ in range(len(block_ids) - len(blocks))]) - return blocks - @enforce_types async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: """Retrieve blocks by their ids without loading unnecessary relationships. Async implementation.""" @@ -247,18 +201,6 @@ class BlockManager: return pydantic_blocks - @enforce_types - def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]: - """ - Retrieve all agents associated with a given block. - """ - with db_registry.session() as session: - block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) - agents_orm = block.agents - agents_pydantic = [agent.to_pydantic() for agent in agents_orm] - - return agents_pydantic - @enforce_types async def get_agents_for_block_async(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]: """ diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 18d72a79..edae7374 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -106,7 +106,7 @@ async def test_sleeptime_group_chat(server, actor): # 3. Verify shared blocks sleeptime_agent_id = group.agent_ids[0] shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor) - agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor) + agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor) assert len(agents) == 2 assert sleeptime_agent_id in [agent.id for agent in agents] assert main_agent.id in [agent.id for agent in agents] @@ -220,7 +220,7 @@ async def test_sleeptime_group_chat_v2(server, actor): # 3. Verify shared blocks sleeptime_agent_id = group.agent_ids[0] shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor) - agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor) + agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor) assert len(agents) == 2 assert sleeptime_agent_id in [agent.id for agent in agents] assert main_agent.id in [agent.id for agent in agents] diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index 44e50480..835a09e6 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -511,7 +511,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor): # 3. Verify shared blocks sleeptime_agent_id = group.agent_ids[0] shared_block = server.agent_manager.get_block_with_label(agent_id=voice_agent.id, block_label="human", actor=actor) - agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor) + agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor) assert len(agents) == 2 assert sleeptime_agent_id in [agent.id for agent in agents] assert voice_agent.id in [agent.id for agent in agents] diff --git a/tests/test_managers.py b/tests/test_managers.py index 235a12dd..58deb748 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1750,29 +1750,6 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de assert block.label == default_block.label -def test_refresh_memory(server: SyncServer, default_user): - block = server.block_manager.create_or_update_block( - PydanticBlock( - label="test", - value="test", - limit=1000, - ), - actor=default_user, - ) - agent = server.agent_manager.create_agent( - CreateAgent( - name="test", - llm_config=LLMConfig.default_config("gpt-4o-mini"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - include_base_tools=False, - ), - actor=default_user, - ) - assert len(agent.memory.blocks) == 0 - agent = server.agent_manager.refresh_memory(agent_state=agent, actor=default_user) - assert len(agent.memory.blocks) == 0 - - @pytest.mark.asyncio async def test_refresh_memory_async(server: SyncServer, default_user, event_loop): block = server.block_manager.create_or_update_block( @@ -2826,7 +2803,8 @@ async def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, assert not (block.id in [b.id for b in agent_state.memory.blocks]) -def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): # Create and delete a block block = server.block_manager.create_or_update_block(PydanticBlock(label="alien", value="Sample content"), actor=default_user) sarah_agent = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user) @@ -2837,7 +2815,7 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de assert block.id in [b.id for b in charles_agent.memory.blocks] # Get the agents for that block - agent_states = server.block_manager.get_agents_for_block(block_id=block.id, actor=default_user) + agent_states = await server.block_manager.get_agents_for_block_async(block_id=block.id, actor=default_user) assert len(agent_states) == 2 # Check both agents are in the list