From 037c20ae1b1d0492704fb2bf291e5970fa066fca Mon Sep 17 00:00:00 2001 From: jnjpng Date: Wed, 14 Jan 2026 17:44:48 -0800 Subject: [PATCH] feat: query param parity for conversation messages (#8730) * base * add tests * generate --- .../rest_api/routers/v1/conversations.py | 20 +- letta/services/conversation_manager.py | 22 +- tests/integration_test_conversations_sdk.py | 169 ++++++++++ tests/managers/test_conversation_manager.py | 288 ++++++++++++++++++ 4 files changed, 490 insertions(+), 9 deletions(-) diff --git a/letta/server/rest_api/routers/v1/conversations.py b/letta/server/rest_api/routers/v1/conversations.py index c489f398..58aad829 100644 --- a/letta/server/rest_api/routers/v1/conversations.py +++ b/letta/server/rest_api/routers/v1/conversations.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Annotated, List, Optional +from typing import Annotated, List, Literal, Optional from fastapi import APIRouter, Body, Depends, HTTPException, Query from pydantic import Field @@ -96,19 +96,26 @@ async def list_conversation_messages( server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), before: Optional[str] = Query( - None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the conversation" + None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order" ), after: Optional[str] = Query( - None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the conversation" + None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order" ), limit: Optional[int] = Query(100, description="Maximum number of messages to return"), + order: Literal["asc", "desc"] = Query( + "desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first" + ), + order_by: Literal["created_at"] = Query("created_at", description="Field to sort by"), + group_id: Optional[str] = Query(None, description="Group ID to filter messages by."), + include_err: Optional[bool] = Query( + None, description="Whether to include error messages and error statuses. For debugging purposes only." + ), ): """ List all messages in a conversation. Returns LettaMessage objects (UserMessage, AssistantMessage, etc.) for all - messages in the conversation, ordered by position (oldest first), - with support for cursor-based pagination. + messages in the conversation, with support for cursor-based pagination. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) return await conversation_manager.list_conversation_messages( @@ -117,6 +124,9 @@ async def list_conversation_messages( limit=limit, before=before, after=after, + reverse=(order == "desc"), + group_id=group_id, + include_err=include_err, ) diff --git a/letta/services/conversation_manager.py b/letta/services/conversation_manager.py index 62239fb4..efe7c5f3 100644 --- a/letta/services/conversation_manager.py +++ b/letta/services/conversation_manager.py @@ -334,13 +334,15 @@ class ConversationManager: limit: Optional[int] = 100, before: Optional[str] = None, after: Optional[str] = None, + reverse: bool = False, + group_id: Optional[str] = None, + include_err: Optional[bool] = None, ) -> List[LettaMessage]: """ List all messages in a conversation with pagination support. Unlike get_messages_for_conversation, this returns ALL messages (not just in_context) and supports cursor-based pagination. - Messages are always ordered by position (oldest first). Args: conversation_id: The conversation to list messages for @@ -348,6 +350,9 @@ class ConversationManager: limit: Maximum number of messages to return before: Return messages before this message ID after: Return messages after this message ID + reverse: If True, return messages in descending order (newest first) + group_id: Optional group ID to filter messages by + include_err: Optional boolean to include error messages and error statuses Returns: List of LettaMessage objects @@ -367,6 +372,10 @@ class ConversationManager: ) ) + # Filter by group_id if provided + if group_id: + query = query.where(MessageModel.group_id == group_id) + # Handle cursor-based pagination if before: # Get the position of the cursor message @@ -390,8 +399,11 @@ class ConversationManager: if cursor_position is not None: query = query.where(ConversationMessageModel.position > cursor_position) - # Order by position (oldest first) - query = query.order_by(ConversationMessageModel.position.asc()) + # Order by position + if reverse: + query = query.order_by(ConversationMessageModel.position.desc()) + else: + query = query.order_by(ConversationMessageModel.position.asc()) # Apply limit if limit is not None: @@ -401,7 +413,9 @@ class ConversationManager: messages = [msg.to_pydantic() for msg in result.scalars().all()] # Convert to LettaMessages - return PydanticMessage.to_letta_messages_from_list(messages, reverse=False, text_is_assistant_message=True) + return PydanticMessage.to_letta_messages_from_list( + messages, reverse=reverse, include_err=include_err, text_is_assistant_message=True + ) # ==================== Isolated Blocks Methods ==================== diff --git a/tests/integration_test_conversations_sdk.py b/tests/integration_test_conversations_sdk.py index 6284b705..c62a53c2 100644 --- a/tests/integration_test_conversations_sdk.py +++ b/tests/integration_test_conversations_sdk.py @@ -396,3 +396,172 @@ class TestConversationsSDK: ) ) assert len(messages) > 0, "Should be able to send message after concurrent requests complete" + + def test_list_conversation_messages_order_asc(self, client: Letta, agent): + """Test listing messages in ascending order (oldest first).""" + conversation = client.conversations.create(agent_id=agent.id) + + # Send messages to create history + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "First message"}], + ) + ) + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "Second message"}], + ) + ) + + # List messages in ascending order (oldest first) + messages_asc = client.conversations.messages.list( + conversation_id=conversation.id, + order="asc", + ) + + # First message should be system message (oldest) + assert messages_asc[0].message_type == "system_message" + + # Get user messages and verify order + user_messages = [m for m in messages_asc if m.message_type == "user_message"] + assert len(user_messages) >= 2 + # First user message should contain "First message" + assert "First" in user_messages[0].content + + def test_list_conversation_messages_order_desc(self, client: Letta, agent): + """Test listing messages in descending order (newest first).""" + conversation = client.conversations.create(agent_id=agent.id) + + # Send messages to create history + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "First message"}], + ) + ) + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "Second message"}], + ) + ) + + # List messages in descending order (newest first) - this is the default + messages_desc = client.conversations.messages.list( + conversation_id=conversation.id, + order="desc", + ) + + # Get user messages and verify order + user_messages = [m for m in messages_desc if m.message_type == "user_message"] + assert len(user_messages) >= 2 + # First user message in desc order should contain "Second message" (newest) + assert "Second" in user_messages[0].content + + def test_list_conversation_messages_order_affects_pagination(self, client: Letta, agent): + """Test that order parameter affects pagination correctly.""" + conversation = client.conversations.create(agent_id=agent.id) + + # Send multiple messages + for i in range(3): + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": f"Message {i}"}], + ) + ) + + # Get all messages in descending order with limit + messages_desc = client.conversations.messages.list( + conversation_id=conversation.id, + order="desc", + limit=5, + ) + + # Get all messages in ascending order with limit + messages_asc = client.conversations.messages.list( + conversation_id=conversation.id, + order="asc", + limit=5, + ) + + # The first messages should be different based on order + assert messages_desc[0].id != messages_asc[0].id + + def test_list_conversation_messages_with_before_cursor(self, client: Letta, agent): + """Test pagination with before cursor.""" + conversation = client.conversations.create(agent_id=agent.id) + + # Send messages to create history + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "First message"}], + ) + ) + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "Second message"}], + ) + ) + + # Get all messages first + all_messages = client.conversations.messages.list( + conversation_id=conversation.id, + order="asc", + ) + assert len(all_messages) >= 4 # system + user + assistant + user + assistant + + # Use the last message ID as cursor + last_message_id = all_messages[-1].id + messages_before = client.conversations.messages.list( + conversation_id=conversation.id, + order="asc", + before=last_message_id, + ) + + # Should have fewer messages (all except the last one) + assert len(messages_before) < len(all_messages) + # Should not contain the cursor message + assert last_message_id not in [m.id for m in messages_before] + + def test_list_conversation_messages_with_after_cursor(self, client: Letta, agent): + """Test pagination with after cursor.""" + conversation = client.conversations.create(agent_id=agent.id) + + # Send messages to create history + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "First message"}], + ) + ) + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "Second message"}], + ) + ) + + # Get all messages first + all_messages = client.conversations.messages.list( + conversation_id=conversation.id, + order="asc", + ) + assert len(all_messages) >= 4 + + # Use the first message ID as cursor + first_message_id = all_messages[0].id + messages_after = client.conversations.messages.list( + conversation_id=conversation.id, + order="asc", + after=first_message_id, + ) + + # Should have fewer messages (all except the first one) + assert len(messages_after) < len(all_messages) + # Should not contain the cursor message + assert first_message_id not in [m.id for m in messages_after] diff --git a/tests/managers/test_conversation_manager.py b/tests/managers/test_conversation_manager.py index 43999091..03329a9a 100644 --- a/tests/managers/test_conversation_manager.py +++ b/tests/managers/test_conversation_manager.py @@ -749,3 +749,291 @@ async def test_delete_conversation_cleans_up_isolated_blocks(conversation_manage # Verify the isolated block was hard-deleted deleted_block = await server.block_manager.get_block_by_id_async(isolated_block_id, default_user) assert deleted_block is None + + +# ====================================================================================================================== +# list_conversation_messages with order/reverse Tests +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_list_conversation_messages_ascending_order(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test listing messages in ascending order (oldest first).""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create messages in a known order + pydantic_messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text=f"Message {i}")], + ) + for i in range(3) + ] + messages = await server.message_manager.create_many_messages_async( + pydantic_messages, + actor=default_user, + ) + + # Add messages to conversation + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[m.id for m in messages], + actor=default_user, + ) + + # List messages in ascending order (reverse=False) + letta_messages = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + reverse=False, + ) + + # First message should be "Message 0" (oldest) + assert len(letta_messages) == 3 + assert "Message 0" in letta_messages[0].content + + +@pytest.mark.asyncio +async def test_list_conversation_messages_descending_order(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test listing messages in descending order (newest first).""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create messages in a known order + pydantic_messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text=f"Message {i}")], + ) + for i in range(3) + ] + messages = await server.message_manager.create_many_messages_async( + pydantic_messages, + actor=default_user, + ) + + # Add messages to conversation + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[m.id for m in messages], + actor=default_user, + ) + + # List messages in descending order (reverse=True) + letta_messages = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + reverse=True, + ) + + # First message should be "Message 2" (newest) + assert len(letta_messages) == 3 + assert "Message 2" in letta_messages[0].content + + +@pytest.mark.asyncio +async def test_list_conversation_messages_with_group_id_filter(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test filtering messages by group_id.""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create messages with different group_ids + group_a_id = "group-aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + group_b_id = "group-bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + + messages_group_a = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text="Group A message 1")], + group_id=group_a_id, + ), + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text="Group A message 2")], + group_id=group_a_id, + ), + ] + messages_group_b = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text="Group B message 1")], + group_id=group_b_id, + ), + ] + + created_a = await server.message_manager.create_many_messages_async(messages_group_a, actor=default_user) + created_b = await server.message_manager.create_many_messages_async(messages_group_b, actor=default_user) + + # Add all messages to conversation + all_message_ids = [m.id for m in created_a] + [m.id for m in created_b] + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=all_message_ids, + actor=default_user, + ) + + # List messages filtered by group A + messages_a = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + group_id=group_a_id, + ) + + assert len(messages_a) == 2 + for msg in messages_a: + assert "Group A" in msg.content + + # List messages filtered by group B + messages_b = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + group_id=group_b_id, + ) + + assert len(messages_b) == 1 + assert "Group B" in messages_b[0].content + + +@pytest.mark.asyncio +async def test_list_conversation_messages_no_group_id_returns_all(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test that not providing group_id returns all messages.""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create messages with different group_ids + group_a_id = "group-aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + group_b_id = "group-bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + + pydantic_messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text="Group A message")], + group_id=group_a_id, + ), + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text="Group B message")], + group_id=group_b_id, + ), + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text="No group message")], + group_id=None, + ), + ] + messages = await server.message_manager.create_many_messages_async(pydantic_messages, actor=default_user) + + # Add all messages to conversation + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[m.id for m in messages], + actor=default_user, + ) + + # List all messages without group_id filter + all_messages = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + ) + + assert len(all_messages) == 3 + + +@pytest.mark.asyncio +async def test_list_conversation_messages_order_with_pagination(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test that order affects pagination correctly.""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create messages + pydantic_messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text=f"Message {i}")], + ) + for i in range(5) + ] + messages = await server.message_manager.create_many_messages_async( + pydantic_messages, + actor=default_user, + ) + + # Add messages to conversation + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[m.id for m in messages], + actor=default_user, + ) + + # Get first page in ascending order with limit + page_asc = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + reverse=False, + limit=2, + ) + + # Get first page in descending order with limit + page_desc = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + reverse=True, + limit=2, + ) + + # The first messages should be different + assert page_asc[0].content != page_desc[0].content + # In ascending, first should be "Message 0" + assert "Message 0" in page_asc[0].content + # In descending, first should be "Message 4" + assert "Message 4" in page_desc[0].content