feat: add conversation and conversation_messages tables for concurrent messaging (#8182)
This commit is contained in:
committed by
Caren Thomas
parent
c66b852978
commit
87d920782f
271
tests/integration_test_conversations_sdk.py
Normal file
271
tests/integration_test_conversations_sdk.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Integration tests for the Conversations API using the SDK.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from letta_client import Letta
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(server_url: str) -> Letta:
|
||||
"""Create a Letta client."""
|
||||
return Letta(base_url=server_url)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(client: Letta):
|
||||
"""Create a test agent."""
|
||||
agent_state = client.agents.create(
|
||||
name=f"test_conversations_{uuid.uuid4().hex[:8]}",
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
memory_blocks=[
|
||||
{"label": "human", "value": "Test user"},
|
||||
{"label": "persona", "value": "You are a helpful assistant."},
|
||||
],
|
||||
)
|
||||
yield agent_state
|
||||
# Cleanup
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
class TestConversationsSDK:
|
||||
"""Test conversations using the SDK client."""
|
||||
|
||||
def test_create_conversation(self, client: Letta, agent):
|
||||
"""Test creating a conversation for an agent."""
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
assert conversation.id is not None
|
||||
assert conversation.id.startswith("conv-")
|
||||
assert conversation.agent_id == agent.id
|
||||
|
||||
def test_list_conversations(self, client: Letta, agent):
|
||||
"""Test listing conversations for an agent."""
|
||||
# Create multiple conversations
|
||||
conv1 = client.conversations.create(agent_id=agent.id)
|
||||
conv2 = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# List conversations
|
||||
conversations = client.conversations.list(agent_id=agent.id)
|
||||
|
||||
assert len(conversations) >= 2
|
||||
conv_ids = [c.id for c in conversations]
|
||||
assert conv1.id in conv_ids
|
||||
assert conv2.id in conv_ids
|
||||
|
||||
def test_retrieve_conversation(self, client: Letta, agent):
|
||||
"""Test retrieving a specific conversation."""
|
||||
# Create a conversation
|
||||
created = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Retrieve it (should have empty in_context_message_ids initially)
|
||||
retrieved = client.conversations.retrieve(conversation_id=created.id)
|
||||
|
||||
assert retrieved.id == created.id
|
||||
assert retrieved.agent_id == created.agent_id
|
||||
assert retrieved.in_context_message_ids == []
|
||||
|
||||
# Send a message to the conversation
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=created.id,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Retrieve again and check in_context_message_ids is populated
|
||||
retrieved_with_messages = client.conversations.retrieve(conversation_id=created.id)
|
||||
|
||||
# System message + user + assistant messages should be in the conversation
|
||||
assert len(retrieved_with_messages.in_context_message_ids) >= 3 # system + user + assistant
|
||||
# All IDs should be strings starting with "message-"
|
||||
for msg_id in retrieved_with_messages.in_context_message_ids:
|
||||
assert isinstance(msg_id, str)
|
||||
assert msg_id.startswith("message-")
|
||||
|
||||
# Verify message ordering by listing messages
|
||||
messages = client.conversations.messages.list(conversation_id=created.id)
|
||||
assert len(messages) >= 3 # system + user + assistant
|
||||
# First message should be system message (shared across conversations)
|
||||
assert messages[0].message_type == "system_message", f"First message should be system_message, got {messages[0].message_type}"
|
||||
# Second message should be user message
|
||||
assert messages[1].message_type == "user_message", f"Second message should be user_message, got {messages[1].message_type}"
|
||||
|
||||
def test_send_message_to_conversation(self, client: Letta, agent):
|
||||
"""Test sending a message to a conversation."""
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send a message (returns a stream)
|
||||
stream = client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
|
||||
# Consume the stream to get messages
|
||||
messages = list(stream)
|
||||
|
||||
# Check response contains messages
|
||||
assert len(messages) > 0
|
||||
# Should have at least an assistant message
|
||||
message_types = [m.message_type for m in messages if hasattr(m, "message_type")]
|
||||
assert "assistant_message" in message_types
|
||||
|
||||
def test_list_conversation_messages(self, client: Letta, agent):
|
||||
"""Test listing messages from a conversation."""
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send a message to create some history (consume the stream)
|
||||
stream = client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Say 'test response' back to me."}],
|
||||
)
|
||||
list(stream) # Consume stream
|
||||
|
||||
# List messages
|
||||
messages = client.conversations.messages.list(conversation_id=conversation.id)
|
||||
|
||||
assert len(messages) >= 2 # At least user + assistant
|
||||
message_types = [m.message_type for m in messages]
|
||||
assert "user_message" in message_types
|
||||
assert "assistant_message" in message_types
|
||||
|
||||
# Send another message and check that old and new messages are both listed
|
||||
first_message_count = len(messages)
|
||||
stream = client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "This is a follow-up message."}],
|
||||
)
|
||||
list(stream) # Consume stream
|
||||
|
||||
# List messages again
|
||||
updated_messages = client.conversations.messages.list(conversation_id=conversation.id)
|
||||
|
||||
# Should have more messages now (at least 2 more: user + assistant)
|
||||
assert len(updated_messages) >= first_message_count + 2
|
||||
|
||||
def test_conversation_isolation(self, client: Letta, agent):
|
||||
"""Test that conversations are isolated from each other."""
|
||||
# Create two conversations
|
||||
conv1 = client.conversations.create(agent_id=agent.id)
|
||||
conv2 = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send different messages to each (consume streams)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conv1.id,
|
||||
messages=[{"role": "user", "content": "Remember the word: APPLE"}],
|
||||
)
|
||||
)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conv2.id,
|
||||
messages=[{"role": "user", "content": "Remember the word: BANANA"}],
|
||||
)
|
||||
)
|
||||
|
||||
# List messages from each conversation
|
||||
conv1_messages = client.conversations.messages.list(conversation_id=conv1.id)
|
||||
conv2_messages = client.conversations.messages.list(conversation_id=conv2.id)
|
||||
|
||||
# Check messages are separate
|
||||
conv1_content = " ".join([m.content for m in conv1_messages if hasattr(m, "content") and m.content])
|
||||
conv2_content = " ".join([m.content for m in conv2_messages if hasattr(m, "content") and m.content])
|
||||
|
||||
assert "APPLE" in conv1_content
|
||||
assert "BANANA" in conv2_content
|
||||
# Each conversation should only have its own word
|
||||
assert "BANANA" not in conv1_content or "APPLE" not in conv2_content
|
||||
|
||||
# Ask what word was remembered and make sure it's different for each conversation
|
||||
conv1_recall = list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conv1.id,
|
||||
messages=[{"role": "user", "content": "What word did I ask you to remember? Reply with just the word."}],
|
||||
)
|
||||
)
|
||||
conv2_recall = list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conv2.id,
|
||||
messages=[{"role": "user", "content": "What word did I ask you to remember? Reply with just the word."}],
|
||||
)
|
||||
)
|
||||
|
||||
# Get the assistant responses
|
||||
conv1_response = " ".join([m.content for m in conv1_recall if hasattr(m, "message_type") and m.message_type == "assistant_message"])
|
||||
conv2_response = " ".join([m.content for m in conv2_recall if hasattr(m, "message_type") and m.message_type == "assistant_message"])
|
||||
|
||||
assert "APPLE" in conv1_response.upper(), f"Conv1 should remember APPLE, got: {conv1_response}"
|
||||
assert "BANANA" in conv2_response.upper(), f"Conv2 should remember BANANA, got: {conv2_response}"
|
||||
|
||||
# Each conversation has its own system message (created on first message)
|
||||
conv1_system_id = conv1_messages[0].id
|
||||
conv2_system_id = conv2_messages[0].id
|
||||
assert conv1_system_id != conv2_system_id, "System messages should have different IDs for different conversations"
|
||||
|
||||
def test_conversation_messages_pagination(self, client: Letta, agent):
|
||||
"""Test pagination when listing conversation messages."""
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send multiple messages to create history (consume streams)
|
||||
for i in range(3):
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": f"Message number {i}"}],
|
||||
)
|
||||
)
|
||||
|
||||
# List with limit
|
||||
messages = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
# Should respect the limit
|
||||
assert len(messages) <= 2
|
||||
|
||||
def test_retrieve_conversation_stream_no_active_run(self, client: Letta, agent):
|
||||
"""Test that retrieve_conversation_stream returns error when no active run exists."""
|
||||
from letta_client import BadRequestError
|
||||
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Try to retrieve stream when no run exists (should fail)
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
# Use the SDK's stream method
|
||||
stream = client.conversations.messages.stream(conversation_id=conversation.id)
|
||||
list(stream) # Consume the stream to trigger the error
|
||||
|
||||
# Should return 400 because no active run exists
|
||||
assert "No active runs found" in str(exc_info.value)
|
||||
|
||||
def test_retrieve_conversation_stream_after_completed_run(self, client: Letta, agent):
|
||||
"""Test that retrieve_conversation_stream returns error when run is completed."""
|
||||
from letta_client import BadRequestError
|
||||
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send a message (this creates a run that completes)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Try to retrieve stream after the run has completed (should fail)
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
# Use the SDK's stream method
|
||||
stream = client.conversations.messages.stream(conversation_id=conversation.id)
|
||||
list(stream) # Consume the stream to trigger the error
|
||||
|
||||
# Should return 400 because no active run exists (run is completed)
|
||||
assert "No active runs found" in str(exc_info.value)
|
||||
@@ -917,6 +917,92 @@ async def test_tool_call(
|
||||
assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_config",
|
||||
TESTED_MODEL_CONFIGS,
|
||||
ids=[handle for handle, _ in TESTED_MODEL_CONFIGS],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_conversation_streaming_raw_http(
|
||||
disable_e2b_api_key: Any,
|
||||
client: AsyncLetta,
|
||||
server_url: str,
|
||||
agent_state: AgentState,
|
||||
model_config: Tuple[str, dict],
|
||||
) -> None:
|
||||
"""
|
||||
Test conversation-based streaming functionality using raw HTTP requests.
|
||||
|
||||
This test verifies that:
|
||||
1. A conversation can be created for an agent
|
||||
2. Messages can be sent to the conversation via streaming
|
||||
3. The streaming response contains the expected message types
|
||||
4. Messages are properly persisted in the conversation
|
||||
|
||||
Uses raw HTTP requests instead of SDK until SDK is regenerated with conversations support.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
model_handle, model_settings = model_config
|
||||
agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings)
|
||||
|
||||
async with httpx.AsyncClient(base_url=server_url, timeout=60.0) as http_client:
|
||||
# Create a conversation for the agent
|
||||
create_response = await http_client.post(
|
||||
"/v1/conversations/",
|
||||
params={"agent_id": agent_state.id},
|
||||
json={},
|
||||
)
|
||||
assert create_response.status_code == 200, f"Failed to create conversation: {create_response.text}"
|
||||
conversation = create_response.json()
|
||||
assert conversation["id"] is not None
|
||||
assert conversation["agent_id"] == agent_state.id
|
||||
|
||||
# Send a message to the conversation using streaming
|
||||
stream_response = await http_client.post(
|
||||
f"/v1/conversations/{conversation['id']}/messages",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": f"Reply with the message '{USER_MESSAGE_RESPONSE}'."}],
|
||||
"stream_tokens": True,
|
||||
},
|
||||
)
|
||||
assert stream_response.status_code == 200, f"Failed to send message: {stream_response.text}"
|
||||
|
||||
# Parse SSE response and accumulate messages
|
||||
messages = await accumulate_chunks(stream_response.text)
|
||||
print("MESSAGES:", messages)
|
||||
|
||||
# Verify the response contains expected message types
|
||||
assert_greeting_response(messages, model_handle, model_settings, streaming=True, token_streaming=True)
|
||||
|
||||
# Verify the conversation can be retrieved
|
||||
retrieve_response = await http_client.get(f"/v1/conversations/{conversation['id']}")
|
||||
assert retrieve_response.status_code == 200, f"Failed to retrieve conversation: {retrieve_response.text}"
|
||||
retrieved_conversation = retrieve_response.json()
|
||||
assert retrieved_conversation["id"] == conversation["id"]
|
||||
print("RETRIEVED CONVERSATION:", retrieved_conversation)
|
||||
|
||||
# Verify conversations can be listed for the agent
|
||||
list_response = await http_client.get("/v1/conversations/", params={"agent_id": agent_state.id})
|
||||
assert list_response.status_code == 200, f"Failed to list conversations: {list_response.text}"
|
||||
conversations_list = list_response.json()
|
||||
assert any(c["id"] == conversation["id"] for c in conversations_list)
|
||||
|
||||
# Verify messages can be listed from the conversation
|
||||
messages_response = await http_client.get(f"/v1/conversations/{conversation['id']}/messages")
|
||||
assert messages_response.status_code == 200, f"Failed to list conversation messages: {messages_response.text}"
|
||||
conversation_messages = messages_response.json()
|
||||
print("CONVERSATION MESSAGES:", conversation_messages)
|
||||
|
||||
# Verify we have at least the user message and assistant message
|
||||
assert len(conversation_messages) >= 2, f"Expected at least 2 messages, got {len(conversation_messages)}"
|
||||
|
||||
# Check message types are present
|
||||
message_types = [msg.get("message_type") for msg in conversation_messages]
|
||||
assert "user_message" in message_types, f"Expected user_message in {message_types}"
|
||||
assert "assistant_message" in message_types, f"Expected assistant_message in {message_types}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_handle,provider_type",
|
||||
[
|
||||
|
||||
506
tests/managers/test_conversation_manager.py
Normal file
506
tests/managers/test_conversation_manager.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Tests for ConversationManager.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.conversation import CreateConversation, UpdateConversation
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
|
||||
# ======================================================================================================================
|
||||
# ConversationManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_manager():
|
||||
"""Create a ConversationManager instance."""
|
||||
return ConversationManager()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test creating a conversation."""
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test conversation"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert conversation.id is not None
|
||||
assert conversation.agent_id == sarah_agent.id
|
||||
assert conversation.summary == "Test conversation"
|
||||
assert conversation.id.startswith("conv-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_no_summary(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test creating a conversation without summary."""
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert conversation.id is not None
|
||||
assert conversation.agent_id == sarah_agent.id
|
||||
assert conversation.summary is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_by_id(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test retrieving a conversation by ID."""
|
||||
# Create a conversation
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Retrieve it
|
||||
retrieved = await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=created.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert retrieved.id == created.id
|
||||
assert retrieved.agent_id == created.agent_id
|
||||
assert retrieved.summary == created.summary
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_not_found(conversation_manager, server: SyncServer, default_user):
|
||||
"""Test retrieving a non-existent conversation raises error."""
|
||||
with pytest.raises(NoResultFound):
|
||||
await conversation_manager.get_conversation_by_id(
|
||||
conversation_id="conv-nonexistent",
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing conversations for an agent."""
|
||||
# Create multiple conversations
|
||||
for i in range(3):
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary=f"Conversation {i}"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List them
|
||||
conversations = await conversation_manager.list_conversations(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(conversations) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_with_limit(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing conversations with a limit."""
|
||||
# Create multiple conversations
|
||||
for i in range(5):
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary=f"Conversation {i}"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List with limit
|
||||
conversations = await conversation_manager.list_conversations(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(conversations) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test updating a conversation."""
|
||||
# Create a conversation
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Original"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Update it
|
||||
updated = await conversation_manager.update_conversation(
|
||||
conversation_id=created.id,
|
||||
conversation_update=UpdateConversation(summary="Updated summary"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert updated.id == created.id
|
||||
assert updated.summary == "Updated summary"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test soft deleting a conversation."""
|
||||
# Create a conversation
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="To delete"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Delete it
|
||||
await conversation_manager.delete_conversation(
|
||||
conversation_id=created.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify it's no longer accessible
|
||||
with pytest.raises(NoResultFound):
|
||||
await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=created.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_isolation_by_agent(conversation_manager, server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
"""Test that conversations are isolated by agent."""
|
||||
# Create conversation for sarah_agent
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Sarah's conversation"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create conversation for charles_agent
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=charles_agent.id,
|
||||
conversation_create=CreateConversation(summary="Charles's conversation"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List for sarah_agent
|
||||
sarah_convos = await conversation_manager.list_conversations(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(sarah_convos) == 1
|
||||
assert sarah_convos[0].summary == "Sarah's conversation"
|
||||
|
||||
# List for charles_agent
|
||||
charles_convos = await conversation_manager.list_conversations(
|
||||
agent_id=charles_agent.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(charles_convos) == 1
|
||||
assert charles_convos[0].summary == "Charles's conversation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_isolation_by_organization(
|
||||
conversation_manager, server: SyncServer, sarah_agent, default_user, other_user_different_org
|
||||
):
|
||||
"""Test that conversations are isolated by organization."""
|
||||
# Create conversation
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Other org user should not be able to access it
|
||||
with pytest.raises(NoResultFound):
|
||||
await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=created.id,
|
||||
actor=other_user_different_org,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Conversation Message Management Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_messages_to_conversation(
|
||||
conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture
|
||||
):
|
||||
"""Test adding messages to a conversation."""
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add the message to the conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[hello_world_message_fixture.id],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify message is in conversation
|
||||
message_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(message_ids) == 1
|
||||
assert message_ids[0] == hello_world_message_fixture.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_for_conversation(
|
||||
conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture
|
||||
):
|
||||
"""Test getting full message objects from a conversation."""
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add the message
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[hello_world_message_fixture.id],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get full messages
|
||||
messages = await conversation_manager.get_messages_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].id == hello_world_message_fixture.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_ordering_in_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that messages maintain their order in a conversation."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create multiple messages
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages in order
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify order is maintained
|
||||
retrieved_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert retrieved_ids == [m.id for m in messages]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_in_context_messages(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test updating which messages are in context."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add all messages
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Update to only keep first and last in context
|
||||
await conversation_manager.update_in_context_messages(
|
||||
conversation_id=conversation.id,
|
||||
in_context_message_ids=[messages[0].id, messages[2].id],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify only the selected messages are in context
|
||||
in_context_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(in_context_ids) == 2
|
||||
assert messages[0].id in in_context_ids
|
||||
assert messages[2].id in in_context_ids
|
||||
assert messages[1].id not in in_context_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_conversation_message_ids(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test getting message IDs from an empty conversation."""
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Empty"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get message IDs (should be empty)
|
||||
message_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert message_ids == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing messages from a conversation as LettaMessages."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages with different roles
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text="Hello!")],
|
||||
),
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="assistant",
|
||||
content=[TextContent(text="Hi there!")],
|
||||
),
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages to conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List conversation messages (returns LettaMessages)
|
||||
letta_messages = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(letta_messages) == 2
|
||||
# Check message types
|
||||
message_types = [m.message_type for m in letta_messages]
|
||||
assert "user_message" in message_types
|
||||
assert "assistant_message" in message_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages_pagination(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test pagination when listing conversation messages."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create multiple messages
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages to conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List with limit
|
||||
letta_messages = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
limit=2,
|
||||
)
|
||||
assert len(letta_messages) == 2
|
||||
|
||||
# List with after cursor (get messages after the first one)
|
||||
letta_messages_after = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
after=messages[0].id,
|
||||
)
|
||||
assert len(letta_messages_after) == 4 # Should get messages 1-4
|
||||
Reference in New Issue
Block a user