507 lines
17 KiB
Python
507 lines
17 KiB
Python
"""
|
|
Tests for ConversationManager.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.schemas.conversation import CreateConversation, UpdateConversation
|
|
from letta.server.server import SyncServer
|
|
from letta.services.conversation_manager import ConversationManager
|
|
|
|
# ======================================================================================================================
|
|
# ConversationManager Tests
|
|
# ======================================================================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def conversation_manager():
|
|
"""Create a ConversationManager instance."""
|
|
return ConversationManager()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test creating a conversation."""
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test conversation"),
|
|
actor=default_user,
|
|
)
|
|
|
|
assert conversation.id is not None
|
|
assert conversation.agent_id == sarah_agent.id
|
|
assert conversation.summary == "Test conversation"
|
|
assert conversation.id.startswith("conv-")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_conversation_no_summary(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test creating a conversation without summary."""
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(),
|
|
actor=default_user,
|
|
)
|
|
|
|
assert conversation.id is not None
|
|
assert conversation.agent_id == sarah_agent.id
|
|
assert conversation.summary is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_conversation_by_id(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test retrieving a conversation by ID."""
|
|
# Create a conversation
|
|
created = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Retrieve it
|
|
retrieved = await conversation_manager.get_conversation_by_id(
|
|
conversation_id=created.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
assert retrieved.id == created.id
|
|
assert retrieved.agent_id == created.agent_id
|
|
assert retrieved.summary == created.summary
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_conversation_not_found(conversation_manager, server: SyncServer, default_user):
|
|
"""Test retrieving a non-existent conversation raises error."""
|
|
with pytest.raises(NoResultFound):
|
|
await conversation_manager.get_conversation_by_id(
|
|
conversation_id="conv-nonexistent",
|
|
actor=default_user,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_conversations(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test listing conversations for an agent."""
|
|
# Create multiple conversations
|
|
for i in range(3):
|
|
await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary=f"Conversation {i}"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# List them
|
|
conversations = await conversation_manager.list_conversations(
|
|
agent_id=sarah_agent.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
assert len(conversations) == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_conversations_with_limit(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test listing conversations with a limit."""
|
|
# Create multiple conversations
|
|
for i in range(5):
|
|
await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary=f"Conversation {i}"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# List with limit
|
|
conversations = await conversation_manager.list_conversations(
|
|
agent_id=sarah_agent.id,
|
|
actor=default_user,
|
|
limit=2,
|
|
)
|
|
|
|
assert len(conversations) == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test updating a conversation."""
|
|
# Create a conversation
|
|
created = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Original"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Update it
|
|
updated = await conversation_manager.update_conversation(
|
|
conversation_id=created.id,
|
|
conversation_update=UpdateConversation(summary="Updated summary"),
|
|
actor=default_user,
|
|
)
|
|
|
|
assert updated.id == created.id
|
|
assert updated.summary == "Updated summary"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test soft deleting a conversation."""
|
|
# Create a conversation
|
|
created = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="To delete"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Delete it
|
|
await conversation_manager.delete_conversation(
|
|
conversation_id=created.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
# Verify it's no longer accessible
|
|
with pytest.raises(NoResultFound):
|
|
await conversation_manager.get_conversation_by_id(
|
|
conversation_id=created.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_isolation_by_agent(conversation_manager, server: SyncServer, sarah_agent, charles_agent, default_user):
|
|
"""Test that conversations are isolated by agent."""
|
|
# Create conversation for sarah_agent
|
|
await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Sarah's conversation"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Create conversation for charles_agent
|
|
await conversation_manager.create_conversation(
|
|
agent_id=charles_agent.id,
|
|
conversation_create=CreateConversation(summary="Charles's conversation"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# List for sarah_agent
|
|
sarah_convos = await conversation_manager.list_conversations(
|
|
agent_id=sarah_agent.id,
|
|
actor=default_user,
|
|
)
|
|
assert len(sarah_convos) == 1
|
|
assert sarah_convos[0].summary == "Sarah's conversation"
|
|
|
|
# List for charles_agent
|
|
charles_convos = await conversation_manager.list_conversations(
|
|
agent_id=charles_agent.id,
|
|
actor=default_user,
|
|
)
|
|
assert len(charles_convos) == 1
|
|
assert charles_convos[0].summary == "Charles's conversation"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_isolation_by_organization(
|
|
conversation_manager, server: SyncServer, sarah_agent, default_user, other_user_different_org
|
|
):
|
|
"""Test that conversations are isolated by organization."""
|
|
# Create conversation
|
|
created = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Other org user should not be able to access it
|
|
with pytest.raises(NoResultFound):
|
|
await conversation_manager.get_conversation_by_id(
|
|
conversation_id=created.id,
|
|
actor=other_user_different_org,
|
|
)
|
|
|
|
|
|
# ======================================================================================================================
|
|
# Conversation Message Management Tests
|
|
# ======================================================================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_messages_to_conversation(
|
|
conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture
|
|
):
|
|
"""Test adding messages to a conversation."""
|
|
# Create a conversation
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Add the message to the conversation
|
|
await conversation_manager.add_messages_to_conversation(
|
|
conversation_id=conversation.id,
|
|
agent_id=sarah_agent.id,
|
|
message_ids=[hello_world_message_fixture.id],
|
|
actor=default_user,
|
|
)
|
|
|
|
# Verify message is in conversation
|
|
message_ids = await conversation_manager.get_message_ids_for_conversation(
|
|
conversation_id=conversation.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
assert len(message_ids) == 1
|
|
assert message_ids[0] == hello_world_message_fixture.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_messages_for_conversation(
|
|
conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture
|
|
):
|
|
"""Test getting full message objects from a conversation."""
|
|
# Create a conversation
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Add the message
|
|
await conversation_manager.add_messages_to_conversation(
|
|
conversation_id=conversation.id,
|
|
agent_id=sarah_agent.id,
|
|
message_ids=[hello_world_message_fixture.id],
|
|
actor=default_user,
|
|
)
|
|
|
|
# Get full messages
|
|
messages = await conversation_manager.get_messages_for_conversation(
|
|
conversation_id=conversation.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
assert len(messages) == 1
|
|
assert messages[0].id == hello_world_message_fixture.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_ordering_in_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test that messages maintain their order in a conversation."""
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
|
|
# Create a conversation
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Create multiple messages
|
|
pydantic_messages = [
|
|
PydanticMessage(
|
|
agent_id=sarah_agent.id,
|
|
role="user",
|
|
content=[TextContent(text=f"Message {i}")],
|
|
)
|
|
for i in range(3)
|
|
]
|
|
messages = await server.message_manager.create_many_messages_async(
|
|
pydantic_messages,
|
|
actor=default_user,
|
|
)
|
|
|
|
# Add messages in order
|
|
await conversation_manager.add_messages_to_conversation(
|
|
conversation_id=conversation.id,
|
|
agent_id=sarah_agent.id,
|
|
message_ids=[m.id for m in messages],
|
|
actor=default_user,
|
|
)
|
|
|
|
# Verify order is maintained
|
|
retrieved_ids = await conversation_manager.get_message_ids_for_conversation(
|
|
conversation_id=conversation.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
assert retrieved_ids == [m.id for m in messages]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_in_context_messages(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test updating which messages are in context."""
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
|
|
# Create a conversation
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Create messages
|
|
pydantic_messages = [
|
|
PydanticMessage(
|
|
agent_id=sarah_agent.id,
|
|
role="user",
|
|
content=[TextContent(text=f"Message {i}")],
|
|
)
|
|
for i in range(3)
|
|
]
|
|
messages = await server.message_manager.create_many_messages_async(
|
|
pydantic_messages,
|
|
actor=default_user,
|
|
)
|
|
|
|
# Add all messages
|
|
await conversation_manager.add_messages_to_conversation(
|
|
conversation_id=conversation.id,
|
|
agent_id=sarah_agent.id,
|
|
message_ids=[m.id for m in messages],
|
|
actor=default_user,
|
|
)
|
|
|
|
# Update to only keep first and last in context
|
|
await conversation_manager.update_in_context_messages(
|
|
conversation_id=conversation.id,
|
|
in_context_message_ids=[messages[0].id, messages[2].id],
|
|
actor=default_user,
|
|
)
|
|
|
|
# Verify only the selected messages are in context
|
|
in_context_ids = await conversation_manager.get_message_ids_for_conversation(
|
|
conversation_id=conversation.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
assert len(in_context_ids) == 2
|
|
assert messages[0].id in in_context_ids
|
|
assert messages[2].id in in_context_ids
|
|
assert messages[1].id not in in_context_ids
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_conversation_message_ids(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test getting message IDs from an empty conversation."""
|
|
# Create a conversation
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Empty"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Get message IDs (should be empty)
|
|
message_ids = await conversation_manager.get_message_ids_for_conversation(
|
|
conversation_id=conversation.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
assert message_ids == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_conversation_messages(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test listing messages from a conversation as LettaMessages."""
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
|
|
# Create a conversation
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Create messages with different roles
|
|
pydantic_messages = [
|
|
PydanticMessage(
|
|
agent_id=sarah_agent.id,
|
|
role="user",
|
|
content=[TextContent(text="Hello!")],
|
|
),
|
|
PydanticMessage(
|
|
agent_id=sarah_agent.id,
|
|
role="assistant",
|
|
content=[TextContent(text="Hi there!")],
|
|
),
|
|
]
|
|
messages = await server.message_manager.create_many_messages_async(
|
|
pydantic_messages,
|
|
actor=default_user,
|
|
)
|
|
|
|
# Add messages to conversation
|
|
await conversation_manager.add_messages_to_conversation(
|
|
conversation_id=conversation.id,
|
|
agent_id=sarah_agent.id,
|
|
message_ids=[m.id for m in messages],
|
|
actor=default_user,
|
|
)
|
|
|
|
# List conversation messages (returns LettaMessages)
|
|
letta_messages = await conversation_manager.list_conversation_messages(
|
|
conversation_id=conversation.id,
|
|
actor=default_user,
|
|
)
|
|
|
|
assert len(letta_messages) == 2
|
|
# Check message types
|
|
message_types = [m.message_type for m in letta_messages]
|
|
assert "user_message" in message_types
|
|
assert "assistant_message" in message_types
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_conversation_messages_pagination(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
|
"""Test pagination when listing conversation messages."""
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
|
|
# Create a conversation
|
|
conversation = await conversation_manager.create_conversation(
|
|
agent_id=sarah_agent.id,
|
|
conversation_create=CreateConversation(summary="Test"),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Create multiple messages
|
|
pydantic_messages = [
|
|
PydanticMessage(
|
|
agent_id=sarah_agent.id,
|
|
role="user",
|
|
content=[TextContent(text=f"Message {i}")],
|
|
)
|
|
for i in range(5)
|
|
]
|
|
messages = await server.message_manager.create_many_messages_async(
|
|
pydantic_messages,
|
|
actor=default_user,
|
|
)
|
|
|
|
# Add messages to conversation
|
|
await conversation_manager.add_messages_to_conversation(
|
|
conversation_id=conversation.id,
|
|
agent_id=sarah_agent.id,
|
|
message_ids=[m.id for m in messages],
|
|
actor=default_user,
|
|
)
|
|
|
|
# List with limit
|
|
letta_messages = await conversation_manager.list_conversation_messages(
|
|
conversation_id=conversation.id,
|
|
actor=default_user,
|
|
limit=2,
|
|
)
|
|
assert len(letta_messages) == 2
|
|
|
|
# List with after cursor (get messages after the first one)
|
|
letta_messages_after = await conversation_manager.list_conversation_messages(
|
|
conversation_id=conversation.id,
|
|
actor=default_user,
|
|
after=messages[0].id,
|
|
)
|
|
assert len(letta_messages_after) == 4 # Should get messages 1-4
|