feat(asyncify): convert reset messages endpoint (#2429)
This commit is contained in:
@@ -877,15 +877,17 @@ async def send_message_async(
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/reset-messages", response_model=AgentState, operation_id="reset_messages")
|
||||
def reset_messages(
|
||||
async def reset_messages(
|
||||
agent_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Resets the messages for an agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.reset_messages_async(
|
||||
agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/groups", response_model=List[Group], operation_id="list_agent_groups")
|
||||
|
||||
@@ -578,6 +578,14 @@ class AgentManager:
|
||||
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
|
||||
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def append_initial_message_sequence_to_in_context_messages_async(
|
||||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||||
) -> PydanticAgentState:
|
||||
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
|
||||
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def update_agent(
|
||||
@@ -1502,7 +1510,20 @@ class AgentManager:
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def reset_messages(self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False) -> PydanticAgentState:
|
||||
async def append_to_in_context_messages_async(
|
||||
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
|
||||
) -> PydanticAgentState:
|
||||
messages = await self.message_manager.create_many_messages_async(messages, actor=actor)
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
message_ids = agent.message_ids or []
|
||||
message_ids += [m.id for m in messages]
|
||||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def reset_messages_async(
|
||||
self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False
|
||||
) -> PydanticAgentState:
|
||||
"""
|
||||
Removes all in-context messages for the specified agent by:
|
||||
1) Clearing the agent.messages relationship (which cascades delete-orphans).
|
||||
@@ -1519,22 +1540,22 @@ class AgentManager:
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state with no linked messages.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
# Retrieve the existing agent (will raise NoResultFound if invalid)
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Also clear out the message_ids field to keep in-context memory consistent
|
||||
agent.message_ids = []
|
||||
|
||||
# Commit the update
|
||||
agent.update(db_session=session, actor=actor)
|
||||
await agent.update_async(db_session=session, actor=actor)
|
||||
|
||||
agent_state = agent.to_pydantic()
|
||||
agent_state = await agent.to_pydantic_async()
|
||||
|
||||
self.message_manager.delete_all_messages_for_agent(agent_id=agent_id, actor=actor)
|
||||
await self.message_manager.delete_all_messages_for_agent_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
if add_default_initial_messages:
|
||||
return self.append_initial_message_sequence_to_in_context_messages(actor, agent_state)
|
||||
return await self.append_initial_message_sequence_to_in_context_messages_async(actor, agent_state)
|
||||
else:
|
||||
# We still want to always have a system message
|
||||
init_messages = initialize_message_sequence(
|
||||
@@ -1545,7 +1566,7 @@ class AgentManager:
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor)
|
||||
return await self.append_to_in_context_messages_async([system_message], agent_id=agent_state.id, actor=actor)
|
||||
|
||||
# TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version
|
||||
@trace_method
|
||||
|
||||
@@ -593,3 +593,26 @@ class MessageManager:
|
||||
|
||||
# 4) return the number of rows deleted
|
||||
return result.rowcount
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_all_messages_for_agent_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
"""
|
||||
Efficiently deletes all messages associated with a given agent_id,
|
||||
while enforcing permission checks and avoiding any ORM‑level loads.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# 1) verify the agent exists and the actor has access
|
||||
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# 2) issue a CORE DELETE against the mapped class
|
||||
stmt = (
|
||||
delete(MessageModel).where(MessageModel.agent_id == agent_id).where(MessageModel.organization_id == actor.organization_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
|
||||
# 3) commit once
|
||||
await session.commit()
|
||||
|
||||
# 4) return the number of rows deleted
|
||||
return result.rowcount
|
||||
|
||||
@@ -1487,7 +1487,7 @@ async def test_reset_messages_no_messages(server: SyncServer, sarah_agent, defau
|
||||
assert updated_agent.message_ids == ["ghost-message-id"]
|
||||
|
||||
# Reset messages
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
reset_agent = await server.agent_manager.reset_messages_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(reset_agent.message_ids) == 1
|
||||
# Double check that physically no messages exist
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
@@ -1505,7 +1505,9 @@ async def test_reset_messages_default_messages(server: SyncServer, sarah_agent,
|
||||
assert updated_agent.message_ids == ["ghost-message-id"]
|
||||
|
||||
# Reset messages
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user, add_default_initial_messages=True)
|
||||
reset_agent = await server.agent_manager.reset_messages_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, add_default_initial_messages=True
|
||||
)
|
||||
assert len(reset_agent.message_ids) == 4
|
||||
# Double check that physically no messages exist
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 4
|
||||
@@ -1544,7 +1546,7 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 6
|
||||
|
||||
# 2. Reset all messages
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
reset_agent = await server.agent_manager.reset_messages_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
|
||||
# 3. Verify the agent now has zero message_ids
|
||||
assert len(reset_agent.message_ids) == 1
|
||||
@@ -1569,12 +1571,12 @@ async def test_reset_messages_idempotency(server: SyncServer, sarah_agent, defau
|
||||
actor=default_user,
|
||||
)
|
||||
# First reset
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
reset_agent = await server.agent_manager.reset_messages_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(reset_agent.message_ids) == 1
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
|
||||
# Second reset should do nothing new
|
||||
reset_agent_again = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
reset_agent_again = await server.agent_manager.reset_messages_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(reset_agent.message_ids) == 1
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user