feat: Add conversation_id filtering to message endpoints (#8324)

* feat: Add conversation_id filtering to message list and search endpoints

Add optional conversation_id parameter to filter messages by conversation:
- client.agents.messages.list
- client.messages.list
- client.messages.search

Changes:
- Added conversation_id field to MessageSearchRequest and SearchAllMessagesRequest schemas
- Added conversation_id filtering to list_messages in message_manager.py
- Updated get_agent_recall_async and get_all_messages_recall_async in server.py
- Added conversation_id query parameter to router endpoints
- Updated Turbopuffer client to support conversation_id filtering in searches

Fixes #8320

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Charles Packer <cpacker@users.noreply.github.com>

* add conversation_id to message and tpuf

* default messages filter for backward compatibility

* add test and auto gen

* fix integration test

* fix test

* update test

---------

Co-authored-by: letta-code <248085862+letta-code@users.noreply.github.com>
Co-authored-by: Charles Packer <cpacker@users.noreply.github.com>
Co-authored-by: christinatong01 <christina@letta.com>
This commit is contained in:
Charles Packer
2026-01-07 13:45:52 -08:00
committed by Caren Thomas
parent 737d6e2550
commit ed6284cedb
9 changed files with 334 additions and 24 deletions

View File

@@ -7442,6 +7442,24 @@
},
"description": "Group ID to filter messages by."
},
{
"name": "conversation_id",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"description": "Conversation ID to filter messages by.",
"title": "Conversation Id"
},
"description": "Conversation ID to filter messages by."
},
{
"name": "use_assistant_message",
"in": "query",
@@ -16137,6 +16155,24 @@
"title": "Order"
},
"description": "Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
},
{
"name": "conversation_id",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"description": "Conversation ID to filter messages by",
"title": "Conversation Id"
},
"description": "Conversation ID to filter messages by"
}
],
"responses": {
@@ -35419,6 +35455,18 @@
"title": "Batch Item Id",
"description": "The id of the LLMBatchItem that this message is associated with"
},
"conversation_id": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Conversation Id",
"description": "The conversation this message belongs to"
},
"is_err": {
"anyOf": [
{
@@ -35494,7 +35542,7 @@
"type": "object",
"required": ["role"],
"title": "Message",
"description": " Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.\n\n Attributes:\n id (str): The unique identifier of the message.\n role (MessageRole): The role of the participant.\n text (str): The text of the message.\n user_id (str): The unique identifier of the user.\n agent_id (str): The unique identifier of the agent.\n model (str): The model used to make the function call.\n name (str): The name of the participant.\n created_at (datetime): The time the message was created.\n tool_calls (List[OpenAIToolCall,]): The list of tool calls requested.\n tool_call_id (str): The id of the tool call.\n step_id (str): The id of the step that this message was created in.\n otid (str): The offline threading id associated with this message.\n tool_returns (List[ToolReturn]): The list of tool returns requested.\n group_id (str): The multi-agent group that the message was sent in.\n sender_id (str): The id of the sender of the message, can be an identity id or agent id.\nt"
"description": " Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.\n\n Attributes:\n id (str): The unique identifier of the message.\n role (MessageRole): The role of the participant.\n text (str): The text of the message.\n user_id (str): The unique identifier of the user.\n agent_id (str): The unique identifier of the agent.\n model (str): The model used to make the function call.\n name (str): The name of the participant.\n created_at (datetime): The time the message was created.\n tool_calls (List[OpenAIToolCall,]): The list of tool calls requested.\n tool_call_id (str): The id of the tool call.\n step_id (str): The id of the step that this message was created in.\n otid (str): The offline threading id associated with this message.\n tool_returns (List[ToolReturn]): The list of tool returns requested.\n group_id (str): The multi-agent group that the message was sent in.\n sender_id (str): The id of the sender of the message, can be an identity id or agent id.\n conversation_id (str): The conversation this message belongs to.\nt"
},
"MessageCreate": {
"properties": {
@@ -35676,6 +35724,18 @@
"title": "Template Id",
"description": "Filter messages by template ID"
},
"conversation_id": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Conversation Id",
"description": "Filter messages by conversation ID"
},
"limit": {
"type": "integer",
"maximum": 100,
@@ -38588,6 +38648,18 @@
"title": "Agent Id",
"description": "Filter messages by agent ID"
},
"conversation_id": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Conversation Id",
"description": "Filter messages by conversation ID"
},
"limit": {
"type": "integer",
"maximum": 100,

View File

@@ -400,6 +400,7 @@ class TurbopufferClient:
created_ats: List[datetime],
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_ids: Optional[List[Optional[str]]] = None,
) -> bool:
"""Insert messages into Turbopuffer.
@@ -413,6 +414,7 @@ class TurbopufferClient:
created_ats: List of creation timestamps for each message
project_id: Optional project ID for all messages
template_id: Optional template ID for all messages
conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts)
Returns:
True if successful
@@ -441,22 +443,26 @@ class TurbopufferClient:
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
if len(message_ids) != len(created_ats):
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids = []
agent_ids = []
organization_ids_list = []
agent_ids_list = []
message_roles = []
created_at_timestamps = []
project_ids = []
template_ids = []
project_ids_list = []
template_ids_list = []
conversation_ids_list = []
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
message_id = message_ids[original_idx]
role = roles[original_idx]
created_at = created_ats[original_idx]
conversation_id = conversation_ids[original_idx] if conversation_ids else None
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
@@ -470,31 +476,36 @@ class TurbopufferClient:
ids.append(message_id)
vectors.append(embedding)
texts.append(text)
organization_ids.append(organization_id)
agent_ids.append(agent_id)
organization_ids_list.append(organization_id)
agent_ids_list.append(agent_id)
message_roles.append(role.value)
created_at_timestamps.append(timestamp)
project_ids.append(project_id)
template_ids.append(template_id)
project_ids_list.append(project_id)
template_ids_list.append(template_id)
conversation_ids_list.append(conversation_id)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids,
"agent_id": agent_ids,
"organization_id": organization_ids_list,
"agent_id": agent_ids_list,
"role": message_roles,
"created_at": created_at_timestamps,
}
# only include conversation_id if it's provided
if conversation_ids is not None:
upsert_columns["conversation_id"] = conversation_ids_list
# only include project_id if it's provided
if project_id is not None:
upsert_columns["project_id"] = project_ids
upsert_columns["project_id"] = project_ids_list
# only include template_id if it's provided
if template_id is not None:
upsert_columns["template_id"] = template_ids
upsert_columns["template_id"] = template_ids_list
try:
# use global semaphore to limit concurrent Turbopuffer writes
@@ -506,7 +517,10 @@ class TurbopufferClient:
await namespace.write(
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={"text": {"type": "string", "full_text_search": True}},
schema={
"text": {"type": "string", "full_text_search": True},
"conversation_id": {"type": "string"},
},
)
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
return True
@@ -792,6 +806,7 @@ class TurbopufferClient:
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
@@ -809,6 +824,7 @@ class TurbopufferClient:
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
conversation_id: Optional conversation ID to filter messages by (use "default" for NULL)
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
@@ -875,6 +891,19 @@ class TurbopufferClient:
if template_id:
template_filter = ("template_id", "Eq", template_id)
# build conversation_id filter if provided
# three cases:
# 1. conversation_id=None (omitted) -> return all messages (no filter)
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
# 3. conversation_id="xyz" -> return only messages in that conversation
conversation_filter = None
if conversation_id == "default":
# "default" is reserved for default messages only (conversation_id is none)
conversation_filter = ("conversation_id", "Eq", None)
elif conversation_id is not None:
# Specific conversation
conversation_filter = ("conversation_id", "Eq", conversation_id)
# combine all filters
all_filters = [agent_filter] # always include agent_id filter
if role_filter:
@@ -883,6 +912,8 @@ class TurbopufferClient:
all_filters.append(project_filter)
if template_filter:
all_filters.append(template_filter)
if conversation_filter:
all_filters.append(conversation_filter)
if date_filters:
all_filters.extend(date_filters)
@@ -901,7 +932,7 @@ class TurbopufferClient:
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
include_attributes=["text", "organization_id", "agent_id", "role", "created_at", "conversation_id"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
@@ -952,6 +983,7 @@ class TurbopufferClient:
agent_id: Optional[str] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
@@ -969,6 +1001,10 @@ class TurbopufferClient:
agent_id: Optional agent ID to filter messages by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
conversation_id: Optional conversation ID to filter messages by. Special values:
- None (omitted): Return all messages
- "default": Return only default messages (conversation_id IS NULL)
- Any other value: Return messages in that specific conversation
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
@@ -1017,6 +1053,18 @@ class TurbopufferClient:
if template_id:
all_filters.append(("template_id", "Eq", template_id))
# conversation filter
# three cases:
# 1. conversation_id=None (omitted) -> return all messages (no filter)
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
# 3. conversation_id="xyz" -> return only messages in that conversation
if conversation_id == "default":
# "default" is reserved for default messages only (conversation_id is none)
all_filters.append(("conversation_id", "Eq", None))
elif conversation_id is not None:
# Specific conversation
all_filters.append(("conversation_id", "Eq", conversation_id))
# date filters
if start_date:
# Convert to UTC to match stored timestamps
@@ -1049,7 +1097,7 @@ class TurbopufferClient:
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
include_attributes=["text", "organization_id", "agent_id", "role", "created_at", "conversation_id"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
@@ -1134,6 +1182,7 @@ class TurbopufferClient:
"agent_id": getattr(row, "agent_id", None),
"role": getattr(row, "role", None),
"created_at": getattr(row, "created_at", None),
"conversation_id": getattr(row, "conversation_id", None),
}
messages.append(message_dict)

View File

@@ -211,6 +211,7 @@ class Message(BaseMessage):
tool_returns (List[ToolReturn]): The list of tool returns requested.
group_id (str): The multi-agent group that the message was sent in.
sender_id (str): The id of the sender of the message, can be an identity id or agent id.
conversation_id (str): The conversation this message belongs to.
t
"""
@@ -237,6 +238,7 @@ class Message(BaseMessage):
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
conversation_id: Optional[str] = Field(default=None, description="The conversation this message belongs to")
is_err: Optional[bool] = Field(
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
)
@@ -2302,6 +2304,7 @@ class MessageSearchRequest(BaseModel):
agent_id: Optional[str] = Field(None, description="Filter messages by agent ID")
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
template_id: Optional[str] = Field(None, description="Filter messages by template ID")
conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID")
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
@@ -2311,6 +2314,7 @@ class SearchAllMessagesRequest(BaseModel):
query: str = Field(..., description="Text query for full-text search")
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
agent_id: Optional[str] = Field(None, description="Filter messages by agent ID")
conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID")
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")

View File

@@ -1385,6 +1385,7 @@ async def list_messages(
),
order_by: Literal["created_at"] = Query("created_at", description="Field to sort by"),
group_id: str | None = Query(None, description="Group ID to filter messages by."),
conversation_id: str | None = Query(None, description="Conversation ID to filter messages by."),
use_assistant_message: bool = Query(True, description="Whether to use assistant messages", deprecated=True),
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool.", deprecated=True),
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument.", deprecated=True),
@@ -1404,6 +1405,7 @@ async def list_messages(
before=before,
limit=limit,
group_id=group_id,
conversation_id=conversation_id,
reverse=(order == "desc"),
return_message_object=False,
use_assistant_message=use_assistant_message,
@@ -1751,6 +1753,7 @@ async def search_messages(
agent_id=request.agent_id,
project_id=request.project_id,
template_id=request.template_id,
conversation_id=request.conversation_id,
limit=request.limit,
start_date=request.start_date,
end_date=request.end_date,

View File

@@ -40,6 +40,7 @@ async def list_all_messages(
order: Literal["asc", "desc"] = Query(
"desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
),
conversation_id: Optional[str] = Query(None, description="Conversation ID to filter messages by"),
):
"""
List messages across all agents for the current user.
@@ -51,6 +52,7 @@ async def list_all_messages(
limit=limit,
reverse=(order == "desc"),
return_message_object=False,
conversation_id=conversation_id,
actor=actor,
)
@@ -74,6 +76,7 @@ async def search_all_messages(
query_text=request.query,
search_mode=request.search_mode,
agent_id=request.agent_id,
conversation_id=request.conversation_id,
limit=request.limit,
start_date=request.start_date,
end_date=request.end_date,

View File

@@ -817,6 +817,7 @@ class SyncServer(object):
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
include_err: Optional[bool] = None,
conversation_id: Optional[str] = None,
) -> Union[List[Message], List[LettaMessage]]:
records = await self.message_manager.list_messages(
agent_id=agent_id,
@@ -827,6 +828,7 @@ class SyncServer(object):
ascending=not reverse,
group_id=group_id,
include_err=include_err,
conversation_id=conversation_id,
)
if not return_message_object:
@@ -862,6 +864,7 @@ class SyncServer(object):
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
include_err: Optional[bool] = None,
conversation_id: Optional[str] = None,
) -> Union[List[Message], List[LettaMessage]]:
records = await self.message_manager.list_messages(
agent_id=None,
@@ -872,6 +875,7 @@ class SyncServer(object):
ascending=not reverse,
group_id=group_id,
include_err=include_err,
conversation_id=conversation_id,
)
if not return_message_object:

View File

@@ -412,10 +412,7 @@ class MessageManager:
from letta.orm.run import Run as RunModel
async with db_registry.async_session() as session:
query = select(RunModel.id).where(
RunModel.id == run_id,
RunModel.organization_id == actor.organization_id
)
query = select(RunModel.id).where(RunModel.id == run_id, RunModel.organization_id == actor.organization_id)
result = await session.execute(query)
return result.scalar_one_or_none() is not None
@@ -545,10 +542,7 @@ class MessageManager:
async with db_registry.async_session() as session:
# Check which run_ids actually exist
query = select(RunModel.id).where(
RunModel.id.in_(unique_run_ids),
RunModel.organization_id == actor.organization_id
)
query = select(RunModel.id).where(RunModel.id.in_(unique_run_ids), RunModel.organization_id == actor.organization_id)
result = await session.execute(query)
existing_run_ids = set(result.scalars().all())
@@ -622,6 +616,7 @@ class MessageManager:
message_ids = []
roles = []
created_ats = []
conversation_ids = []
# combine assistant+tool messages before embedding
combined_messages = self._combine_assistant_tool_messages(messages)
@@ -633,6 +628,7 @@ class MessageManager:
message_ids.append(msg.id)
roles.append(msg.role)
created_ats.append(msg.created_at)
conversation_ids.append(msg.conversation_id)
if message_texts:
# insert to turbopuffer - TurbopufferClient will generate embeddings internally
@@ -647,6 +643,7 @@ class MessageManager:
created_ats=created_ats,
project_id=project_id,
template_id=template_id,
conversation_ids=conversation_ids,
)
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
except Exception as e:
@@ -776,6 +773,7 @@ class MessageManager:
created_ats=[message.created_at],
project_id=project_id,
template_id=template_id,
conversation_ids=[message.conversation_id],
)
logger.info(f"Successfully updated message {message.id} in Turbopuffer")
except Exception as e:
@@ -896,6 +894,7 @@ class MessageManager:
group_id: Optional[str] = None,
include_err: Optional[bool] = None,
run_id: Optional[str] = None,
conversation_id: Optional[str] = None,
) -> List[PydanticMessage]:
"""
Most performant query to list messages by directly querying the Message table.
@@ -917,6 +916,7 @@ class MessageManager:
group_id: Optional group ID to filter messages by group_id.
include_err: Optional boolean to include errors and error statuses. Used for debugging only.
run_id: Optional run ID to filter messages by run_id.
conversation_id: Optional conversation ID to filter messages by conversation_id.
Returns:
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
@@ -942,6 +942,9 @@ class MessageManager:
if run_id:
query = query.where(MessageModel.run_id == run_id)
if conversation_id:
query = query.where(MessageModel.conversation_id == conversation_id)
# if not include_err:
# query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None)))
@@ -1233,6 +1236,7 @@ class MessageManager:
agent_id: Optional[str] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_id: Optional[str] = None,
limit: int = 50,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
@@ -1248,6 +1252,7 @@ class MessageManager:
agent_id: Optional agent ID to filter messages by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
conversation_id: Optional conversation ID to filter messages by
limit: Maximum number of results to return
start_date: Optional filter for messages created after this date
end_date: Optional filter for messages created on or before this date (inclusive)
@@ -1277,6 +1282,7 @@ class MessageManager:
agent_id=agent_id,
project_id=project_id,
template_id=template_id,
conversation_id=conversation_id,
start_date=start_date,
end_date=end_date,
)

View File

@@ -2167,6 +2167,122 @@ async def test_message_template_id_filtering(server, sarah_agent, default_user,
)
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_conversation_id_filtering(server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding):
"""Test that conversation_id filtering works correctly in message queries, including 'default' sentinel"""
from letta.schemas.conversation import CreateConversation
from letta.schemas.letta_message_content import TextContent
from letta.services.conversation_manager import ConversationManager
conversation_manager = ConversationManager()
# Create a conversation
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test conversation"),
actor=default_user,
)
# Create messages with different conversation_ids
message_with_conv = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text="Message in specific conversation about Python")],
)
message_default_conv = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text="Message in default conversation about JavaScript")],
)
# Insert messages with their respective conversation IDs
tpuf_client = TurbopufferClient()
# Message with specific conversation_id
await tpuf_client.insert_messages(
agent_id=sarah_agent.id,
message_texts=[message_with_conv.content[0].text],
message_ids=[message_with_conv.id],
organization_id=default_user.organization_id,
actor=default_user,
roles=[message_with_conv.role],
created_ats=[message_with_conv.created_at],
conversation_ids=[conversation.id], # Specific conversation
)
# Message with no conversation_id (default)
await tpuf_client.insert_messages(
agent_id=sarah_agent.id,
message_texts=[message_default_conv.content[0].text],
message_ids=[message_default_conv.id],
organization_id=default_user.organization_id,
actor=default_user,
roles=[message_default_conv.role],
created_ats=[message_default_conv.created_at],
conversation_ids=[None], # Default conversation (NULL)
)
# Wait for indexing
await asyncio.sleep(1)
# Test 1: Query for specific conversation - should find only message with that conversation_id
results_conv = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
conversation_id=conversation.id,
actor=default_user,
)
assert len(results_conv) == 1
assert results_conv[0][0]["id"] == message_with_conv.id
assert "Python" in results_conv[0][0]["text"]
# Test 2: Query for "default" conversation - should find only messages with NULL conversation_id
results_default = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
conversation_id="default", # Sentinel for NULL
actor=default_user,
)
assert len(results_default) >= 1 # May have other default messages from setup
# Check our message is in there
default_ids = [r[0]["id"] for r in results_default]
assert message_default_conv.id in default_ids
# Verify the message content
for msg_dict, _, _ in results_default:
if msg_dict["id"] == message_default_conv.id:
assert "JavaScript" in msg_dict["text"]
break
# Test 3: Query without conversation filter - should find both
results_all = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
conversation_id=None, # No filter
actor=default_user,
)
assert len(results_all) >= 2 # May have other messages from setup
message_ids = [r[0]["id"] for r in results_all]
assert message_with_conv.id in message_ids
assert message_default_conv.id in message_ids
# Clean up
await tpuf_client.delete_messages(
agent_id=sarah_agent.id, organization_id=default_user.organization_id, message_ids=[message_with_conv.id, message_default_conv.id]
)
@pytest.mark.asyncio
async def test_system_messages_not_embedded_during_agent_creation(server, default_user, enable_message_embedding):
"""Test that system messages are filtered out before being passed to the embedding pipeline during agent creation"""

