diff --git a/fern/openapi.json b/fern/openapi.json index 35def53d..0fc4f219 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -7442,6 +7442,24 @@ }, "description": "Group ID to filter messages by." }, + { + "name": "conversation_id", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Conversation ID to filter messages by.", + "title": "Conversation Id" + }, + "description": "Conversation ID to filter messages by." + }, { "name": "use_assistant_message", "in": "query", @@ -16137,6 +16155,24 @@ "title": "Order" }, "description": "Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first" + }, + { + "name": "conversation_id", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Conversation ID to filter messages by", + "title": "Conversation Id" + }, + "description": "Conversation ID to filter messages by" } ], "responses": { @@ -35419,6 +35455,18 @@ "title": "Batch Item Id", "description": "The id of the LLMBatchItem that this message is associated with" }, + "conversation_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Conversation Id", + "description": "The conversation this message belongs to" + }, "is_err": { "anyOf": [ { @@ -35494,7 +35542,7 @@ "type": "object", "required": ["role"], "title": "Message", - "description": " Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.\n\n Attributes:\n id (str): The unique identifier of the message.\n role (MessageRole): The role of the participant.\n text (str): The text of the message.\n user_id (str): The unique identifier of the user.\n agent_id (str): The unique identifier of the agent.\n model (str): The model used to make the function call.\n name (str): The name of the participant.\n created_at (datetime): The time the message was created.\n tool_calls (List[OpenAIToolCall,]): The list of tool calls requested.\n tool_call_id (str): The id of the tool call.\n step_id (str): The id of the step that this message was created in.\n otid (str): The offline threading id associated with this message.\n tool_returns (List[ToolReturn]): The list of tool returns requested.\n group_id (str): The multi-agent group that the message was sent in.\n sender_id (str): The id of the sender of the message, can be an identity id or agent id.\nt" + "description": " Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.\n\n Attributes:\n id (str): The unique identifier of the message.\n role (MessageRole): The role of the participant.\n text (str): The text of the message.\n user_id (str): The unique identifier of the user.\n agent_id (str): The unique identifier of the agent.\n model (str): The model used to make the function call.\n name (str): The name of the participant.\n created_at (datetime): The time the message was created.\n tool_calls (List[OpenAIToolCall,]): The list of tool calls requested.\n tool_call_id (str): The id of the tool call.\n step_id (str): The id of the step that this message was created in.\n otid (str): The offline threading id associated with this message.\n tool_returns (List[ToolReturn]): The list of tool returns requested.\n group_id (str): The multi-agent group that the message was sent in.\n sender_id (str): The id of the sender of the message, can be an identity id or agent id.\n conversation_id (str): The conversation this message belongs to.\nt" }, "MessageCreate": { "properties": { @@ -35676,6 +35724,18 @@ "title": "Template Id", "description": "Filter messages by template ID" }, + "conversation_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Conversation Id", + "description": "Filter messages by conversation ID" + }, "limit": { "type": "integer", "maximum": 100, @@ -38588,6 +38648,18 @@ "title": "Agent Id", "description": "Filter messages by agent ID" }, + "conversation_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Conversation Id", + "description": "Filter messages by conversation ID" + }, "limit": { "type": "integer", "maximum": 100, diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index acc2dde1..c5e7f925 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -400,6 +400,7 @@ class TurbopufferClient: created_ats: List[datetime], project_id: Optional[str] = None, template_id: Optional[str] = None, + conversation_ids: Optional[List[Optional[str]]] = None, ) -> bool: """Insert messages into Turbopuffer. @@ -413,6 +414,7 @@ class TurbopufferClient: created_ats: List of creation timestamps for each message project_id: Optional project ID for all messages template_id: Optional template ID for all messages + conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts) Returns: True if successful @@ -441,22 +443,26 @@ class TurbopufferClient: raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})") if len(message_ids) != len(created_ats): raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})") + if conversation_ids is not None and len(conversation_ids) != len(message_ids): + raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})") # prepare column-based data for turbopuffer - optimized for batch insert ids = [] vectors = [] texts = [] - organization_ids = [] - agent_ids = [] + organization_ids_list = [] + agent_ids_list = [] message_roles = [] created_at_timestamps = [] - project_ids = [] - template_ids = [] + project_ids_list = [] + template_ids_list = [] + conversation_ids_list = [] for (original_idx, text), embedding in zip(filtered_messages, embeddings): message_id = message_ids[original_idx] role = roles[original_idx] created_at = created_ats[original_idx] + conversation_id = conversation_ids[original_idx] if conversation_ids else None # ensure the provided timestamp is timezone-aware and in UTC if created_at.tzinfo is None: @@ -470,31 +476,36 @@ class TurbopufferClient: ids.append(message_id) vectors.append(embedding) texts.append(text) - organization_ids.append(organization_id) - agent_ids.append(agent_id) + organization_ids_list.append(organization_id) + agent_ids_list.append(agent_id) message_roles.append(role.value) created_at_timestamps.append(timestamp) - project_ids.append(project_id) - template_ids.append(template_id) + project_ids_list.append(project_id) + template_ids_list.append(template_id) + conversation_ids_list.append(conversation_id) # build column-based upsert data upsert_columns = { "id": ids, "vector": vectors, "text": texts, - "organization_id": organization_ids, - "agent_id": agent_ids, + "organization_id": organization_ids_list, + "agent_id": agent_ids_list, "role": message_roles, "created_at": created_at_timestamps, } + # only include conversation_id if it's provided + if conversation_ids is not None: + upsert_columns["conversation_id"] = conversation_ids_list + # only include project_id if it's provided if project_id is not None: - upsert_columns["project_id"] = project_ids + upsert_columns["project_id"] = project_ids_list # only include template_id if it's provided if template_id is not None: - upsert_columns["template_id"] = template_ids + upsert_columns["template_id"] = template_ids_list try: # use global semaphore to limit concurrent Turbopuffer writes @@ -506,7 +517,10 @@ class TurbopufferClient: await namespace.write( upsert_columns=upsert_columns, distance_metric="cosine_distance", - schema={"text": {"type": "string", "full_text_search": True}}, + schema={ + "text": {"type": "string", "full_text_search": True}, + "conversation_id": {"type": "string"}, + }, ) logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}") return True @@ -792,6 +806,7 @@ class TurbopufferClient: roles: Optional[List[MessageRole]] = None, project_id: Optional[str] = None, template_id: Optional[str] = None, + conversation_id: Optional[str] = None, vector_weight: float = 0.5, fts_weight: float = 0.5, start_date: Optional[datetime] = None, @@ -809,6 +824,7 @@ class TurbopufferClient: roles: Optional list of message roles to filter by project_id: Optional project ID to filter messages by template_id: Optional template ID to filter messages by + conversation_id: Optional conversation ID to filter messages by (use "default" for NULL) vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) start_date: Optional datetime to filter messages created after this date @@ -875,6 +891,19 @@ class TurbopufferClient: if template_id: template_filter = ("template_id", "Eq", template_id) + # build conversation_id filter if provided + # three cases: + # 1. conversation_id=None (omitted) -> return all messages (no filter) + # 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility + # 3. conversation_id="xyz" -> return only messages in that conversation + conversation_filter = None + if conversation_id == "default": + # "default" is reserved for default messages only (conversation_id is none) + conversation_filter = ("conversation_id", "Eq", None) + elif conversation_id is not None: + # Specific conversation + conversation_filter = ("conversation_id", "Eq", conversation_id) + # combine all filters all_filters = [agent_filter] # always include agent_id filter if role_filter: @@ -883,6 +912,8 @@ class TurbopufferClient: all_filters.append(project_filter) if template_filter: all_filters.append(template_filter) + if conversation_filter: + all_filters.append(conversation_filter) if date_filters: all_filters.extend(date_filters) @@ -901,7 +932,7 @@ class TurbopufferClient: query_embedding=query_embedding, query_text=query_text, top_k=top_k, - include_attributes=["text", "organization_id", "agent_id", "role", "created_at"], + include_attributes=["text", "organization_id", "agent_id", "role", "created_at", "conversation_id"], filters=final_filter, vector_weight=vector_weight, fts_weight=fts_weight, @@ -952,6 +983,7 @@ class TurbopufferClient: agent_id: Optional[str] = None, project_id: Optional[str] = None, template_id: Optional[str] = None, + conversation_id: Optional[str] = None, vector_weight: float = 0.5, fts_weight: float = 0.5, start_date: Optional[datetime] = None, @@ -969,6 +1001,10 @@ class TurbopufferClient: agent_id: Optional agent ID to filter messages by project_id: Optional project ID to filter messages by template_id: Optional template ID to filter messages by + conversation_id: Optional conversation ID to filter messages by. Special values: + - None (omitted): Return all messages + - "default": Return only default messages (conversation_id IS NULL) + - Any other value: Return messages in that specific conversation vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) start_date: Optional datetime to filter messages created after this date @@ -1017,6 +1053,18 @@ class TurbopufferClient: if template_id: all_filters.append(("template_id", "Eq", template_id)) + # conversation filter + # three cases: + # 1. conversation_id=None (omitted) -> return all messages (no filter) + # 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility + # 3. conversation_id="xyz" -> return only messages in that conversation + if conversation_id == "default": + # "default" is reserved for default messages only (conversation_id is none) + all_filters.append(("conversation_id", "Eq", None)) + elif conversation_id is not None: + # Specific conversation + all_filters.append(("conversation_id", "Eq", conversation_id)) + # date filters if start_date: # Convert to UTC to match stored timestamps @@ -1049,7 +1097,7 @@ class TurbopufferClient: query_embedding=query_embedding, query_text=query_text, top_k=top_k, - include_attributes=["text", "organization_id", "agent_id", "role", "created_at"], + include_attributes=["text", "organization_id", "agent_id", "role", "created_at", "conversation_id"], filters=final_filter, vector_weight=vector_weight, fts_weight=fts_weight, @@ -1134,6 +1182,7 @@ class TurbopufferClient: "agent_id": getattr(row, "agent_id", None), "role": getattr(row, "role", None), "created_at": getattr(row, "created_at", None), + "conversation_id": getattr(row, "conversation_id", None), } messages.append(message_dict) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 2ebde224..b5d92466 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -211,6 +211,7 @@ class Message(BaseMessage): tool_returns (List[ToolReturn]): The list of tool returns requested. group_id (str): The multi-agent group that the message was sent in. sender_id (str): The id of the sender of the message, can be an identity id or agent id. + conversation_id (str): The conversation this message belongs to. t """ @@ -237,6 +238,7 @@ class Message(BaseMessage): group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in") sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id") batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with") + conversation_id: Optional[str] = Field(default=None, description="The conversation this message belongs to") is_err: Optional[bool] = Field( default=None, description="Whether this message is part of an error step. Used only for debugging purposes." ) @@ -2302,6 +2304,7 @@ class MessageSearchRequest(BaseModel): agent_id: Optional[str] = Field(None, description="Filter messages by agent ID") project_id: Optional[str] = Field(None, description="Filter messages by project ID") template_id: Optional[str] = Field(None, description="Filter messages by template ID") + conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID") limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100) start_date: Optional[datetime] = Field(None, description="Filter messages created after this date") end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date") @@ -2311,6 +2314,7 @@ class SearchAllMessagesRequest(BaseModel): query: str = Field(..., description="Text query for full-text search") search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use") agent_id: Optional[str] = Field(None, description="Filter messages by agent ID") + conversation_id: Optional[str] = Field(None, description="Filter messages by conversation ID") limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100) start_date: Optional[datetime] = Field(None, description="Filter messages created after this date") end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 3421d790..80943e9b 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1385,6 +1385,7 @@ async def list_messages( ), order_by: Literal["created_at"] = Query("created_at", description="Field to sort by"), group_id: str | None = Query(None, description="Group ID to filter messages by."), + conversation_id: str | None = Query(None, description="Conversation ID to filter messages by."), use_assistant_message: bool = Query(True, description="Whether to use assistant messages", deprecated=True), assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool.", deprecated=True), assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument.", deprecated=True), @@ -1404,6 +1405,7 @@ async def list_messages( before=before, limit=limit, group_id=group_id, + conversation_id=conversation_id, reverse=(order == "desc"), return_message_object=False, use_assistant_message=use_assistant_message, @@ -1751,6 +1753,7 @@ async def search_messages( agent_id=request.agent_id, project_id=request.project_id, template_id=request.template_id, + conversation_id=request.conversation_id, limit=request.limit, start_date=request.start_date, end_date=request.end_date, diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index e7e1a6ec..fa622fb2 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -40,6 +40,7 @@ async def list_all_messages( order: Literal["asc", "desc"] = Query( "desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first" ), + conversation_id: Optional[str] = Query(None, description="Conversation ID to filter messages by"), ): """ List messages across all agents for the current user. @@ -51,6 +52,7 @@ async def list_all_messages( limit=limit, reverse=(order == "desc"), return_message_object=False, + conversation_id=conversation_id, actor=actor, ) @@ -74,6 +76,7 @@ async def search_all_messages( query_text=request.query, search_mode=request.search_mode, agent_id=request.agent_id, + conversation_id=request.conversation_id, limit=request.limit, start_date=request.start_date, end_date=request.end_date, diff --git a/letta/server/server.py b/letta/server/server.py index 25dac35f..db089c8c 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -817,6 +817,7 @@ class SyncServer(object): assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, include_err: Optional[bool] = None, + conversation_id: Optional[str] = None, ) -> Union[List[Message], List[LettaMessage]]: records = await self.message_manager.list_messages( agent_id=agent_id, @@ -827,6 +828,7 @@ class SyncServer(object): ascending=not reverse, group_id=group_id, include_err=include_err, + conversation_id=conversation_id, ) if not return_message_object: @@ -862,6 +864,7 @@ class SyncServer(object): assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, include_err: Optional[bool] = None, + conversation_id: Optional[str] = None, ) -> Union[List[Message], List[LettaMessage]]: records = await self.message_manager.list_messages( agent_id=None, @@ -872,6 +875,7 @@ class SyncServer(object): ascending=not reverse, group_id=group_id, include_err=include_err, + conversation_id=conversation_id, ) if not return_message_object: diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index e420d85c..eb68200a 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -412,10 +412,7 @@ class MessageManager: from letta.orm.run import Run as RunModel async with db_registry.async_session() as session: - query = select(RunModel.id).where( - RunModel.id == run_id, - RunModel.organization_id == actor.organization_id - ) + query = select(RunModel.id).where(RunModel.id == run_id, RunModel.organization_id == actor.organization_id) result = await session.execute(query) return result.scalar_one_or_none() is not None @@ -545,10 +542,7 @@ class MessageManager: async with db_registry.async_session() as session: # Check which run_ids actually exist - query = select(RunModel.id).where( - RunModel.id.in_(unique_run_ids), - RunModel.organization_id == actor.organization_id - ) + query = select(RunModel.id).where(RunModel.id.in_(unique_run_ids), RunModel.organization_id == actor.organization_id) result = await session.execute(query) existing_run_ids = set(result.scalars().all()) @@ -622,6 +616,7 @@ class MessageManager: message_ids = [] roles = [] created_ats = [] + conversation_ids = [] # combine assistant+tool messages before embedding combined_messages = self._combine_assistant_tool_messages(messages) @@ -633,6 +628,7 @@ class MessageManager: message_ids.append(msg.id) roles.append(msg.role) created_ats.append(msg.created_at) + conversation_ids.append(msg.conversation_id) if message_texts: # insert to turbopuffer - TurbopufferClient will generate embeddings internally @@ -647,6 +643,7 @@ class MessageManager: created_ats=created_ats, project_id=project_id, template_id=template_id, + conversation_ids=conversation_ids, ) logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}") except Exception as e: @@ -776,6 +773,7 @@ class MessageManager: created_ats=[message.created_at], project_id=project_id, template_id=template_id, + conversation_ids=[message.conversation_id], ) logger.info(f"Successfully updated message {message.id} in Turbopuffer") except Exception as e: @@ -896,6 +894,7 @@ class MessageManager: group_id: Optional[str] = None, include_err: Optional[bool] = None, run_id: Optional[str] = None, + conversation_id: Optional[str] = None, ) -> List[PydanticMessage]: """ Most performant query to list messages by directly querying the Message table. @@ -917,6 +916,7 @@ class MessageManager: group_id: Optional group ID to filter messages by group_id. include_err: Optional boolean to include errors and error statuses. Used for debugging only. run_id: Optional run ID to filter messages by run_id. + conversation_id: Optional conversation ID to filter messages by conversation_id. Returns: List[PydanticMessage]: A list of messages (converted via .to_pydantic()). @@ -942,6 +942,9 @@ class MessageManager: if run_id: query = query.where(MessageModel.run_id == run_id) + if conversation_id: + query = query.where(MessageModel.conversation_id == conversation_id) + # if not include_err: # query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None))) @@ -1233,6 +1236,7 @@ class MessageManager: agent_id: Optional[str] = None, project_id: Optional[str] = None, template_id: Optional[str] = None, + conversation_id: Optional[str] = None, limit: int = 50, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, @@ -1248,6 +1252,7 @@ class MessageManager: agent_id: Optional agent ID to filter messages by project_id: Optional project ID to filter messages by template_id: Optional template ID to filter messages by + conversation_id: Optional conversation ID to filter messages by limit: Maximum number of results to return start_date: Optional filter for messages created after this date end_date: Optional filter for messages created on or before this date (inclusive) @@ -1277,6 +1282,7 @@ class MessageManager: agent_id=agent_id, project_id=project_id, template_id=template_id, + conversation_id=conversation_id, start_date=start_date, end_date=end_date, ) diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index 859f91f8..031d67db 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -2167,6 +2167,122 @@ async def test_message_template_id_filtering(server, sarah_agent, default_user, ) +@pytest.mark.asyncio +@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") +async def test_message_conversation_id_filtering(server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding): + """Test that conversation_id filtering works correctly in message queries, including 'default' sentinel""" + from letta.schemas.conversation import CreateConversation + from letta.schemas.letta_message_content import TextContent + from letta.services.conversation_manager import ConversationManager + + conversation_manager = ConversationManager() + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test conversation"), + actor=default_user, + ) + + # Create messages with different conversation_ids + message_with_conv = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Message in specific conversation about Python")], + ) + + message_default_conv = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Message in default conversation about JavaScript")], + ) + + # Insert messages with their respective conversation IDs + tpuf_client = TurbopufferClient() + + # Message with specific conversation_id + await tpuf_client.insert_messages( + agent_id=sarah_agent.id, + message_texts=[message_with_conv.content[0].text], + message_ids=[message_with_conv.id], + organization_id=default_user.organization_id, + actor=default_user, + roles=[message_with_conv.role], + created_ats=[message_with_conv.created_at], + conversation_ids=[conversation.id], # Specific conversation + ) + + # Message with no conversation_id (default) + await tpuf_client.insert_messages( + agent_id=sarah_agent.id, + message_texts=[message_default_conv.content[0].text], + message_ids=[message_default_conv.id], + organization_id=default_user.organization_id, + actor=default_user, + roles=[message_default_conv.role], + created_ats=[message_default_conv.created_at], + conversation_ids=[None], # Default conversation (NULL) + ) + + # Wait for indexing + await asyncio.sleep(1) + + # Test 1: Query for specific conversation - should find only message with that conversation_id + results_conv = await tpuf_client.query_messages_by_agent_id( + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + search_mode="timestamp", + top_k=10, + conversation_id=conversation.id, + actor=default_user, + ) + + assert len(results_conv) == 1 + assert results_conv[0][0]["id"] == message_with_conv.id + assert "Python" in results_conv[0][0]["text"] + + # Test 2: Query for "default" conversation - should find only messages with NULL conversation_id + results_default = await tpuf_client.query_messages_by_agent_id( + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + search_mode="timestamp", + top_k=10, + conversation_id="default", # Sentinel for NULL + actor=default_user, + ) + + assert len(results_default) >= 1 # May have other default messages from setup + # Check our message is in there + default_ids = [r[0]["id"] for r in results_default] + assert message_default_conv.id in default_ids + + # Verify the message content + for msg_dict, _, _ in results_default: + if msg_dict["id"] == message_default_conv.id: + assert "JavaScript" in msg_dict["text"] + break + + # Test 3: Query without conversation filter - should find both + results_all = await tpuf_client.query_messages_by_agent_id( + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + search_mode="timestamp", + top_k=10, + conversation_id=None, # No filter + actor=default_user, + ) + + assert len(results_all) >= 2 # May have other messages from setup + message_ids = [r[0]["id"] for r in results_all] + assert message_with_conv.id in message_ids + assert message_default_conv.id in message_ids + + # Clean up + await tpuf_client.delete_messages( + agent_id=sarah_agent.id, organization_id=default_user.organization_id, message_ids=[message_with_conv.id, message_default_conv.id] + ) + + @pytest.mark.asyncio async def test_system_messages_not_embedded_during_agent_creation(server, default_user, enable_message_embedding): """Test that system messages are filtered out before being passed to the embedding pipeline during agent creation""" diff --git a/tests/managers/test_message_manager.py b/tests/managers/test_message_manager.py index b62b32a8..8bea9347 100644 --- a/tests/managers/test_message_manager.py +++ b/tests/managers/test_message_manager.py @@ -412,6 +412,59 @@ async def test_message_delete(server: SyncServer, hello_world_message_fixture, d assert retrieved is None +@pytest.mark.asyncio +async def test_message_conversation_id_persistence(server: SyncServer, sarah_agent, default_user): + """Test that conversation_id is properly persisted and retrieved from DB to Pydantic object""" + from letta.schemas.conversation import CreateConversation + from letta.services.conversation_manager import ConversationManager + + conversation_manager = ConversationManager() + + # Test 1: Create a message without conversation_id (should be None - the default/backward-compat case) + message_no_conv = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Test message without conversation")], + ) + + created_no_conv = await server.message_manager.create_many_messages_async([message_no_conv], actor=default_user) + assert len(created_no_conv) == 1 + assert created_no_conv[0].conversation_id is None + + # Verify retrieval also has None - this confirms ORM-to-Pydantic conversion works for None + retrieved_no_conv = await server.message_manager.get_message_by_id_async(created_no_conv[0].id, actor=default_user) + assert retrieved_no_conv is not None + assert retrieved_no_conv.conversation_id is None + assert retrieved_no_conv.id == created_no_conv[0].id + + # Test 2: Create a conversation and a message with that conversation_id + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test conversation"), + actor=default_user, + ) + + message_with_conv = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Test message with conversation")], + conversation_id=conversation.id, + ) + + created_with_conv = await server.message_manager.create_many_messages_async([message_with_conv], actor=default_user) + assert len(created_with_conv) == 1 + assert created_with_conv[0].conversation_id == conversation.id + + # Verify retrieval has the correct conversation_id - this confirms ORM-to-Pydantic conversion works for non-None + retrieved_with_conv = await server.message_manager.get_message_by_id_async(created_with_conv[0].id, actor=default_user) + assert retrieved_with_conv is not None + assert retrieved_with_conv.conversation_id == conversation.id + assert retrieved_with_conv.id == created_with_conv[0].id + + # Test 3: Verify the field exists on the Pydantic model + assert hasattr(retrieved_with_conv, "conversation_id") + + @pytest.mark.asyncio async def test_message_size(server: SyncServer, hello_world_message_fixture, default_user): """Test counting messages with filters"""