feat: query param parity for conversation messages (#8730)
* base * add tests * generate
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from typing import Annotated, List, Optional
|
||||
from typing import Annotated, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from pydantic import Field
|
||||
@@ -96,19 +96,26 @@ async def list_conversation_messages(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
before: Optional[str] = Query(
|
||||
None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the conversation"
|
||||
None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order"
|
||||
),
|
||||
after: Optional[str] = Query(
|
||||
None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the conversation"
|
||||
None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order"
|
||||
),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
order: Literal["asc", "desc"] = Query(
|
||||
"desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
|
||||
),
|
||||
order_by: Literal["created_at"] = Query("created_at", description="Field to sort by"),
|
||||
group_id: Optional[str] = Query(None, description="Group ID to filter messages by."),
|
||||
include_err: Optional[bool] = Query(
|
||||
None, description="Whether to include error messages and error statuses. For debugging purposes only."
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all messages in a conversation.
|
||||
|
||||
Returns LettaMessage objects (UserMessage, AssistantMessage, etc.) for all
|
||||
messages in the conversation, ordered by position (oldest first),
|
||||
with support for cursor-based pagination.
|
||||
messages in the conversation, with support for cursor-based pagination.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.list_conversation_messages(
|
||||
@@ -117,6 +124,9 @@ async def list_conversation_messages(
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=after,
|
||||
reverse=(order == "desc"),
|
||||
group_id=group_id,
|
||||
include_err=include_err,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -334,13 +334,15 @@ class ConversationManager:
|
||||
limit: Optional[int] = 100,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
reverse: bool = False,
|
||||
group_id: Optional[str] = None,
|
||||
include_err: Optional[bool] = None,
|
||||
) -> List[LettaMessage]:
|
||||
"""
|
||||
List all messages in a conversation with pagination support.
|
||||
|
||||
Unlike get_messages_for_conversation, this returns ALL messages
|
||||
(not just in_context) and supports cursor-based pagination.
|
||||
Messages are always ordered by position (oldest first).
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation to list messages for
|
||||
@@ -348,6 +350,9 @@ class ConversationManager:
|
||||
limit: Maximum number of messages to return
|
||||
before: Return messages before this message ID
|
||||
after: Return messages after this message ID
|
||||
reverse: If True, return messages in descending order (newest first)
|
||||
group_id: Optional group ID to filter messages by
|
||||
include_err: Optional boolean to include error messages and error statuses
|
||||
|
||||
Returns:
|
||||
List of LettaMessage objects
|
||||
@@ -367,6 +372,10 @@ class ConversationManager:
|
||||
)
|
||||
)
|
||||
|
||||
# Filter by group_id if provided
|
||||
if group_id:
|
||||
query = query.where(MessageModel.group_id == group_id)
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if before:
|
||||
# Get the position of the cursor message
|
||||
@@ -390,8 +399,11 @@ class ConversationManager:
|
||||
if cursor_position is not None:
|
||||
query = query.where(ConversationMessageModel.position > cursor_position)
|
||||
|
||||
# Order by position (oldest first)
|
||||
query = query.order_by(ConversationMessageModel.position.asc())
|
||||
# Order by position
|
||||
if reverse:
|
||||
query = query.order_by(ConversationMessageModel.position.desc())
|
||||
else:
|
||||
query = query.order_by(ConversationMessageModel.position.asc())
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
@@ -401,7 +413,9 @@ class ConversationManager:
|
||||
messages = [msg.to_pydantic() for msg in result.scalars().all()]
|
||||
|
||||
# Convert to LettaMessages
|
||||
return PydanticMessage.to_letta_messages_from_list(messages, reverse=False, text_is_assistant_message=True)
|
||||
return PydanticMessage.to_letta_messages_from_list(
|
||||
messages, reverse=reverse, include_err=include_err, text_is_assistant_message=True
|
||||
)
|
||||
|
||||
# ==================== Isolated Blocks Methods ====================
|
||||
|
||||
|
||||
@@ -396,3 +396,172 @@ class TestConversationsSDK:
|
||||
)
|
||||
)
|
||||
assert len(messages) > 0, "Should be able to send message after concurrent requests complete"
|
||||
|
||||
def test_list_conversation_messages_order_asc(self, client: Letta, agent):
|
||||
"""Test listing messages in ascending order (oldest first)."""
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send messages to create history
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "First message"}],
|
||||
)
|
||||
)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Second message"}],
|
||||
)
|
||||
)
|
||||
|
||||
# List messages in ascending order (oldest first)
|
||||
messages_asc = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
order="asc",
|
||||
)
|
||||
|
||||
# First message should be system message (oldest)
|
||||
assert messages_asc[0].message_type == "system_message"
|
||||
|
||||
# Get user messages and verify order
|
||||
user_messages = [m for m in messages_asc if m.message_type == "user_message"]
|
||||
assert len(user_messages) >= 2
|
||||
# First user message should contain "First message"
|
||||
assert "First" in user_messages[0].content
|
||||
|
||||
def test_list_conversation_messages_order_desc(self, client: Letta, agent):
|
||||
"""Test listing messages in descending order (newest first)."""
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send messages to create history
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "First message"}],
|
||||
)
|
||||
)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Second message"}],
|
||||
)
|
||||
)
|
||||
|
||||
# List messages in descending order (newest first) - this is the default
|
||||
messages_desc = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
order="desc",
|
||||
)
|
||||
|
||||
# Get user messages and verify order
|
||||
user_messages = [m for m in messages_desc if m.message_type == "user_message"]
|
||||
assert len(user_messages) >= 2
|
||||
# First user message in desc order should contain "Second message" (newest)
|
||||
assert "Second" in user_messages[0].content
|
||||
|
||||
def test_list_conversation_messages_order_affects_pagination(self, client: Letta, agent):
|
||||
"""Test that order parameter affects pagination correctly."""
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send multiple messages
|
||||
for i in range(3):
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": f"Message {i}"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Get all messages in descending order with limit
|
||||
messages_desc = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
order="desc",
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# Get all messages in ascending order with limit
|
||||
messages_asc = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
order="asc",
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# The first messages should be different based on order
|
||||
assert messages_desc[0].id != messages_asc[0].id
|
||||
|
||||
def test_list_conversation_messages_with_before_cursor(self, client: Letta, agent):
|
||||
"""Test pagination with before cursor."""
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send messages to create history
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "First message"}],
|
||||
)
|
||||
)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Second message"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Get all messages first
|
||||
all_messages = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
order="asc",
|
||||
)
|
||||
assert len(all_messages) >= 4 # system + user + assistant + user + assistant
|
||||
|
||||
# Use the last message ID as cursor
|
||||
last_message_id = all_messages[-1].id
|
||||
messages_before = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
order="asc",
|
||||
before=last_message_id,
|
||||
)
|
||||
|
||||
# Should have fewer messages (all except the last one)
|
||||
assert len(messages_before) < len(all_messages)
|
||||
# Should not contain the cursor message
|
||||
assert last_message_id not in [m.id for m in messages_before]
|
||||
|
||||
def test_list_conversation_messages_with_after_cursor(self, client: Letta, agent):
|
||||
"""Test pagination with after cursor."""
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send messages to create history
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "First message"}],
|
||||
)
|
||||
)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Second message"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Get all messages first
|
||||
all_messages = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
order="asc",
|
||||
)
|
||||
assert len(all_messages) >= 4
|
||||
|
||||
# Use the first message ID as cursor
|
||||
first_message_id = all_messages[0].id
|
||||
messages_after = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
order="asc",
|
||||
after=first_message_id,
|
||||
)
|
||||
|
||||
# Should have fewer messages (all except the first one)
|
||||
assert len(messages_after) < len(all_messages)
|
||||
# Should not contain the cursor message
|
||||
assert first_message_id not in [m.id for m in messages_after]
|
||||
|
||||
@@ -749,3 +749,291 @@ async def test_delete_conversation_cleans_up_isolated_blocks(conversation_manage
|
||||
# Verify the isolated block was hard-deleted
|
||||
deleted_block = await server.block_manager.get_block_by_id_async(isolated_block_id, default_user)
|
||||
assert deleted_block is None
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# list_conversation_messages with order/reverse Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages_ascending_order(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing messages in ascending order (oldest first)."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages in a known order
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages to conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List messages in ascending order (reverse=False)
|
||||
letta_messages = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
reverse=False,
|
||||
)
|
||||
|
||||
# First message should be "Message 0" (oldest)
|
||||
assert len(letta_messages) == 3
|
||||
assert "Message 0" in letta_messages[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages_descending_order(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing messages in descending order (newest first)."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages in a known order
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages to conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List messages in descending order (reverse=True)
|
||||
letta_messages = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# First message should be "Message 2" (newest)
|
||||
assert len(letta_messages) == 3
|
||||
assert "Message 2" in letta_messages[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages_with_group_id_filter(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test filtering messages by group_id."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages with different group_ids
|
||||
group_a_id = "group-aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
group_b_id = "group-bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
|
||||
messages_group_a = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text="Group A message 1")],
|
||||
group_id=group_a_id,
|
||||
),
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text="Group A message 2")],
|
||||
group_id=group_a_id,
|
||||
),
|
||||
]
|
||||
messages_group_b = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text="Group B message 1")],
|
||||
group_id=group_b_id,
|
||||
),
|
||||
]
|
||||
|
||||
created_a = await server.message_manager.create_many_messages_async(messages_group_a, actor=default_user)
|
||||
created_b = await server.message_manager.create_many_messages_async(messages_group_b, actor=default_user)
|
||||
|
||||
# Add all messages to conversation
|
||||
all_message_ids = [m.id for m in created_a] + [m.id for m in created_b]
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=all_message_ids,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List messages filtered by group A
|
||||
messages_a = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
group_id=group_a_id,
|
||||
)
|
||||
|
||||
assert len(messages_a) == 2
|
||||
for msg in messages_a:
|
||||
assert "Group A" in msg.content
|
||||
|
||||
# List messages filtered by group B
|
||||
messages_b = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
group_id=group_b_id,
|
||||
)
|
||||
|
||||
assert len(messages_b) == 1
|
||||
assert "Group B" in messages_b[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages_no_group_id_returns_all(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that not providing group_id returns all messages."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages with different group_ids
|
||||
group_a_id = "group-aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
group_b_id = "group-bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text="Group A message")],
|
||||
group_id=group_a_id,
|
||||
),
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text="Group B message")],
|
||||
group_id=group_b_id,
|
||||
),
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text="No group message")],
|
||||
group_id=None,
|
||||
),
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(pydantic_messages, actor=default_user)
|
||||
|
||||
# Add all messages to conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List all messages without group_id filter
|
||||
all_messages = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(all_messages) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages_order_with_pagination(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that order affects pagination correctly."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages to conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get first page in ascending order with limit
|
||||
page_asc = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
reverse=False,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
# Get first page in descending order with limit
|
||||
page_desc = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
reverse=True,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
# The first messages should be different
|
||||
assert page_asc[0].content != page_desc[0].content
|
||||
# In ascending, first should be "Message 0"
|
||||
assert "Message 0" in page_asc[0].content
|
||||
# In descending, first should be "Message 4"
|
||||
assert "Message 4" in page_desc[0].content
|
||||
|
||||
Reference in New Issue
Block a user