From 2f716d49617c2f510c1d6a4bdafbb9c33bc87e75 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 19 Sep 2025 18:39:22 -0700 Subject: [PATCH] feat: Add partial flag on `create_many_messages_async` (#4836) Add partial flag --- letta/services/message_manager.py | 85 ++++++++++- tests/test_managers.py | 242 ++++++++++++++++++++++++++++++ 2 files changed, 323 insertions(+), 4 deletions(-) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index f5b7f7be..9c6bfa1d 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,7 +1,7 @@ import json import uuid from datetime import datetime -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Set, Tuple from sqlalchemy import delete, exists, func, select, text @@ -307,6 +307,51 @@ class MessageManager: created_messages = MessageModel.batch_create(orm_messages, session, actor=actor) return [msg.to_pydantic() for msg in created_messages] + @enforce_types + @trace_method + async def check_existing_message_ids(self, message_ids: List[str], actor: PydanticUser) -> Set[str]: + """Check which message IDs already exist in the database. + + Args: + message_ids: List of message IDs to check + actor: User performing the action + + Returns: + Set of message IDs that already exist in the database + """ + if not message_ids: + return set() + + async with db_registry.async_session() as session: + query = select(MessageModel.id).where(MessageModel.id.in_(message_ids), MessageModel.organization_id == actor.organization_id) + result = await session.execute(query) + return set(result.scalars().all()) + + @enforce_types + @trace_method + async def filter_existing_messages( + self, messages: List[PydanticMessage], actor: PydanticUser + ) -> Tuple[List[PydanticMessage], List[PydanticMessage]]: + """Filter messages into new and existing based on their IDs. + + Args: + messages: List of messages to filter + actor: User performing the action + + Returns: + Tuple of (new_messages, existing_messages) + """ + message_ids = [msg.id for msg in messages if msg.id] + if not message_ids: + return messages, [] + + existing_ids = await self.check_existing_message_ids(message_ids, actor) + + new_messages = [msg for msg in messages if msg.id not in existing_ids] + existing_messages = [msg for msg in messages if msg.id in existing_ids] + + return new_messages, existing_messages + @enforce_types @trace_method async def create_many_messages_async( @@ -316,6 +361,7 @@ class MessageManager: strict_mode: bool = False, project_id: Optional[str] = None, template_id: Optional[str] = None, + allow_partial: bool = False, ) -> List[PydanticMessage]: """ Create multiple messages in a single database transaction asynchronously. @@ -326,14 +372,33 @@ class MessageManager: strict_mode: If True, wait for embedding to complete; if False, run in background project_id: Optional project ID for the messages (for Turbopuffer indexing) template_id: Optional template ID for the messages (for Turbopuffer indexing) + allow_partial: If True, skip messages that already exist; if False, fail on duplicates Returns: - List of created Pydantic message models + List of created Pydantic message models (and existing ones if allow_partial=True) """ if not pydantic_msgs: return [] - for message in pydantic_msgs: + messages_to_create = pydantic_msgs + existing_messages = [] + + if allow_partial: + # filter out messages that already exist + new_messages, existing_messages = await self.filter_existing_messages(pydantic_msgs, actor) + messages_to_create = new_messages + + if not messages_to_create: + # all messages already exist, fetch and return them + async with db_registry.async_session() as session: + existing_ids = [msg.id for msg in existing_messages if msg.id] + query = select(MessageModel).where( + MessageModel.id.in_(existing_ids), MessageModel.organization_id == actor.organization_id + ) + result = await session.execute(query) + return [msg.to_pydantic() for msg in result.scalars()] + + for message in messages_to_create: if isinstance(message.content, list): for content in message.content: if content.type == MessageContentType.image and content.source.type == ImageSourceType.base64: @@ -358,7 +423,7 @@ class MessageManager: media_type=content.source.media_type, detail=content.source.detail, ) - orm_messages = self._create_many_preprocess(pydantic_msgs, actor) + orm_messages = self._create_many_preprocess(messages_to_create, actor) async with db_registry.async_session() as session: created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True) result = [msg.to_pydantic() for msg in created_messages] @@ -381,6 +446,18 @@ class MessageManager: task_name=f"embed_messages_for_agent_{agent_id}", ) + # if allow_partial, combine newly created with existing + if allow_partial and existing_messages: + # fetch the existing messages to return complete data + async with db_registry.async_session() as session: + existing_ids = [msg.id for msg in existing_messages if msg.id] + query = select(MessageModel).where( + MessageModel.id.in_(existing_ids), MessageModel.organization_id == actor.organization_id + ) + existing_result = await session.execute(query) + existing_fetched = [msg.to_pydantic() for msg in existing_result.scalars()] + result.extend(existing_fetched) + return result async def _embed_messages_background( diff --git a/tests/test_managers.py b/tests/test_managers.py index 17ec2934..a4d6a654 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5845,6 +5845,248 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix assert len(search_results) == 0 +@pytest.mark.asyncio +async def test_create_many_messages_async_basic(server: SyncServer, sarah_agent, default_user): + """Test basic batch creation of messages""" + message_manager = server.message_manager + + messages = [] + for i in range(5): + msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"Test message {i}")], + name=None, + tool_calls=None, + tool_call_id=None, + ) + messages.append(msg) + + created_messages = await message_manager.create_many_messages_async(pydantic_msgs=messages, actor=default_user) + + assert len(created_messages) == 5 + for i, msg in enumerate(created_messages): + assert msg.id is not None + assert msg.content[0].text == f"Test message {i}" + assert msg.agent_id == sarah_agent.id + + +@pytest.mark.asyncio +async def test_create_many_messages_async_allow_partial_false(server: SyncServer, sarah_agent, default_user): + """Test that allow_partial=False (default) fails on duplicate IDs""" + message_manager = server.message_manager + + initial_msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Initial message")], + ) + + created = await message_manager.create_many_messages_async(pydantic_msgs=[initial_msg], actor=default_user) + assert len(created) == 1 + created_msg = created[0] + + duplicate_msg = PydanticMessage( + id=created_msg.id, + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Duplicate message")], + ) + + with pytest.raises(UniqueConstraintViolationError): + await message_manager.create_many_messages_async(pydantic_msgs=[duplicate_msg], actor=default_user, allow_partial=False) + + +@pytest.mark.asyncio +async def test_create_many_messages_async_allow_partial_true_some_duplicates(server: SyncServer, sarah_agent, default_user): + """Test that allow_partial=True handles partial duplicates correctly""" + message_manager = server.message_manager + + initial_messages = [] + for i in range(3): + msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"Existing message {i}")], + ) + initial_messages.append(msg) + + created_initial = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user) + assert len(created_initial) == 3 + existing_ids = [msg.id for msg in created_initial] + + mixed_messages = [] + for created_msg in created_initial: + duplicate_msg = PydanticMessage( + id=created_msg.id, + agent_id=sarah_agent.id, + role=MessageRole.user, + content=created_msg.content, + ) + mixed_messages.append(duplicate_msg) + for i in range(3, 6): + msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"New message {i}")], + ) + mixed_messages.append(msg) + + result = await message_manager.create_many_messages_async(pydantic_msgs=mixed_messages, actor=default_user, allow_partial=True) + + assert len(result) == 6 + + result_ids = {msg.id for msg in result} + for existing_id in existing_ids: + assert existing_id in result_ids + + +@pytest.mark.asyncio +async def test_create_many_messages_async_allow_partial_true_all_duplicates(server: SyncServer, sarah_agent, default_user): + """Test that allow_partial=True handles all duplicates correctly""" + message_manager = server.message_manager + + initial_messages = [] + for i in range(3): + msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"Message {i}")], + ) + initial_messages.append(msg) + + created_initial = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user) + assert len(created_initial) == 3 + + duplicate_messages = [] + for created_msg in created_initial: + duplicate_msg = PydanticMessage( + id=created_msg.id, + agent_id=sarah_agent.id, + role=MessageRole.user, + content=created_msg.content, + ) + duplicate_messages.append(duplicate_msg) + + result = await message_manager.create_many_messages_async(pydantic_msgs=duplicate_messages, actor=default_user, allow_partial=True) + + assert len(result) == 3 + for i, msg in enumerate(result): + assert msg.id == created_initial[i].id + assert msg.content[0].text == f"Message {i}" + + +@pytest.mark.asyncio +async def test_create_many_messages_async_empty_list(server: SyncServer, default_user): + """Test that empty list returns empty list""" + message_manager = server.message_manager + + result = await message_manager.create_many_messages_async(pydantic_msgs=[], actor=default_user) + + assert result == [] + + +@pytest.mark.asyncio +async def test_check_existing_message_ids(server: SyncServer, sarah_agent, default_user): + """Test the check_existing_message_ids convenience function""" + message_manager = server.message_manager + + messages = [] + for i in range(3): + msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"Message {i}")], + ) + messages.append(msg) + + created_messages = await message_manager.create_many_messages_async(pydantic_msgs=messages, actor=default_user) + existing_ids = [msg.id for msg in created_messages] + + non_existent_ids = [f"message-{uuid.uuid4().hex[:8]}" for _ in range(3)] + all_ids = existing_ids + non_existent_ids + + existing = await message_manager.check_existing_message_ids(message_ids=all_ids, actor=default_user) + + assert existing == set(existing_ids) + for non_existent_id in non_existent_ids: + assert non_existent_id not in existing + + +@pytest.mark.asyncio +async def test_filter_existing_messages(server: SyncServer, sarah_agent, default_user): + """Test the filter_existing_messages helper function""" + message_manager = server.message_manager + + initial_messages = [] + for i in range(3): + msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"Existing {i}")], + ) + initial_messages.append(msg) + + created_existing = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user) + + existing_messages = [] + for created_msg in created_existing: + msg = PydanticMessage( + id=created_msg.id, + agent_id=sarah_agent.id, + role=MessageRole.user, + content=created_msg.content, + ) + existing_messages.append(msg) + + new_messages = [] + for i in range(3): + msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"New {i}")], + ) + new_messages.append(msg) + + all_messages = existing_messages + new_messages + + new_filtered, existing_filtered = await message_manager.filter_existing_messages(messages=all_messages, actor=default_user) + + assert len(new_filtered) == 3 + assert len(existing_filtered) == 3 + + existing_filtered_ids = {msg.id for msg in existing_filtered} + for created_msg in created_existing: + assert created_msg.id in existing_filtered_ids + + for msg in new_filtered: + assert msg.id not in existing_filtered_ids + + +@pytest.mark.asyncio +async def test_create_many_messages_async_with_turbopuffer(server: SyncServer, sarah_agent, default_user): + """Test batch creation with turbopuffer embedding (if enabled)""" + message_manager = server.message_manager + + messages = [] + for i in range(3): + msg = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text=f"Important information about topic {i}")], + ) + messages.append(msg) + + created_messages = await message_manager.create_many_messages_async( + pydantic_msgs=messages, actor=default_user, strict_mode=True, project_id="test_project", template_id="test_template" + ) + + assert len(created_messages) == 3 + for msg in created_messages: + assert msg.id is not None + assert msg.agent_id == sarah_agent.id + + # ====================================================================================================================== # Block Manager Tests - Basic # ======================================================================================================================