From 3a893daec02b22ad7128b964cf9e2bb167021ebb Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 25 May 2025 21:23:18 -0700 Subject: [PATCH] feat(asyncify): convert reset messages endpoint (#2429) --- letta/server/rest_api/routers/v1/agents.py | 8 +++-- letta/services/agent_manager.py | 37 +++++++++++++++++----- letta/services/message_manager.py | 23 ++++++++++++++ tests/test_managers.py | 12 ++++--- 4 files changed, 64 insertions(+), 16 deletions(-) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 36cb05a3..5d61aade 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -877,15 +877,17 @@ async def send_message_async( @router.patch("/{agent_id}/reset-messages", response_model=AgentState, operation_id="reset_messages") -def reset_messages( +async def reset_messages( agent_id: str, add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Resets the messages for an agent""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.reset_messages_async( + agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages + ) @router.get("/{agent_id}/groups", response_model=List[Group], operation_id="list_agent_groups") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 2b330dfe..aeb75403 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -578,6 +578,14 @@ class AgentManager: init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence) return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor) + @trace_method + @enforce_types + async def append_initial_message_sequence_to_in_context_messages_async( + self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None + ) -> PydanticAgentState: + init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence) + return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor) + @trace_method @enforce_types def update_agent( @@ -1502,7 +1510,20 @@ class AgentManager: @trace_method @enforce_types - def reset_messages(self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False) -> PydanticAgentState: + async def append_to_in_context_messages_async( + self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser + ) -> PydanticAgentState: + messages = await self.message_manager.create_many_messages_async(messages, actor=actor) + agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + message_ids = agent.message_ids or [] + message_ids += [m.id for m in messages] + return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor) + + @trace_method + @enforce_types + async def reset_messages_async( + self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False + ) -> PydanticAgentState: """ Removes all in-context messages for the specified agent by: 1) Clearing the agent.messages relationship (which cascades delete-orphans). @@ -1519,22 +1540,22 @@ class AgentManager: Returns: PydanticAgentState: The updated agent state with no linked messages. """ - with db_registry.session() as session: + async with db_registry.async_session() as session: # Retrieve the existing agent (will raise NoResultFound if invalid) - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) # Also clear out the message_ids field to keep in-context memory consistent agent.message_ids = [] # Commit the update - agent.update(db_session=session, actor=actor) + await agent.update_async(db_session=session, actor=actor) - agent_state = agent.to_pydantic() + agent_state = await agent.to_pydantic_async() - self.message_manager.delete_all_messages_for_agent(agent_id=agent_id, actor=actor) + await self.message_manager.delete_all_messages_for_agent_async(agent_id=agent_id, actor=actor) if add_default_initial_messages: - return self.append_initial_message_sequence_to_in_context_messages(actor, agent_state) + return await self.append_initial_message_sequence_to_in_context_messages_async(actor, agent_state) else: # We still want to always have a system message init_messages = initialize_message_sequence( @@ -1545,7 +1566,7 @@ class AgentManager: model=agent_state.llm_config.model, openai_message_dict=init_messages[0], ) - return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor) + return await self.append_to_in_context_messages_async([system_message], agent_id=agent_state.id, actor=actor) # TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version @trace_method diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index ce70d58e..26e5326c 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -593,3 +593,26 @@ class MessageManager: # 4) return the number of rows deleted return result.rowcount + + @enforce_types + @trace_method + async def delete_all_messages_for_agent_async(self, agent_id: str, actor: PydanticUser) -> int: + """ + Efficiently deletes all messages associated with a given agent_id, + while enforcing permission checks and avoiding any ORM‑level loads. + """ + async with db_registry.async_session() as session: + # 1) verify the agent exists and the actor has access + await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # 2) issue a CORE DELETE against the mapped class + stmt = ( + delete(MessageModel).where(MessageModel.agent_id == agent_id).where(MessageModel.organization_id == actor.organization_id) + ) + result = await session.execute(stmt) + + # 3) commit once + await session.commit() + + # 4) return the number of rows deleted + return result.rowcount diff --git a/tests/test_managers.py b/tests/test_managers.py index 2449d4cc..48c338ec 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1487,7 +1487,7 @@ async def test_reset_messages_no_messages(server: SyncServer, sarah_agent, defau assert updated_agent.message_ids == ["ghost-message-id"] # Reset messages - reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) + reset_agent = await server.agent_manager.reset_messages_async(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 # Double check that physically no messages exist assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 @@ -1505,7 +1505,9 @@ async def test_reset_messages_default_messages(server: SyncServer, sarah_agent, assert updated_agent.message_ids == ["ghost-message-id"] # Reset messages - reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user, add_default_initial_messages=True) + reset_agent = await server.agent_manager.reset_messages_async( + agent_id=sarah_agent.id, actor=default_user, add_default_initial_messages=True + ) assert len(reset_agent.message_ids) == 4 # Double check that physically no messages exist assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 4 @@ -1544,7 +1546,7 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 6 # 2. Reset all messages - reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) + reset_agent = await server.agent_manager.reset_messages_async(agent_id=sarah_agent.id, actor=default_user) # 3. Verify the agent now has zero message_ids assert len(reset_agent.message_ids) == 1 @@ -1569,12 +1571,12 @@ async def test_reset_messages_idempotency(server: SyncServer, sarah_agent, defau actor=default_user, ) # First reset - reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) + reset_agent = await server.agent_manager.reset_messages_async(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 # Second reset should do nothing new - reset_agent_again = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) + reset_agent_again = await server.agent_manager.reset_messages_async(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1