feat: Add reset messages route for agents (#601)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
@@ -156,6 +156,18 @@ def remove_tool_from_agent(
|
||||
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/reset-messages", response_model=AgentState, operation_id="reset_messages")
|
||||
def reset_messages(
|
||||
agent_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Resets the messages for an agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
|
||||
def get_agent_state(
|
||||
agent_id: str,
|
||||
|
||||
@@ -22,6 +22,7 @@ from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
@@ -125,13 +126,17 @@ class AgentManager:
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# TODO: See if we can merge this into the above SQL create call for performance reasons
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
return self.append_initial_message_sequence_to_in_context_messages(actor, agent_state, agent_create.initial_message_sequence)
|
||||
|
||||
@enforce_types
|
||||
def append_initial_message_sequence_to_in_context_messages(
|
||||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||||
) -> PydanticAgentState:
|
||||
init_messages = initialize_message_sequence(
|
||||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||||
)
|
||||
|
||||
if agent_create.initial_message_sequence is not None:
|
||||
if initial_message_sequence is not None:
|
||||
# We always need the system prompt up front
|
||||
system_message_obj = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id,
|
||||
@@ -142,7 +147,7 @@ class AgentManager:
|
||||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||||
init_messages = [system_message_obj]
|
||||
init_messages.extend(
|
||||
package_initial_message_sequence(agent_state.id, agent_create.initial_message_sequence, agent_state.llm_config.model, actor)
|
||||
package_initial_message_sequence(agent_state.id, initial_message_sequence, agent_state.llm_config.model, actor)
|
||||
)
|
||||
else:
|
||||
init_messages = [
|
||||
@@ -468,6 +473,45 @@ class AgentManager:
|
||||
message_ids += [m.id for m in messages]
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def reset_messages(self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False) -> PydanticAgentState:
|
||||
"""
|
||||
Removes all in-context messages for the specified agent by:
|
||||
1) Clearing the agent.messages relationship (which cascades delete-orphans).
|
||||
2) Resetting the message_ids list to empty.
|
||||
3) Committing the transaction.
|
||||
|
||||
This action is destructive and cannot be undone once committed.
|
||||
|
||||
Args:
|
||||
add_default_initial_messages: If true, adds the default initial messages after resetting.
|
||||
agent_id (str): The ID of the agent whose messages will be reset.
|
||||
actor (PydanticUser): The user performing this action.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state with no linked messages.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the existing agent (will raise NoResultFound if invalid)
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Because of cascade="all, delete-orphan" on agent.messages, setting
|
||||
# this relationship to an empty list will physically remove them from the DB.
|
||||
agent.messages = []
|
||||
|
||||
# Also clear out the message_ids field to keep in-context memory consistent
|
||||
agent.message_ids = []
|
||||
|
||||
# Commit the update
|
||||
agent.update(db_session=session, actor=actor)
|
||||
|
||||
agent_state = agent.to_pydantic()
|
||||
|
||||
if add_default_initial_messages:
|
||||
return self.append_initial_message_sequence_to_in_context_messages(actor, agent_state)
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
# ======================================================================================================================
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -914,6 +914,109 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
assert agent2.id in all_ids
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Messages Relationship
|
||||
# ======================================================================================================================
|
||||
def test_reset_messages_no_messages(server: SyncServer, sarah_agent, default_user):
|
||||
"""
|
||||
Test that resetting messages on an agent that has zero messages
|
||||
does not fail and clears out message_ids if somehow it's non-empty.
|
||||
"""
|
||||
# Force a weird scenario: Suppose the message_ids field was set non-empty (without actual messages).
|
||||
server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user)
|
||||
updated_agent = server.agent_manager.get_agent_by_id(sarah_agent.id, default_user)
|
||||
assert updated_agent.message_ids == ["ghost-message-id"]
|
||||
|
||||
# Reset messages
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert reset_agent.message_ids == []
|
||||
# Double check that physically no messages exist
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 0
|
||||
|
||||
|
||||
def test_reset_messages_default_messages(server: SyncServer, sarah_agent, default_user):
|
||||
"""
|
||||
Test that resetting messages on an agent that has zero messages
|
||||
does not fail and clears out message_ids if somehow it's non-empty.
|
||||
"""
|
||||
# Force a weird scenario: Suppose the message_ids field was set non-empty (without actual messages).
|
||||
server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user)
|
||||
updated_agent = server.agent_manager.get_agent_by_id(sarah_agent.id, default_user)
|
||||
assert updated_agent.message_ids == ["ghost-message-id"]
|
||||
|
||||
# Reset messages
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user, add_default_initial_messages=True)
|
||||
assert len(reset_agent.message_ids) == 4
|
||||
# Double check that physically no messages exist
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 4
|
||||
|
||||
|
||||
def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent, default_user):
|
||||
"""
|
||||
Test that resetting messages on an agent with actual messages
|
||||
deletes them from the database and clears message_ids.
|
||||
"""
|
||||
# 1. Create multiple messages for the agent
|
||||
msg1 = server.message_manager.create_message(
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="user",
|
||||
text="Hello, Sarah!",
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
msg2 = server.message_manager.create_message(
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="assistant",
|
||||
text="Hello, user!",
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify the messages were created
|
||||
agent_before = server.agent_manager.get_agent_by_id(sarah_agent.id, default_user)
|
||||
# This is 4 because creating the message does not necessarily add it to the in context message ids
|
||||
assert len(agent_before.message_ids) == 4
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 6
|
||||
|
||||
# 2. Reset all messages
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
|
||||
# 3. Verify the agent now has zero message_ids
|
||||
assert reset_agent.message_ids == []
|
||||
|
||||
# 4. Verify the messages are physically removed
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 0
|
||||
|
||||
|
||||
def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_user):
|
||||
"""
|
||||
Test that calling reset_messages multiple times has no adverse effect.
|
||||
"""
|
||||
# Create a single message
|
||||
server.message_manager.create_message(
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="user",
|
||||
text="Hello, Sarah!",
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
# First reset
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert reset_agent.message_ids == []
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 0
|
||||
|
||||
# Second reset should do nothing new
|
||||
reset_agent_again = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert reset_agent_again.message_ids == []
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 0
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Blocks Relationship
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user