feat: Add partial flag on create_many_messages_async (#4836)
Add partial flag
This commit is contained in:
committed by
Caren Thomas
parent
3593e7cda6
commit
2f716d4961
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from typing import List, Optional, Sequence, Set, Tuple
|
||||
|
||||
from sqlalchemy import delete, exists, func, select, text
|
||||
|
||||
@@ -307,6 +307,51 @@ class MessageManager:
|
||||
created_messages = MessageModel.batch_create(orm_messages, session, actor=actor)
|
||||
return [msg.to_pydantic() for msg in created_messages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def check_existing_message_ids(self, message_ids: List[str], actor: PydanticUser) -> Set[str]:
|
||||
"""Check which message IDs already exist in the database.
|
||||
|
||||
Args:
|
||||
message_ids: List of message IDs to check
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
Set of message IDs that already exist in the database
|
||||
"""
|
||||
if not message_ids:
|
||||
return set()
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(MessageModel.id).where(MessageModel.id.in_(message_ids), MessageModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(query)
|
||||
return set(result.scalars().all())
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def filter_existing_messages(
|
||||
self, messages: List[PydanticMessage], actor: PydanticUser
|
||||
) -> Tuple[List[PydanticMessage], List[PydanticMessage]]:
|
||||
"""Filter messages into new and existing based on their IDs.
|
||||
|
||||
Args:
|
||||
messages: List of messages to filter
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
Tuple of (new_messages, existing_messages)
|
||||
"""
|
||||
message_ids = [msg.id for msg in messages if msg.id]
|
||||
if not message_ids:
|
||||
return messages, []
|
||||
|
||||
existing_ids = await self.check_existing_message_ids(message_ids, actor)
|
||||
|
||||
new_messages = [msg for msg in messages if msg.id not in existing_ids]
|
||||
existing_messages = [msg for msg in messages if msg.id in existing_ids]
|
||||
|
||||
return new_messages, existing_messages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_many_messages_async(
|
||||
@@ -316,6 +361,7 @@ class MessageManager:
|
||||
strict_mode: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
allow_partial: bool = False,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Create multiple messages in a single database transaction asynchronously.
|
||||
@@ -326,14 +372,33 @@ class MessageManager:
|
||||
strict_mode: If True, wait for embedding to complete; if False, run in background
|
||||
project_id: Optional project ID for the messages (for Turbopuffer indexing)
|
||||
template_id: Optional template ID for the messages (for Turbopuffer indexing)
|
||||
allow_partial: If True, skip messages that already exist; if False, fail on duplicates
|
||||
|
||||
Returns:
|
||||
List of created Pydantic message models
|
||||
List of created Pydantic message models (and existing ones if allow_partial=True)
|
||||
"""
|
||||
if not pydantic_msgs:
|
||||
return []
|
||||
|
||||
for message in pydantic_msgs:
|
||||
messages_to_create = pydantic_msgs
|
||||
existing_messages = []
|
||||
|
||||
if allow_partial:
|
||||
# filter out messages that already exist
|
||||
new_messages, existing_messages = await self.filter_existing_messages(pydantic_msgs, actor)
|
||||
messages_to_create = new_messages
|
||||
|
||||
if not messages_to_create:
|
||||
# all messages already exist, fetch and return them
|
||||
async with db_registry.async_session() as session:
|
||||
existing_ids = [msg.id for msg in existing_messages if msg.id]
|
||||
query = select(MessageModel).where(
|
||||
MessageModel.id.in_(existing_ids), MessageModel.organization_id == actor.organization_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return [msg.to_pydantic() for msg in result.scalars()]
|
||||
|
||||
for message in messages_to_create:
|
||||
if isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if content.type == MessageContentType.image and content.source.type == ImageSourceType.base64:
|
||||
@@ -358,7 +423,7 @@ class MessageManager:
|
||||
media_type=content.source.media_type,
|
||||
detail=content.source.detail,
|
||||
)
|
||||
orm_messages = self._create_many_preprocess(pydantic_msgs, actor)
|
||||
orm_messages = self._create_many_preprocess(messages_to_create, actor)
|
||||
async with db_registry.async_session() as session:
|
||||
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True)
|
||||
result = [msg.to_pydantic() for msg in created_messages]
|
||||
@@ -381,6 +446,18 @@ class MessageManager:
|
||||
task_name=f"embed_messages_for_agent_{agent_id}",
|
||||
)
|
||||
|
||||
# if allow_partial, combine newly created with existing
|
||||
if allow_partial and existing_messages:
|
||||
# fetch the existing messages to return complete data
|
||||
async with db_registry.async_session() as session:
|
||||
existing_ids = [msg.id for msg in existing_messages if msg.id]
|
||||
query = select(MessageModel).where(
|
||||
MessageModel.id.in_(existing_ids), MessageModel.organization_id == actor.organization_id
|
||||
)
|
||||
existing_result = await session.execute(query)
|
||||
existing_fetched = [msg.to_pydantic() for msg in existing_result.scalars()]
|
||||
result.extend(existing_fetched)
|
||||
|
||||
return result
|
||||
|
||||
async def _embed_messages_background(
|
||||
|
||||
@@ -5845,6 +5845,248 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix
|
||||
assert len(search_results) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_many_messages_async_basic(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test basic batch creation of messages"""
|
||||
message_manager = server.message_manager
|
||||
|
||||
messages = []
|
||||
for i in range(5):
|
||||
msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"Test message {i}")],
|
||||
name=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
)
|
||||
messages.append(msg)
|
||||
|
||||
created_messages = await message_manager.create_many_messages_async(pydantic_msgs=messages, actor=default_user)
|
||||
|
||||
assert len(created_messages) == 5
|
||||
for i, msg in enumerate(created_messages):
|
||||
assert msg.id is not None
|
||||
assert msg.content[0].text == f"Test message {i}"
|
||||
assert msg.agent_id == sarah_agent.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_many_messages_async_allow_partial_false(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that allow_partial=False (default) fails on duplicate IDs"""
|
||||
message_manager = server.message_manager
|
||||
|
||||
initial_msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Initial message")],
|
||||
)
|
||||
|
||||
created = await message_manager.create_many_messages_async(pydantic_msgs=[initial_msg], actor=default_user)
|
||||
assert len(created) == 1
|
||||
created_msg = created[0]
|
||||
|
||||
duplicate_msg = PydanticMessage(
|
||||
id=created_msg.id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Duplicate message")],
|
||||
)
|
||||
|
||||
with pytest.raises(UniqueConstraintViolationError):
|
||||
await message_manager.create_many_messages_async(pydantic_msgs=[duplicate_msg], actor=default_user, allow_partial=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_many_messages_async_allow_partial_true_some_duplicates(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that allow_partial=True handles partial duplicates correctly"""
|
||||
message_manager = server.message_manager
|
||||
|
||||
initial_messages = []
|
||||
for i in range(3):
|
||||
msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"Existing message {i}")],
|
||||
)
|
||||
initial_messages.append(msg)
|
||||
|
||||
created_initial = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user)
|
||||
assert len(created_initial) == 3
|
||||
existing_ids = [msg.id for msg in created_initial]
|
||||
|
||||
mixed_messages = []
|
||||
for created_msg in created_initial:
|
||||
duplicate_msg = PydanticMessage(
|
||||
id=created_msg.id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=created_msg.content,
|
||||
)
|
||||
mixed_messages.append(duplicate_msg)
|
||||
for i in range(3, 6):
|
||||
msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"New message {i}")],
|
||||
)
|
||||
mixed_messages.append(msg)
|
||||
|
||||
result = await message_manager.create_many_messages_async(pydantic_msgs=mixed_messages, actor=default_user, allow_partial=True)
|
||||
|
||||
assert len(result) == 6
|
||||
|
||||
result_ids = {msg.id for msg in result}
|
||||
for existing_id in existing_ids:
|
||||
assert existing_id in result_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_many_messages_async_allow_partial_true_all_duplicates(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that allow_partial=True handles all duplicates correctly"""
|
||||
message_manager = server.message_manager
|
||||
|
||||
initial_messages = []
|
||||
for i in range(3):
|
||||
msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
initial_messages.append(msg)
|
||||
|
||||
created_initial = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user)
|
||||
assert len(created_initial) == 3
|
||||
|
||||
duplicate_messages = []
|
||||
for created_msg in created_initial:
|
||||
duplicate_msg = PydanticMessage(
|
||||
id=created_msg.id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=created_msg.content,
|
||||
)
|
||||
duplicate_messages.append(duplicate_msg)
|
||||
|
||||
result = await message_manager.create_many_messages_async(pydantic_msgs=duplicate_messages, actor=default_user, allow_partial=True)
|
||||
|
||||
assert len(result) == 3
|
||||
for i, msg in enumerate(result):
|
||||
assert msg.id == created_initial[i].id
|
||||
assert msg.content[0].text == f"Message {i}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_many_messages_async_empty_list(server: SyncServer, default_user):
|
||||
"""Test that empty list returns empty list"""
|
||||
message_manager = server.message_manager
|
||||
|
||||
result = await message_manager.create_many_messages_async(pydantic_msgs=[], actor=default_user)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_existing_message_ids(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test the check_existing_message_ids convenience function"""
|
||||
message_manager = server.message_manager
|
||||
|
||||
messages = []
|
||||
for i in range(3):
|
||||
msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
messages.append(msg)
|
||||
|
||||
created_messages = await message_manager.create_many_messages_async(pydantic_msgs=messages, actor=default_user)
|
||||
existing_ids = [msg.id for msg in created_messages]
|
||||
|
||||
non_existent_ids = [f"message-{uuid.uuid4().hex[:8]}" for _ in range(3)]
|
||||
all_ids = existing_ids + non_existent_ids
|
||||
|
||||
existing = await message_manager.check_existing_message_ids(message_ids=all_ids, actor=default_user)
|
||||
|
||||
assert existing == set(existing_ids)
|
||||
for non_existent_id in non_existent_ids:
|
||||
assert non_existent_id not in existing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_existing_messages(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test the filter_existing_messages helper function"""
|
||||
message_manager = server.message_manager
|
||||
|
||||
initial_messages = []
|
||||
for i in range(3):
|
||||
msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"Existing {i}")],
|
||||
)
|
||||
initial_messages.append(msg)
|
||||
|
||||
created_existing = await message_manager.create_many_messages_async(pydantic_msgs=initial_messages, actor=default_user)
|
||||
|
||||
existing_messages = []
|
||||
for created_msg in created_existing:
|
||||
msg = PydanticMessage(
|
||||
id=created_msg.id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=created_msg.content,
|
||||
)
|
||||
existing_messages.append(msg)
|
||||
|
||||
new_messages = []
|
||||
for i in range(3):
|
||||
msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"New {i}")],
|
||||
)
|
||||
new_messages.append(msg)
|
||||
|
||||
all_messages = existing_messages + new_messages
|
||||
|
||||
new_filtered, existing_filtered = await message_manager.filter_existing_messages(messages=all_messages, actor=default_user)
|
||||
|
||||
assert len(new_filtered) == 3
|
||||
assert len(existing_filtered) == 3
|
||||
|
||||
existing_filtered_ids = {msg.id for msg in existing_filtered}
|
||||
for created_msg in created_existing:
|
||||
assert created_msg.id in existing_filtered_ids
|
||||
|
||||
for msg in new_filtered:
|
||||
assert msg.id not in existing_filtered_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_many_messages_async_with_turbopuffer(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test batch creation with turbopuffer embedding (if enabled)"""
|
||||
message_manager = server.message_manager
|
||||
|
||||
messages = []
|
||||
for i in range(3):
|
||||
msg = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"Important information about topic {i}")],
|
||||
)
|
||||
messages.append(msg)
|
||||
|
||||
created_messages = await message_manager.create_many_messages_async(
|
||||
pydantic_msgs=messages, actor=default_user, strict_mode=True, project_id="test_project", template_id="test_template"
|
||||
)
|
||||
|
||||
assert len(created_messages) == 3
|
||||
for msg in created_messages:
|
||||
assert msg.id is not None
|
||||
assert msg.agent_id == sarah_agent.id
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Manager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user