Files
letta-server/tests/managers/test_conversation_manager.py

507 lines
17 KiB
Python

"""
Tests for ConversationManager.
"""
import pytest
from letta.orm.errors import NoResultFound
from letta.schemas.conversation import CreateConversation, UpdateConversation
from letta.server.server import SyncServer
from letta.services.conversation_manager import ConversationManager
# ======================================================================================================================
# ConversationManager Tests
# ======================================================================================================================
@pytest.fixture
def conversation_manager():
"""Create a ConversationManager instance."""
return ConversationManager()
@pytest.mark.asyncio
async def test_create_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test creating a conversation."""
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test conversation"),
actor=default_user,
)
assert conversation.id is not None
assert conversation.agent_id == sarah_agent.id
assert conversation.summary == "Test conversation"
assert conversation.id.startswith("conv-")
@pytest.mark.asyncio
async def test_create_conversation_no_summary(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test creating a conversation without summary."""
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(),
actor=default_user,
)
assert conversation.id is not None
assert conversation.agent_id == sarah_agent.id
assert conversation.summary is None
@pytest.mark.asyncio
async def test_get_conversation_by_id(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test retrieving a conversation by ID."""
# Create a conversation
created = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test"),
actor=default_user,
)
# Retrieve it
retrieved = await conversation_manager.get_conversation_by_id(
conversation_id=created.id,
actor=default_user,
)
assert retrieved.id == created.id
assert retrieved.agent_id == created.agent_id
assert retrieved.summary == created.summary
@pytest.mark.asyncio
async def test_get_conversation_not_found(conversation_manager, server: SyncServer, default_user):
"""Test retrieving a non-existent conversation raises error."""
with pytest.raises(NoResultFound):
await conversation_manager.get_conversation_by_id(
conversation_id="conv-nonexistent",
actor=default_user,
)
@pytest.mark.asyncio
async def test_list_conversations(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test listing conversations for an agent."""
# Create multiple conversations
for i in range(3):
await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary=f"Conversation {i}"),
actor=default_user,
)
# List them
conversations = await conversation_manager.list_conversations(
agent_id=sarah_agent.id,
actor=default_user,
)
assert len(conversations) == 3
@pytest.mark.asyncio
async def test_list_conversations_with_limit(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test listing conversations with a limit."""
# Create multiple conversations
for i in range(5):
await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary=f"Conversation {i}"),
actor=default_user,
)
# List with limit
conversations = await conversation_manager.list_conversations(
agent_id=sarah_agent.id,
actor=default_user,
limit=2,
)
assert len(conversations) == 2
@pytest.mark.asyncio
async def test_update_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test updating a conversation."""
# Create a conversation
created = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Original"),
actor=default_user,
)
# Update it
updated = await conversation_manager.update_conversation(
conversation_id=created.id,
conversation_update=UpdateConversation(summary="Updated summary"),
actor=default_user,
)
assert updated.id == created.id
assert updated.summary == "Updated summary"
@pytest.mark.asyncio
async def test_delete_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test soft deleting a conversation."""
# Create a conversation
created = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="To delete"),
actor=default_user,
)
# Delete it
await conversation_manager.delete_conversation(
conversation_id=created.id,
actor=default_user,
)
# Verify it's no longer accessible
with pytest.raises(NoResultFound):
await conversation_manager.get_conversation_by_id(
conversation_id=created.id,
actor=default_user,
)
@pytest.mark.asyncio
async def test_conversation_isolation_by_agent(conversation_manager, server: SyncServer, sarah_agent, charles_agent, default_user):
"""Test that conversations are isolated by agent."""
# Create conversation for sarah_agent
await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Sarah's conversation"),
actor=default_user,
)
# Create conversation for charles_agent
await conversation_manager.create_conversation(
agent_id=charles_agent.id,
conversation_create=CreateConversation(summary="Charles's conversation"),
actor=default_user,
)
# List for sarah_agent
sarah_convos = await conversation_manager.list_conversations(
agent_id=sarah_agent.id,
actor=default_user,
)
assert len(sarah_convos) == 1
assert sarah_convos[0].summary == "Sarah's conversation"
# List for charles_agent
charles_convos = await conversation_manager.list_conversations(
agent_id=charles_agent.id,
actor=default_user,
)
assert len(charles_convos) == 1
assert charles_convos[0].summary == "Charles's conversation"
@pytest.mark.asyncio
async def test_conversation_isolation_by_organization(
conversation_manager, server: SyncServer, sarah_agent, default_user, other_user_different_org
):
"""Test that conversations are isolated by organization."""
# Create conversation
created = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test"),
actor=default_user,
)
# Other org user should not be able to access it
with pytest.raises(NoResultFound):
await conversation_manager.get_conversation_by_id(
conversation_id=created.id,
actor=other_user_different_org,
)
# ======================================================================================================================
# Conversation Message Management Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_add_messages_to_conversation(
conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture
):
"""Test adding messages to a conversation."""
# Create a conversation
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test"),
actor=default_user,
)
# Add the message to the conversation
await conversation_manager.add_messages_to_conversation(
conversation_id=conversation.id,
agent_id=sarah_agent.id,
message_ids=[hello_world_message_fixture.id],
actor=default_user,
)
# Verify message is in conversation
message_ids = await conversation_manager.get_message_ids_for_conversation(
conversation_id=conversation.id,
actor=default_user,
)
assert len(message_ids) == 1
assert message_ids[0] == hello_world_message_fixture.id
@pytest.mark.asyncio
async def test_get_messages_for_conversation(
conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture
):
"""Test getting full message objects from a conversation."""
# Create a conversation
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test"),
actor=default_user,
)
# Add the message
await conversation_manager.add_messages_to_conversation(
conversation_id=conversation.id,
agent_id=sarah_agent.id,
message_ids=[hello_world_message_fixture.id],
actor=default_user,
)
# Get full messages
messages = await conversation_manager.get_messages_for_conversation(
conversation_id=conversation.id,
actor=default_user,
)
assert len(messages) == 1
assert messages[0].id == hello_world_message_fixture.id
@pytest.mark.asyncio
async def test_message_ordering_in_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test that messages maintain their order in a conversation."""
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message as PydanticMessage
# Create a conversation
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test"),
actor=default_user,
)
# Create multiple messages
pydantic_messages = [
PydanticMessage(
agent_id=sarah_agent.id,
role="user",
content=[TextContent(text=f"Message {i}")],
)
for i in range(3)
]
messages = await server.message_manager.create_many_messages_async(
pydantic_messages,
actor=default_user,
)
# Add messages in order
await conversation_manager.add_messages_to_conversation(
conversation_id=conversation.id,
agent_id=sarah_agent.id,
message_ids=[m.id for m in messages],
actor=default_user,
)
# Verify order is maintained
retrieved_ids = await conversation_manager.get_message_ids_for_conversation(
conversation_id=conversation.id,
actor=default_user,
)
assert retrieved_ids == [m.id for m in messages]
@pytest.mark.asyncio
async def test_update_in_context_messages(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test updating which messages are in context."""
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message as PydanticMessage
# Create a conversation
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test"),
actor=default_user,
)
# Create messages
pydantic_messages = [
PydanticMessage(
agent_id=sarah_agent.id,
role="user",
content=[TextContent(text=f"Message {i}")],
)
for i in range(3)
]
messages = await server.message_manager.create_many_messages_async(
pydantic_messages,
actor=default_user,
)
# Add all messages
await conversation_manager.add_messages_to_conversation(
conversation_id=conversation.id,
agent_id=sarah_agent.id,
message_ids=[m.id for m in messages],
actor=default_user,
)
# Update to only keep first and last in context
await conversation_manager.update_in_context_messages(
conversation_id=conversation.id,
in_context_message_ids=[messages[0].id, messages[2].id],
actor=default_user,
)
# Verify only the selected messages are in context
in_context_ids = await conversation_manager.get_message_ids_for_conversation(
conversation_id=conversation.id,
actor=default_user,
)
assert len(in_context_ids) == 2
assert messages[0].id in in_context_ids
assert messages[2].id in in_context_ids
assert messages[1].id not in in_context_ids
@pytest.mark.asyncio
async def test_empty_conversation_message_ids(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test getting message IDs from an empty conversation."""
# Create a conversation
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Empty"),
actor=default_user,
)
# Get message IDs (should be empty)
message_ids = await conversation_manager.get_message_ids_for_conversation(
conversation_id=conversation.id,
actor=default_user,
)
assert message_ids == []
@pytest.mark.asyncio
async def test_list_conversation_messages(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test listing messages from a conversation as LettaMessages."""
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message as PydanticMessage
# Create a conversation
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test"),
actor=default_user,
)
# Create messages with different roles
pydantic_messages = [
PydanticMessage(
agent_id=sarah_agent.id,
role="user",
content=[TextContent(text="Hello!")],
),
PydanticMessage(
agent_id=sarah_agent.id,
role="assistant",
content=[TextContent(text="Hi there!")],
),
]
messages = await server.message_manager.create_many_messages_async(
pydantic_messages,
actor=default_user,
)
# Add messages to conversation
await conversation_manager.add_messages_to_conversation(
conversation_id=conversation.id,
agent_id=sarah_agent.id,
message_ids=[m.id for m in messages],
actor=default_user,
)
# List conversation messages (returns LettaMessages)
letta_messages = await conversation_manager.list_conversation_messages(
conversation_id=conversation.id,
actor=default_user,
)
assert len(letta_messages) == 2
# Check message types
message_types = [m.message_type for m in letta_messages]
assert "user_message" in message_types
assert "assistant_message" in message_types
@pytest.mark.asyncio
async def test_list_conversation_messages_pagination(conversation_manager, server: SyncServer, sarah_agent, default_user):
"""Test pagination when listing conversation messages."""
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message as PydanticMessage
# Create a conversation
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test"),
actor=default_user,
)
# Create multiple messages
pydantic_messages = [
PydanticMessage(
agent_id=sarah_agent.id,
role="user",
content=[TextContent(text=f"Message {i}")],
)
for i in range(5)
]
messages = await server.message_manager.create_many_messages_async(
pydantic_messages,
actor=default_user,
)
# Add messages to conversation
await conversation_manager.add_messages_to_conversation(
conversation_id=conversation.id,
agent_id=sarah_agent.id,
message_ids=[m.id for m in messages],
actor=default_user,
)
# List with limit
letta_messages = await conversation_manager.list_conversation_messages(
conversation_id=conversation.id,
actor=default_user,
limit=2,
)
assert len(letta_messages) == 2
# List with after cursor (get messages after the first one)
letta_messages_after = await conversation_manager.list_conversation_messages(
conversation_id=conversation.id,
actor=default_user,
after=messages[0].id,
)
assert len(letta_messages_after) == 4 # Should get messages 1-4