feat: Add partial flag on create_many_messages_async (#4836)

Add partial flag
This commit is contained in:
Matthew Zhou
2025-09-19 18:39:22 -07:00
committed by Caren Thomas
parent 3593e7cda6
commit 2f716d4961
2 changed files with 323 additions and 4 deletions

View File

@@ -1,7 +1,7 @@
import json
import uuid
from datetime import datetime
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Set, Tuple
from sqlalchemy import delete, exists, func, select, text
@@ -307,6 +307,51 @@ class MessageManager:
created_messages = MessageModel.batch_create(orm_messages, session, actor=actor)
return [msg.to_pydantic() for msg in created_messages]
@enforce_types
@trace_method
async def check_existing_message_ids(self, message_ids: List[str], actor: PydanticUser) -> Set[str]:
"""Check which message IDs already exist in the database.
Args:
message_ids: List of message IDs to check
actor: User performing the action
Returns:
Set of message IDs that already exist in the database
"""
if not message_ids:
return set()
async with db_registry.async_session() as session:
query = select(MessageModel.id).where(MessageModel.id.in_(message_ids), MessageModel.organization_id == actor.organization_id)
result = await session.execute(query)
return set(result.scalars().all())
@enforce_types
@trace_method
async def filter_existing_messages(
self, messages: List[PydanticMessage], actor: PydanticUser
) -> Tuple[List[PydanticMessage], List[PydanticMessage]]:
"""Filter messages into new and existing based on their IDs.
Args:
messages: List of messages to filter
actor: User performing the action
Returns:
Tuple of (new_messages, existing_messages)
"""
message_ids = [msg.id for msg in messages if msg.id]
if not message_ids:
return messages, []
existing_ids = await self.check_existing_message_ids(message_ids, actor)
new_messages = [msg for msg in messages if msg.id not in existing_ids]
existing_messages = [msg for msg in messages if msg.id in existing_ids]
return new_messages, existing_messages
@enforce_types
@trace_method
async def create_many_messages_async(
@@ -316,6 +361,7 @@ class MessageManager:
strict_mode: bool = False,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
allow_partial: bool = False,
) -> List[PydanticMessage]:
"""
Create multiple messages in a single database transaction asynchronously.
@@ -326,14 +372,33 @@ class MessageManager:
strict_mode: If True, wait for embedding to complete; if False, run in background
project_id: Optional project ID for the messages (for Turbopuffer indexing)
template_id: Optional template ID for the messages (for Turbopuffer indexing)
allow_partial: If True, skip messages that already exist; if False, fail on duplicates
Returns:
List of created Pydantic message models
List of created Pydantic message models (and existing ones if allow_partial=True)
"""
if not pydantic_msgs:
return []
for message in pydantic_msgs:
messages_to_create = pydantic_msgs
existing_messages = []
if allow_partial:
# filter out messages that already exist
new_messages, existing_messages = await self.filter_existing_messages(pydantic_msgs, actor)
messages_to_create = new_messages
if not messages_to_create:
# all messages already exist, fetch and return them
async with db_registry.async_session() as session:
existing_ids = [msg.id for msg in existing_messages if msg.id]
query = select(MessageModel).where(
MessageModel.id.in_(existing_ids), MessageModel.organization_id == actor.organization_id
)
result = await session.execute(query)
return [msg.to_pydantic() for msg in result.scalars()]
for message in messages_to_create:
if isinstance(message.content, list):
for content in message.content:
if content.type == MessageContentType.image and content.source.type == ImageSourceType.base64:
@@ -358,7 +423,7 @@ class MessageManager:
media_type=content.source.media_type,
detail=content.source.detail,
)
orm_messages = self._create_many_preprocess(pydantic_msgs, actor)
orm_messages = self._create_many_preprocess(messages_to_create, actor)
async with db_registry.async_session() as session:
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True)
result = [msg.to_pydantic() for msg in created_messages]
@@ -381,6 +446,18 @@ class MessageManager:
task_name=f"embed_messages_for_agent_{agent_id}",
)
# if allow_partial, combine newly created with existing
if allow_partial and existing_messages:
# fetch the existing messages to return complete data
async with db_registry.async_session() as session:
existing_ids = [msg.id for msg in existing_messages if msg.id]
query = select(MessageModel).where(
MessageModel.id.in_(existing_ids), MessageModel.organization_id == actor.organization_id
)
existing_result = await session.execute(query)
existing_fetched = [msg.to_pydantic() for msg in existing_result.scalars()]
result.extend(existing_fetched)
return result
async def _embed_messages_background(

View File

@@ -5845,6 +5845,248 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix
assert len(search_results) == 0
@pytest.mark.asyncio
async def test_create_many_messages_async_basic(server: SyncServer, sarah_agent, default_user):
"""Test basic batch creation of messages"""
message_manager = server.message_manager
messages = []
for i in range(5):
msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"Test message {i}")],
name=None,
tool_calls=None,
tool_call_id=None,
)
messages.append(msg)
created_messages = await message_manager.create_many_messages_async(pydantic_msgs=messages, actor=default_user)
assert len(created_messages) == 5
for i, msg in enumerate(created_messages):
assert msg.id is not None
assert msg.content[0].text == f"Test message {i}"
assert msg.agent_id == sarah_agent.id
@pytest.mark.asyncio
async def test_create_many_messages_async_allow_partial_false(server: SyncServer, sarah_agent, default_user):
"""Test that allow_partial=False (default) fails on duplicate IDs"""
message_manager = server.message_manager
initial_msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text="Initial message")],
)
created = await message_manager.create_many_messages_async(pydantic_msgs=[initial_msg], actor=default_user)
assert len(created) == 1
created_msg = created[0]
duplicate_msg = PydanticMessage(
id=created_msg.id,
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text="Duplicate message")],
)
with pytest.raises(UniqueConstraintViolationError):
await message_manager.create_many_messages_async(pydantic_msgs=[duplicate_msg], actor=default_user, allow_partial=False)
@pytest.mark.asyncio
async def test_create_many_messages_async_allow_partial_true_some_duplicates(server: SyncServer, sarah_agent, default_user):
"""Test that allow_partial=True handles partial duplicates correctly"""
message_manager = server.message_manager
initial_messages = []
for i in range(3):
msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"Existing message {i}")],
)
initial_messages.append(msg)
created_initial = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user)
assert len(created_initial) == 3
existing_ids = [msg.id for msg in created_initial]
mixed_messages = []
for created_msg in created_initial:
duplicate_msg = PydanticMessage(
id=created_msg.id,
agent_id=sarah_agent.id,
role=MessageRole.user,
content=created_msg.content,
)
mixed_messages.append(duplicate_msg)
for i in range(3, 6):
msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"New message {i}")],
)
mixed_messages.append(msg)
result = await message_manager.create_many_messages_async(pydantic_msgs=mixed_messages, actor=default_user, allow_partial=True)
assert len(result) == 6
result_ids = {msg.id for msg in result}
for existing_id in existing_ids:
assert existing_id in result_ids
@pytest.mark.asyncio
async def test_create_many_messages_async_allow_partial_true_all_duplicates(server: SyncServer, sarah_agent, default_user):
"""Test that allow_partial=True handles all duplicates correctly"""
message_manager = server.message_manager
initial_messages = []
for i in range(3):
msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"Message {i}")],
)
initial_messages.append(msg)
created_initial = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user)
assert len(created_initial) == 3
duplicate_messages = []
for created_msg in created_initial:
duplicate_msg = PydanticMessage(
id=created_msg.id,
agent_id=sarah_agent.id,
role=MessageRole.user,
content=created_msg.content,
)
duplicate_messages.append(duplicate_msg)
result = await message_manager.create_many_messages_async(pydantic_msgs=duplicate_messages, actor=default_user, allow_partial=True)
assert len(result) == 3
for i, msg in enumerate(result):
assert msg.id == created_initial[i].id
assert msg.content[0].text == f"Message {i}"
@pytest.mark.asyncio
async def test_create_many_messages_async_empty_list(server: SyncServer, default_user):
"""Test that empty list returns empty list"""
message_manager = server.message_manager
result = await message_manager.create_many_messages_async(pydantic_msgs=[], actor=default_user)
assert result == []
@pytest.mark.asyncio
async def test_check_existing_message_ids(server: SyncServer, sarah_agent, default_user):
"""Test the check_existing_message_ids convenience function"""
message_manager = server.message_manager
messages = []
for i in range(3):
msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"Message {i}")],
)
messages.append(msg)
created_messages = await message_manager.create_many_messages_async(pydantic_msgs=messages, actor=default_user)
existing_ids = [msg.id for msg in created_messages]
non_existent_ids = [f"message-{uuid.uuid4().hex[:8]}" for _ in range(3)]
all_ids = existing_ids + non_existent_ids
existing = await message_manager.check_existing_message_ids(message_ids=all_ids, actor=default_user)
assert existing == set(existing_ids)
for non_existent_id in non_existent_ids:
assert non_existent_id not in existing
@pytest.mark.asyncio
async def test_filter_existing_messages(server: SyncServer, sarah_agent, default_user):
"""Test the filter_existing_messages helper function"""
message_manager = server.message_manager
initial_messages = []
for i in range(3):
msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"Existing {i}")],
)
initial_messages.append(msg)
created_existing = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user)
existing_messages = []
for created_msg in created_existing:
msg = PydanticMessage(
id=created_msg.id,
agent_id=sarah_agent.id,
role=MessageRole.user,
content=created_msg.content,
)
existing_messages.append(msg)
new_messages = []
for i in range(3):
msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"New {i}")],
)
new_messages.append(msg)
all_messages = existing_messages + new_messages
new_filtered, existing_filtered = await message_manager.filter_existing_messages(messages=all_messages, actor=default_user)
assert len(new_filtered) == 3
assert len(existing_filtered) == 3
existing_filtered_ids = {msg.id for msg in existing_filtered}
for created_msg in created_existing:
assert created_msg.id in existing_filtered_ids
for msg in new_filtered:
assert msg.id not in existing_filtered_ids
@pytest.mark.asyncio
async def test_create_many_messages_async_with_turbopuffer(server: SyncServer, sarah_agent, default_user):
"""Test batch creation with turbopuffer embedding (if enabled)"""
message_manager = server.message_manager
messages = []
for i in range(3):
msg = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"Important information about topic {i}")],
)
messages.append(msg)
created_messages = await message_manager.create_many_messages_async(
pydantic_msgs=messages, actor=default_user, strict_mode=True, project_id="test_project", template_id="test_template"
)
assert len(created_messages) == 3
for msg in created_messages:
assert msg.id is not None
assert msg.agent_id == sarah_agent.id
# ======================================================================================================================
# Block Manager Tests - Basic
# ======================================================================================================================