View File

@@ -412,6 +412,59 @@ async def test_message_delete(server: SyncServer, hello_world_message_fixture, d
assert retrieved is None
@pytest.mark.asyncio
async def test_message_conversation_id_persistence(server: SyncServer, sarah_agent, default_user):
"""Test that conversation_id is properly persisted and retrieved from DB to Pydantic object"""
from letta.schemas.conversation import CreateConversation
from letta.services.conversation_manager import ConversationManager
conversation_manager = ConversationManager()
# Test 1: Create a message without conversation_id (should be None - the default/backward-compat case)
message_no_conv = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text="Test message without conversation")],
)
created_no_conv = await server.message_manager.create_many_messages_async([message_no_conv], actor=default_user)
assert len(created_no_conv) == 1
assert created_no_conv[0].conversation_id is None
# Verify retrieval also has None - this confirms ORM-to-Pydantic conversion works for None
retrieved_no_conv = await server.message_manager.get_message_by_id_async(created_no_conv[0].id, actor=default_user)
assert retrieved_no_conv is not None
assert retrieved_no_conv.conversation_id is None
assert retrieved_no_conv.id == created_no_conv[0].id
# Test 2: Create a conversation and a message with that conversation_id
conversation = await conversation_manager.create_conversation(
agent_id=sarah_agent.id,
conversation_create=CreateConversation(summary="Test conversation"),
actor=default_user,
)
message_with_conv = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text="Test message with conversation")],
conversation_id=conversation.id,
)
created_with_conv = await server.message_manager.create_many_messages_async([message_with_conv], actor=default_user)
assert len(created_with_conv) == 1
assert created_with_conv[0].conversation_id == conversation.id
# Verify retrieval has the correct conversation_id - this confirms ORM-to-Pydantic conversion works for non-None
retrieved_with_conv = await server.message_manager.get_message_by_id_async(created_with_conv[0].id, actor=default_user)
assert retrieved_with_conv is not None
assert retrieved_with_conv.conversation_id == conversation.id
assert retrieved_with_conv.id == created_with_conv[0].id
# Test 3: Verify the field exists on the Pydantic model
assert hasattr(retrieved_with_conv, "conversation_id")
@pytest.mark.asyncio
async def test_message_size(server: SyncServer, hello_world_message_fixture, default_user):
"""Test counting messages with filters"""