feat: Add conversation_id filtering to message endpoints (#8324)
* feat: Add conversation_id filtering to message list and search endpoints Add optional conversation_id parameter to filter messages by conversation: - client.agents.messages.list - client.messages.list - client.messages.search Changes: - Added conversation_id field to MessageSearchRequest and SearchAllMessagesRequest schemas - Added conversation_id filtering to list_messages in message_manager.py - Updated get_agent_recall_async and get_all_messages_recall_async in server.py - Added conversation_id query parameter to router endpoints - Updated Turbopuffer client to support conversation_id filtering in searches Fixes #8320 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Charles Packer <cpacker@users.noreply.github.com> * add conversation_id to message and tpuf * default messages filter for backward compatibility * add test and auto gen * fix integration test * fix test * update test --------- Co-authored-by: letta-code <248085862+letta-code@users.noreply.github.com> Co-authored-by: Charles Packer <cpacker@users.noreply.github.com> Co-authored-by: christinatong01 <christina@letta.com>
This commit is contained in:
committed by
Caren Thomas
parent
737d6e2550
commit
ed6284cedb
@@ -7442,6 +7442,24 @@
|
||||
},
|
||||
"description": "Group ID to filter messages by."
|
||||
},
|
||||
{
|
||||
"name": "conversation_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Conversation ID to filter messages by.",
|
||||
"title": "Conversation Id"
|
||||
},
|
||||
"description": "Conversation ID to filter messages by."
|
||||
},
|
||||
{
|
||||
"name": "use_assistant_message",
|
||||
"in": "query",
|
||||
@@ -16137,6 +16155,24 @@
|
||||
"title": "Order"
|
||||
},
|
||||
"description": "Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
|
||||
},
|
||||
{
|
||||
"name": "conversation_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Conversation ID to filter messages by",
|
||||
"title": "Conversation Id"
|
||||
},
|
||||
"description": "Conversation ID to filter messages by"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -35419,6 +35455,18 @@
|
||||
"title": "Batch Item Id",
|
||||
"description": "The id of the LLMBatchItem that this message is associated with"
|
||||
},
|
||||
"conversation_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Conversation Id",
|
||||
"description": "The conversation this message belongs to"
|
||||
},
|
||||
"is_err": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -35494,7 +35542,7 @@
|
||||
"type": "object",
|
||||
"required": ["role"],
|
||||
"title": "Message",
|
||||
"description": " Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.\n\n Attributes:\n id (str): The unique identifier of the message.\n role (MessageRole): The role of the participant.\n text (str): The text of the message.\n user_id (str): The unique identifier of the user.\n agent_id (str): The unique identifier of the agent.\n model (str): The model used to make the function call.\n name (str): The name of the participant.\n created_at (datetime): The time the message was created.\n tool_calls (List[OpenAIToolCall,]): The list of tool calls requested.\n tool_call_id (str): The id of the tool call.\n step_id (str): The id of the step that this message was created in.\n otid (str): The offline threading id associated with this message.\n tool_returns (List[ToolReturn]): The list of tool returns requested.\n group_id (str): The multi-agent group that the message was sent in.\n sender_id (str): The id of the sender of the message, can be an identity id or agent id.\nt"
|
||||
"description": " Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.\n\n Attributes:\n id (str): The unique identifier of the message.\n role (MessageRole): The role of the participant.\n text (str): The text of the message.\n user_id (str): The unique identifier of the user.\n agent_id (str): The unique identifier of the agent.\n model (str): The model used to make the function call.\n name (str): The name of the participant.\n created_at (datetime): The time the message was created.\n tool_calls (List[OpenAIToolCall,]): The list of tool calls requested.\n tool_call_id (str): The id of the tool call.\n step_id (str): The id of the step that this message was created in.\n otid (str): The offline threading id associated with this message.\n tool_returns (List[ToolReturn]): The list of tool returns requested.\n group_id (str): The multi-agent group that the message was sent in.\n sender_id (str): The id of the sender of the message, can be an identity id or agent id.\n conversation_id (str): The conversation this message belongs to.\nt"
|
||||
},
|
||||
"MessageCreate": {
|
||||
"properties": {
|
||||
@@ -35676,6 +35724,18 @@
|
||||
"title": "Template Id",
|
||||
"description": "Filter messages by template ID"
|
||||
},
|
||||
"conversation_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Conversation Id",
|
||||
"description": "Filter messages by conversation ID"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"maximum": 100,
|
||||
@@ -38588,6 +38648,18 @@
|
||||
"title": "Agent Id",
|
||||
"description": "Filter messages by agent ID"
|
||||
},
|
||||
"conversation_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Conversation Id",
|
||||
"description": "Filter messages by conversation ID"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"maximum": 100,
|
||||
|
||||
@@ -400,6 +400,7 @@ class TurbopufferClient:
|
||||
created_ats: List[datetime],
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
conversation_ids: Optional[List[Optional[str]]] = None,
|
||||
) -> bool:
|
||||
"""Insert messages into Turbopuffer.
|
||||
|
||||
@@ -413,6 +414,7 @@ class TurbopufferClient:
|
||||
created_ats: List of creation timestamps for each message
|
||||
project_id: Optional project ID for all messages
|
||||
template_id: Optional template ID for all messages
|
||||
conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts)
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
@@ -441,22 +443,26 @@ class TurbopufferClient:
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
||||
if len(message_ids) != len(created_ats):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
|
||||
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
|
||||
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
|
||||
|
||||
# prepare column-based data for turbopuffer - optimized for batch insert
|
||||
ids = []
|
||||
vectors = []
|
||||
texts = []
|
||||
organization_ids = []
|
||||
agent_ids = []
|
||||
organization_ids_list = []
|
||||
agent_ids_list = []
|
||||
message_roles = []
|
||||
created_at_timestamps = []
|
||||
project_ids = []
|
||||
template_ids = []
|
||||
project_ids_list = []
|
||||
template_ids_list = []
|
||||
conversation_ids_list = []
|
||||
|
||||
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
|
||||
message_id = message_ids[original_idx]
|
||||
role = roles[original_idx]
|
||||
created_at = created_ats[original_idx]
|
||||
conversation_id = conversation_ids[original_idx] if conversation_ids else None
|
||||
|
||||
# ensure the provided timestamp is timezone-aware and in UTC
|
||||
if created_at.tzinfo is None:
|
||||
@@ -470,31 +476,36 @@ class TurbopufferClient:
|
||||
ids.append(message_id)
|
||||
vectors.append(embedding)
|
||||
texts.append(text)
|
||||
organization_ids.append(organization_id)
|
||||
agent_ids.append(agent_id)
|
||||
organization_ids_list.append(organization_id)
|
||||
agent_ids_list.append(agent_id)
|
||||
message_roles.append(role.value)
|
||||
created_at_timestamps.append(timestamp)
|
||||
project_ids.append(project_id)
|
||||
template_ids.append(template_id)
|
||||
project_ids_list.append(project_id)
|
||||
template_ids_list.append(template_id)
|
||||
conversation_ids_list.append(conversation_id)
|
||||
|
||||
# build column-based upsert data
|
||||
upsert_columns = {
|
||||
"id": ids,
|
||||
"vector": vectors,
|
||||
"text": texts,
|
||||
"organization_id": organization_ids,
|
||||
"agent_id": agent_ids,
|
||||
"organization_id": organization_ids_list,
|
||||
"agent_id": agent_ids_list,
|
||||
"role": message_roles,
|
||||
"created_at": created_at_timestamps,
|
||||
}
|
||||
|
||||
# only include conversation_id if it's provided
|
||||
if conversation_ids is not None:
|
||||
upsert_columns["conversation_id"] = conversation_ids_list
|
||||
|
||||
# only include project_id if it's provided
|
||||
if project_id is not None:
|
||||
upsert_columns["project_id"] = project_ids
|
||||
upsert_columns["project_id"] = project_ids_list
|
||||
|
||||
# only include template_id if it's provided
|
||||
if template_id is not None:
|
||||
upsert_columns["template_id"] = template_ids
|
||||
upsert_columns["template_id"] = template_ids_list
|
||||
|
||||
try:
|
||||
# use global semaphore to limit concurrent Turbopuffer writes
|
||||
@@ -506,7 +517,10 @@ class TurbopufferClient:
|
||||
await namespace.write(
|
||||
upsert_columns=upsert_columns,
|
||||
distance_metric="cosine_distance",
|
||||
schema={"text": {"type": "string", "full_text_search": True}},
|
||||
schema={
|
||||
"text": {"type": "string", "full_text_search": True},
|
||||
"conversation_id": {"type": "string"},
|
||||
},
|
||||
)
|
||||
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
|
||||
return True
|
||||
@@ -792,6 +806,7 @@ class TurbopufferClient:
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -809,6 +824,7 @@ class TurbopufferClient:
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
conversation_id: Optional conversation ID to filter messages by (use "default" for NULL)
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
@@ -875,6 +891,19 @@ class TurbopufferClient:
|
||||
if template_id:
|
||||
template_filter = ("template_id", "Eq", template_id)
|
||||
|
||||
# build conversation_id filter if provided
|
||||
# three cases:
|
||||
# 1. conversation_id=None (omitted) -> return all messages (no filter)
|
||||
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
|
||||
# 3. conversation_id="xyz" -> return only messages in that conversation
|
||||
conversation_filter = None
|
||||
if conversation_id == "default":
|
||||
# "default" is reserved for default messages only (conversation_id is none)
|
||||
conversation_filter = ("conversation_id", "Eq", None)
|
||||
elif conversation_id is not None:
|
||||
# Specific conversation
|
||||
conversation_filter = ("conversation_id", "Eq", conversation_id)
|
||||
|
||||
# combine all filters
|
||||
all_filters = [agent_filter] # always include agent_id filter
|
||||
if role_filter:
|
||||
@@ -883,6 +912,8 @@ class TurbopufferClient:
|
||||
all_filters.append(project_filter)
|
||||
if template_filter:
|
||||
all_filters.append(template_filter)
|
||||
if conversation_filter:
|
||||
all_filters.append(conversation_filter)
|
||||
if date_filters:
|
||||
all_filters.extend(date_filters)
|
||||
|
||||
@@ -901,7 +932,7 @@ class TurbopufferClient:
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
|
||||
include_attributes=["text", "organization_id", "agent_id", "role", "created_at", "conversation_id"],
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
@@ -952,6 +983,7 @@ class TurbopufferClient:
|
||||
agent_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -969,6 +1001,10 @@ class TurbopufferClient:
|
||||
agent_id: Optional agent ID to filter messages by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
conversation_id: Optional conversation ID to filter messages by. Special values:
|
||||
- None (omitted): Return all messages
|
||||
- "default": Return only default messages (conversation_id IS NULL)
|
||||
- Any other value: Return messages in that specific conversation
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
@@ -1017,6 +1053,18 @@ class TurbopufferClient:
|
||||
if template_id:
|
||||
all_filters.append(("template_id", "Eq", template_id))
|
||||
|
||||
# conversation filter
|
||||
# three cases:
|
||||
# 1. conversation_id=None (omitted) -> return all messages (no filter)
|
||||
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
|
||||
# 3. conversation_id="xyz" -> return only messages in that conversation
|
||||
if conversation_id == "default":
|
||||
# "default" is reserved for default messages only (conversation_id is none)
|
||||
all_filters.append(("conversation_id", "Eq", None))
|
||||
elif conversation_id is not None:
|
||||
# Specific conversation
|
||||
all_filters.append(("conversation_id", "Eq", conversation_id))
|
||||
|
||||
# date filters
|
||||
if start_date:
|
||||
# Convert to UTC to match stored timestamps
|
||||
@@ -1049,7 +1097,7 @@ class TurbopufferClient:
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
|
||||
include_attributes=["text", "organization_id", "agent_id", "role", "created_at", "conversation_id"],
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
@@ -1134,6 +1182,7 @@ class TurbopufferClient:
|
||||
"agent_id": getattr(row, "agent_id", None),
|
||||
"role": getattr(row, "role", None),
|
||||
"created_at": getattr(row, "created_at", None),
|
||||
"conversation_id": getattr(row, "conversation_id", None),
|
||||
}
|
||||
messages.append(message_dict)
|
||||
|
||||
|
||||
@@ -211,6 +211,7 @@ class Message(BaseMessage):
|
||||
tool_returns (List[ToolReturn]): The list of tool returns requested.
|
||||
group_id (str): The multi-agent group that the message was sent in.
|
||||
sender_id (str): The id of the sender of the message, can be an identity id or agent id.
|
||||
conversation_id (str): The conversation this message belongs to.
|
||||
t
|
||||
"""
|
||||
|
||||
@@ -237,6 +238,7 @@ class Message(BaseMessage):
|
||||
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
|
||||
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
|
||||
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
|
||||
conversation_id: Optional[str] = Field(default=None, description="The conversation this message belongs to")
|
||||
is_err: Optional[bool] = Field(
|
||||
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
|
||||
)
|
||||
@@ -2302,6 +2304,7 @@ class MessageSearchRequest(BaseModel):
|
||||
agent_id: Optional[str] = Field(None, description="Filter messages by agent ID")
|
||||
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
|
||||
template_id: Optional[str] = Field(None, description="Filter messages by template ID")
|
||||
conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID")
|
||||
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
|
||||
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
|
||||
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
|
||||
@@ -2311,6 +2314,7 @@ class SearchAllMessagesRequest(BaseModel):
|
||||
query: str = Field(..., description="Text query for full-text search")
|
||||
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
|
||||
agent_id: Optional[str] = Field(None, description="Filter messages by agent ID")
|
||||
conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID")
|
||||
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
|
||||
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
|
||||
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
|
||||
|
||||
@@ -1385,6 +1385,7 @@ async def list_messages(
|
||||
),
|
||||
order_by: Literal["created_at"] = Query("created_at", description="Field to sort by"),
|
||||
group_id: str | None = Query(None, description="Group ID to filter messages by."),
|
||||
conversation_id: str | None = Query(None, description="Conversation ID to filter messages by."),
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages", deprecated=True),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool.", deprecated=True),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument.", deprecated=True),
|
||||
@@ -1404,6 +1405,7 @@ async def list_messages(
|
||||
before=before,
|
||||
limit=limit,
|
||||
group_id=group_id,
|
||||
conversation_id=conversation_id,
|
||||
reverse=(order == "desc"),
|
||||
return_message_object=False,
|
||||
use_assistant_message=use_assistant_message,
|
||||
@@ -1751,6 +1753,7 @@ async def search_messages(
|
||||
agent_id=request.agent_id,
|
||||
project_id=request.project_id,
|
||||
template_id=request.template_id,
|
||||
conversation_id=request.conversation_id,
|
||||
limit=request.limit,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
|
||||
@@ -40,6 +40,7 @@ async def list_all_messages(
|
||||
order: Literal["asc", "desc"] = Query(
|
||||
"desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
|
||||
),
|
||||
conversation_id: Optional[str] = Query(None, description="Conversation ID to filter messages by"),
|
||||
):
|
||||
"""
|
||||
List messages across all agents for the current user.
|
||||
@@ -51,6 +52,7 @@ async def list_all_messages(
|
||||
limit=limit,
|
||||
reverse=(order == "desc"),
|
||||
return_message_object=False,
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -74,6 +76,7 @@ async def search_all_messages(
|
||||
query_text=request.query,
|
||||
search_mode=request.search_mode,
|
||||
agent_id=request.agent_id,
|
||||
conversation_id=request.conversation_id,
|
||||
limit=request.limit,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
|
||||
@@ -817,6 +817,7 @@ class SyncServer(object):
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
include_err: Optional[bool] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
records = await self.message_manager.list_messages(
|
||||
agent_id=agent_id,
|
||||
@@ -827,6 +828,7 @@ class SyncServer(object):
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
include_err=include_err,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
@@ -862,6 +864,7 @@ class SyncServer(object):
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
include_err: Optional[bool] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
records = await self.message_manager.list_messages(
|
||||
agent_id=None,
|
||||
@@ -872,6 +875,7 @@ class SyncServer(object):
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
include_err=include_err,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
|
||||
@@ -412,10 +412,7 @@ class MessageManager:
|
||||
from letta.orm.run import Run as RunModel
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(RunModel.id).where(
|
||||
RunModel.id == run_id,
|
||||
RunModel.organization_id == actor.organization_id
|
||||
)
|
||||
query = select(RunModel.id).where(RunModel.id == run_id, RunModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
@@ -545,10 +542,7 @@ class MessageManager:
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Check which run_ids actually exist
|
||||
query = select(RunModel.id).where(
|
||||
RunModel.id.in_(unique_run_ids),
|
||||
RunModel.organization_id == actor.organization_id
|
||||
)
|
||||
query = select(RunModel.id).where(RunModel.id.in_(unique_run_ids), RunModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(query)
|
||||
existing_run_ids = set(result.scalars().all())
|
||||
|
||||
@@ -622,6 +616,7 @@ class MessageManager:
|
||||
message_ids = []
|
||||
roles = []
|
||||
created_ats = []
|
||||
conversation_ids = []
|
||||
|
||||
# combine assistant+tool messages before embedding
|
||||
combined_messages = self._combine_assistant_tool_messages(messages)
|
||||
@@ -633,6 +628,7 @@ class MessageManager:
|
||||
message_ids.append(msg.id)
|
||||
roles.append(msg.role)
|
||||
created_ats.append(msg.created_at)
|
||||
conversation_ids.append(msg.conversation_id)
|
||||
|
||||
if message_texts:
|
||||
# insert to turbopuffer - TurbopufferClient will generate embeddings internally
|
||||
@@ -647,6 +643,7 @@ class MessageManager:
|
||||
created_ats=created_ats,
|
||||
project_id=project_id,
|
||||
template_id=template_id,
|
||||
conversation_ids=conversation_ids,
|
||||
)
|
||||
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
|
||||
except Exception as e:
|
||||
@@ -776,6 +773,7 @@ class MessageManager:
|
||||
created_ats=[message.created_at],
|
||||
project_id=project_id,
|
||||
template_id=template_id,
|
||||
conversation_ids=[message.conversation_id],
|
||||
)
|
||||
logger.info(f"Successfully updated message {message.id} in Turbopuffer")
|
||||
except Exception as e:
|
||||
@@ -896,6 +894,7 @@ class MessageManager:
|
||||
group_id: Optional[str] = None,
|
||||
include_err: Optional[bool] = None,
|
||||
run_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Most performant query to list messages by directly querying the Message table.
|
||||
@@ -917,6 +916,7 @@ class MessageManager:
|
||||
group_id: Optional group ID to filter messages by group_id.
|
||||
include_err: Optional boolean to include errors and error statuses. Used for debugging only.
|
||||
run_id: Optional run ID to filter messages by run_id.
|
||||
conversation_id: Optional conversation ID to filter messages by conversation_id.
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||||
@@ -942,6 +942,9 @@ class MessageManager:
|
||||
if run_id:
|
||||
query = query.where(MessageModel.run_id == run_id)
|
||||
|
||||
if conversation_id:
|
||||
query = query.where(MessageModel.conversation_id == conversation_id)
|
||||
|
||||
# if not include_err:
|
||||
# query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None)))
|
||||
|
||||
@@ -1233,6 +1236,7 @@ class MessageManager:
|
||||
agent_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
@@ -1248,6 +1252,7 @@ class MessageManager:
|
||||
agent_id: Optional agent ID to filter messages by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
conversation_id: Optional conversation ID to filter messages by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional filter for messages created after this date
|
||||
end_date: Optional filter for messages created on or before this date (inclusive)
|
||||
@@ -1277,6 +1282,7 @@ class MessageManager:
|
||||
agent_id=agent_id,
|
||||
project_id=project_id,
|
||||
template_id=template_id,
|
||||
conversation_id=conversation_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
@@ -2167,6 +2167,122 @@ async def test_message_template_id_filtering(server, sarah_agent, default_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_conversation_id_filtering(server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding):
|
||||
"""Test that conversation_id filtering works correctly in message queries, including 'default' sentinel"""
|
||||
from letta.schemas.conversation import CreateConversation
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
|
||||
conversation_manager = ConversationManager()
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test conversation"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages with different conversation_ids
|
||||
message_with_conv = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Message in specific conversation about Python")],
|
||||
)
|
||||
|
||||
message_default_conv = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Message in default conversation about JavaScript")],
|
||||
)
|
||||
|
||||
# Insert messages with their respective conversation IDs
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Message with specific conversation_id
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
message_texts=[message_with_conv.content[0].text],
|
||||
message_ids=[message_with_conv.id],
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
roles=[message_with_conv.role],
|
||||
created_ats=[message_with_conv.created_at],
|
||||
conversation_ids=[conversation.id], # Specific conversation
|
||||
)
|
||||
|
||||
# Message with no conversation_id (default)
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
message_texts=[message_default_conv.content[0].text],
|
||||
message_ids=[message_default_conv.id],
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
roles=[message_default_conv.role],
|
||||
created_ats=[message_default_conv.created_at],
|
||||
conversation_ids=[None], # Default conversation (NULL)
|
||||
)
|
||||
|
||||
# Wait for indexing
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Test 1: Query for specific conversation - should find only message with that conversation_id
|
||||
results_conv = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(results_conv) == 1
|
||||
assert results_conv[0][0]["id"] == message_with_conv.id
|
||||
assert "Python" in results_conv[0][0]["text"]
|
||||
|
||||
# Test 2: Query for "default" conversation - should find only messages with NULL conversation_id
|
||||
results_default = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
conversation_id="default", # Sentinel for NULL
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(results_default) >= 1 # May have other default messages from setup
|
||||
# Check our message is in there
|
||||
default_ids = [r[0]["id"] for r in results_default]
|
||||
assert message_default_conv.id in default_ids
|
||||
|
||||
# Verify the message content
|
||||
for msg_dict, _, _ in results_default:
|
||||
if msg_dict["id"] == message_default_conv.id:
|
||||
assert "JavaScript" in msg_dict["text"]
|
||||
break
|
||||
|
||||
# Test 3: Query without conversation filter - should find both
|
||||
results_all = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
conversation_id=None, # No filter
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(results_all) >= 2 # May have other messages from setup
|
||||
message_ids = [r[0]["id"] for r in results_all]
|
||||
assert message_with_conv.id in message_ids
|
||||
assert message_default_conv.id in message_ids
|
||||
|
||||
# Clean up
|
||||
await tpuf_client.delete_messages(
|
||||
agent_id=sarah_agent.id, organization_id=default_user.organization_id, message_ids=[message_with_conv.id, message_default_conv.id]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_messages_not_embedded_during_agent_creation(server, default_user, enable_message_embedding):
|
||||
"""Test that system messages are filtered out before being passed to the embedding pipeline during agent creation"""
|
||||
|
||||
@@ -412,6 +412,59 @@ async def test_message_delete(server: SyncServer, hello_world_message_fixture, d
|
||||
assert retrieved is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_conversation_id_persistence(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that conversation_id is properly persisted and retrieved from DB to Pydantic object"""
|
||||
from letta.schemas.conversation import CreateConversation
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
|
||||
conversation_manager = ConversationManager()
|
||||
|
||||
# Test 1: Create a message without conversation_id (should be None - the default/backward-compat case)
|
||||
message_no_conv = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Test message without conversation")],
|
||||
)
|
||||
|
||||
created_no_conv = await server.message_manager.create_many_messages_async([message_no_conv], actor=default_user)
|
||||
assert len(created_no_conv) == 1
|
||||
assert created_no_conv[0].conversation_id is None
|
||||
|
||||
# Verify retrieval also has None - this confirms ORM-to-Pydantic conversion works for None
|
||||
retrieved_no_conv = await server.message_manager.get_message_by_id_async(created_no_conv[0].id, actor=default_user)
|
||||
assert retrieved_no_conv is not None
|
||||
assert retrieved_no_conv.conversation_id is None
|
||||
assert retrieved_no_conv.id == created_no_conv[0].id
|
||||
|
||||
# Test 2: Create a conversation and a message with that conversation_id
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test conversation"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
message_with_conv = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Test message with conversation")],
|
||||
conversation_id=conversation.id,
|
||||
)
|
||||
|
||||
created_with_conv = await server.message_manager.create_many_messages_async([message_with_conv], actor=default_user)
|
||||
assert len(created_with_conv) == 1
|
||||
assert created_with_conv[0].conversation_id == conversation.id
|
||||
|
||||
# Verify retrieval has the correct conversation_id - this confirms ORM-to-Pydantic conversion works for non-None
|
||||
retrieved_with_conv = await server.message_manager.get_message_by_id_async(created_with_conv[0].id, actor=default_user)
|
||||
assert retrieved_with_conv is not None
|
||||
assert retrieved_with_conv.conversation_id == conversation.id
|
||||
assert retrieved_with_conv.id == created_with_conv[0].id
|
||||
|
||||
# Test 3: Verify the field exists on the Pydantic model
|
||||
assert hasattr(retrieved_with_conv, "conversation_id")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_size(server: SyncServer, hello_world_message_fixture, default_user):
|
||||
"""Test counting messages with filters"""
|
||||
|
||||
Reference in New Issue
Block a user