fix: truncate oversized text in embedding requests (#9196)
fix: handle oversized text in embedding requests with recursive chunking When message text exceeds the embedding model's context length, recursively split it until all chunks can be embedded successfully. Changes: - `tpuf_client.py`: Add `_split_text_in_half()` helper for recursive splitting - `tpuf_client.py`: Add `_generate_embeddings_with_chunking()` that retries with splits on context length errors - `tpuf_client.py`: Store `message_id` and `chunk_index` columns in Turbopuffer - `tpuf_client.py`: Deduplicate query results by `message_id` - `tpuf_client.py`: Use `LettaInvalidArgumentError` instead of `ValueError` - `tpuf_client.py`: Move LLMClient import to top of file - `openai_client.py`: Remove fixed truncation (chunking handles this now) - Add tests for `_split_text_in_half` and chunked query deduplication 🤖 Generated with [Letta Code](https://letta.com) Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -12,6 +12,7 @@ import httpx
|
|||||||
|
|
||||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||||
from letta.errors import LettaInvalidArgumentError
|
from letta.errors import LettaInvalidArgumentError
|
||||||
|
from letta.llm_api.llm_client import LLMClient
|
||||||
from letta.otel.tracing import trace_method, log_event
|
from letta.otel.tracing import trace_method, log_event
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.enums import MessageRole, TagMatchMode
|
from letta.schemas.enums import MessageRole, TagMatchMode
|
||||||
@@ -158,6 +159,38 @@ def async_retry_with_backoff(
|
|||||||
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
|
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_text_in_half(text: str) -> List[str]:
|
||||||
|
"""Split text roughly in half, trying to break at sentence/word boundaries."""
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
mid = len(text) // 2
|
||||||
|
|
||||||
|
# Look for a good break point (sentence end, newline, or space) near the middle
|
||||||
|
best_break = mid
|
||||||
|
search_range = min(500, mid // 2) # Search within 500 chars of middle
|
||||||
|
|
||||||
|
for offset in range(search_range):
|
||||||
|
# Check both directions from middle
|
||||||
|
for pos in [mid + offset, mid - offset]:
|
||||||
|
if 0 <= pos < len(text) and text[pos] in ".!?\n ":
|
||||||
|
best_break = pos + 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
first_half = text[:best_break].strip()
|
||||||
|
second_half = text[best_break:].strip()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
if first_half:
|
||||||
|
result.append(first_half)
|
||||||
|
if second_half:
|
||||||
|
result.append(second_half)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _run_turbopuffer_write_in_thread(
|
def _run_turbopuffer_write_in_thread(
|
||||||
api_key: str,
|
api_key: str,
|
||||||
region: str,
|
region: str,
|
||||||
@@ -255,8 +288,6 @@ class TurbopufferClient:
|
|||||||
Returns:
|
Returns:
|
||||||
List of embedding vectors
|
List of embedding vectors
|
||||||
"""
|
"""
|
||||||
from letta.llm_api.llm_client import LLMClient
|
|
||||||
|
|
||||||
# filter out empty strings after stripping
|
# filter out empty strings after stripping
|
||||||
filtered_texts = [text for text in texts if text.strip()]
|
filtered_texts = [text for text in texts if text.strip()]
|
||||||
|
|
||||||
@@ -271,6 +302,98 @@ class TurbopufferClient:
|
|||||||
embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
|
embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
@trace_method
|
||||||
|
async def _generate_embeddings_with_chunking(
|
||||||
|
self, texts: List[str], actor: "PydanticUser", max_retries: int = 5
|
||||||
|
) -> List[Tuple[str, List[float], int]]:
|
||||||
|
"""Generate embeddings with automatic chunking for texts that exceed context length.
|
||||||
|
|
||||||
|
For texts that are too long, recursively splits them until embedding succeeds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to embed
|
||||||
|
actor: User actor for embedding generation
|
||||||
|
max_retries: Maximum split attempts per text before giving up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (text, embedding, original_index) tuples. A single input text may
|
||||||
|
produce multiple output tuples if it was split.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
embedding_client = LLMClient.create(
|
||||||
|
provider_type=self.default_embedding_config.embedding_endpoint_type,
|
||||||
|
actor=actor,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: List[Tuple[str, List[float], int]] = []
|
||||||
|
|
||||||
|
for original_idx, text in enumerate(texts):
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to embed, splitting recursively on context length errors
|
||||||
|
texts_to_embed = [text]
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
while texts_to_embed and retry_count < max_retries:
|
||||||
|
try:
|
||||||
|
embeddings = await embedding_client.request_embeddings(texts_to_embed, self.default_embedding_config)
|
||||||
|
|
||||||
|
# Success - add all chunks to results
|
||||||
|
for chunk_text, embedding in zip(texts_to_embed, embeddings):
|
||||||
|
results.append((chunk_text, embedding, original_idx))
|
||||||
|
texts_to_embed = [] # Clear the queue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_str = str(e).lower()
|
||||||
|
is_context_length_error = "context length" in error_str or "maximum context" in error_str
|
||||||
|
|
||||||
|
if is_context_length_error and len(texts_to_embed) == 1:
|
||||||
|
# Single text is too long - split it
|
||||||
|
long_text = texts_to_embed[0]
|
||||||
|
split_texts = _split_text_in_half(long_text)
|
||||||
|
|
||||||
|
if len(split_texts) <= 1:
|
||||||
|
# Can't split further, re-raise the original exception
|
||||||
|
logger.error(f"Cannot split text further, still exceeds context limit: {len(long_text)} chars")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Text exceeds context limit ({len(long_text)} chars), splitting into {len(split_texts)} chunks and retrying"
|
||||||
|
)
|
||||||
|
texts_to_embed = split_texts
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
elif is_context_length_error and len(texts_to_embed) > 1:
|
||||||
|
# Multiple texts - try one at a time to find the problematic one
|
||||||
|
logger.warning("Batch embedding failed with context error, trying individually")
|
||||||
|
new_texts = []
|
||||||
|
for t in texts_to_embed:
|
||||||
|
try:
|
||||||
|
emb = await embedding_client.request_embeddings([t], self.default_embedding_config)
|
||||||
|
results.append((t, emb[0], original_idx))
|
||||||
|
except Exception as inner_e:
|
||||||
|
if "context length" in str(inner_e).lower():
|
||||||
|
# This text is too long - split it
|
||||||
|
new_texts.extend(_split_text_in_half(t))
|
||||||
|
else:
|
||||||
|
raise inner_e
|
||||||
|
texts_to_embed = new_texts
|
||||||
|
retry_count += 1
|
||||||
|
else:
|
||||||
|
# Non-context-length error, re-raise the original exception
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if texts_to_embed:
|
||||||
|
# Exhausted retries
|
||||||
|
logger.error(f"Failed to embed text after {max_retries} split attempts")
|
||||||
|
raise RuntimeError(f"Text could not be embedded after {max_retries} split attempts")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
async def _get_archive_namespace_name(self, archive_id: str) -> str:
|
async def _get_archive_namespace_name(self, archive_id: str) -> str:
|
||||||
"""Get namespace name for a specific archive."""
|
"""Get namespace name for a specific archive."""
|
||||||
@@ -636,30 +759,52 @@ class TurbopufferClient:
|
|||||||
"""
|
"""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
from turbopuffer import AsyncTurbopuffer
|
||||||
|
|
||||||
# filter out empty message texts
|
# validation checks
|
||||||
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
|
if not message_ids:
|
||||||
|
raise LettaInvalidArgumentError("message_ids must be provided for Turbopuffer insertion", argument_name="message_ids")
|
||||||
|
if len(message_ids) != len(message_texts):
|
||||||
|
raise LettaInvalidArgumentError(
|
||||||
|
f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})",
|
||||||
|
argument_name="message_ids",
|
||||||
|
)
|
||||||
|
if len(message_ids) != len(roles):
|
||||||
|
raise LettaInvalidArgumentError(
|
||||||
|
f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})", argument_name="roles"
|
||||||
|
)
|
||||||
|
if len(message_ids) != len(created_ats):
|
||||||
|
raise LettaInvalidArgumentError(
|
||||||
|
f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})", argument_name="created_ats"
|
||||||
|
)
|
||||||
|
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
|
||||||
|
raise LettaInvalidArgumentError(
|
||||||
|
f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})",
|
||||||
|
argument_name="conversation_ids",
|
||||||
|
)
|
||||||
|
|
||||||
if not filtered_messages:
|
# Filter out empty messages and prepare metadata
|
||||||
|
valid_messages = [] # List of (index, text)
|
||||||
|
for i, text in enumerate(message_texts):
|
||||||
|
text = text.strip()
|
||||||
|
if text:
|
||||||
|
valid_messages.append((i, text))
|
||||||
|
|
||||||
|
if not valid_messages:
|
||||||
logger.warning("All message texts were empty, skipping insertion")
|
logger.warning("All message texts were empty, skipping insertion")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# generate embeddings using the default config
|
# Generate embeddings with automatic chunking for texts that exceed context length
|
||||||
filtered_texts = [text for _, text in filtered_messages]
|
# This returns (chunk_text, embedding, original_valid_idx) tuples
|
||||||
embeddings = await self._generate_embeddings(filtered_texts, actor)
|
texts_to_embed = [text for _, text in valid_messages]
|
||||||
|
embedding_results = await self._generate_embeddings_with_chunking(texts_to_embed, actor)
|
||||||
|
|
||||||
namespace_name = await self._get_message_namespace_name(organization_id)
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
||||||
|
|
||||||
# validation checks
|
# Build a mapping from valid_messages index to original message metadata
|
||||||
if not message_ids:
|
# This lets us look up the original message_id, role, etc. for each chunk
|
||||||
raise ValueError("message_ids must be provided for Turbopuffer insertion")
|
valid_idx_to_original = {valid_idx: original_idx for valid_idx, (original_idx, _) in enumerate(valid_messages)}
|
||||||
if len(message_ids) != len(message_texts):
|
|
||||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
|
# Track chunk indices per message for composite IDs
|
||||||
if len(message_ids) != len(roles):
|
message_chunk_counts: dict = {}
|
||||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
|
||||||
if len(message_ids) != len(created_ats):
|
|
||||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
|
|
||||||
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
|
|
||||||
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
|
|
||||||
|
|
||||||
# prepare column-based data for turbopuffer - optimized for batch insert
|
# prepare column-based data for turbopuffer - optimized for batch insert
|
||||||
ids = []
|
ids = []
|
||||||
@@ -667,18 +812,32 @@ class TurbopufferClient:
|
|||||||
texts = []
|
texts = []
|
||||||
organization_ids_list = []
|
organization_ids_list = []
|
||||||
agent_ids_list = []
|
agent_ids_list = []
|
||||||
|
message_id_list = [] # Store original message_id for deduplication
|
||||||
|
chunk_index_list = [] # Store chunk index
|
||||||
message_roles = []
|
message_roles = []
|
||||||
created_at_timestamps = []
|
created_at_timestamps = []
|
||||||
project_ids_list = []
|
project_ids_list = []
|
||||||
template_ids_list = []
|
template_ids_list = []
|
||||||
conversation_ids_list = []
|
conversation_ids_list = []
|
||||||
|
|
||||||
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
|
for chunk_text, embedding, valid_idx in embedding_results:
|
||||||
|
# Map back to original message metadata
|
||||||
|
original_idx = valid_idx_to_original[valid_idx]
|
||||||
message_id = message_ids[original_idx]
|
message_id = message_ids[original_idx]
|
||||||
role = roles[original_idx]
|
role = roles[original_idx]
|
||||||
created_at = created_ats[original_idx]
|
created_at = created_ats[original_idx]
|
||||||
conversation_id = conversation_ids[original_idx] if conversation_ids else None
|
conversation_id = conversation_ids[original_idx] if conversation_ids else None
|
||||||
|
|
||||||
|
# Track chunk index for this message
|
||||||
|
chunk_idx = message_chunk_counts.get(message_id, 0)
|
||||||
|
message_chunk_counts[message_id] = chunk_idx + 1
|
||||||
|
|
||||||
|
# Use composite ID for chunks: message_id for chunk 0, message_id_chunk_N for others
|
||||||
|
if chunk_idx == 0:
|
||||||
|
record_id = message_id
|
||||||
|
else:
|
||||||
|
record_id = f"{message_id}_chunk_{chunk_idx}"
|
||||||
|
|
||||||
# ensure the provided timestamp is timezone-aware and in UTC
|
# ensure the provided timestamp is timezone-aware and in UTC
|
||||||
if created_at.tzinfo is None:
|
if created_at.tzinfo is None:
|
||||||
# assume UTC if no timezone provided
|
# assume UTC if no timezone provided
|
||||||
@@ -688,11 +847,13 @@ class TurbopufferClient:
|
|||||||
timestamp = created_at.astimezone(timezone.utc)
|
timestamp = created_at.astimezone(timezone.utc)
|
||||||
|
|
||||||
# append to columns
|
# append to columns
|
||||||
ids.append(message_id)
|
ids.append(record_id)
|
||||||
vectors.append(embedding)
|
vectors.append(embedding)
|
||||||
texts.append(text)
|
texts.append(chunk_text)
|
||||||
organization_ids_list.append(organization_id)
|
organization_ids_list.append(organization_id)
|
||||||
agent_ids_list.append(agent_id)
|
agent_ids_list.append(agent_id)
|
||||||
|
message_id_list.append(message_id) # Original message ID for deduplication
|
||||||
|
chunk_index_list.append(chunk_idx)
|
||||||
message_roles.append(role.value)
|
message_roles.append(role.value)
|
||||||
created_at_timestamps.append(timestamp)
|
created_at_timestamps.append(timestamp)
|
||||||
project_ids_list.append(project_id)
|
project_ids_list.append(project_id)
|
||||||
@@ -706,6 +867,8 @@ class TurbopufferClient:
|
|||||||
"text": texts,
|
"text": texts,
|
||||||
"organization_id": organization_ids_list,
|
"organization_id": organization_ids_list,
|
||||||
"agent_id": agent_ids_list,
|
"agent_id": agent_ids_list,
|
||||||
|
"message_id": message_id_list, # Original message ID for deduplication
|
||||||
|
"chunk_index": chunk_index_list, # Chunk index (0 for first/only chunk)
|
||||||
"role": message_roles,
|
"role": message_roles,
|
||||||
"created_at": created_at_timestamps,
|
"created_at": created_at_timestamps,
|
||||||
}
|
}
|
||||||
@@ -722,6 +885,12 @@ class TurbopufferClient:
|
|||||||
if template_id is not None:
|
if template_id is not None:
|
||||||
upsert_columns["template_id"] = template_ids_list
|
upsert_columns["template_id"] = template_ids_list
|
||||||
|
|
||||||
|
# Log if we chunked any messages
|
||||||
|
num_messages_chunked = sum(1 for count in message_chunk_counts.values() if count > 1)
|
||||||
|
total_chunks = sum(message_chunk_counts.values())
|
||||||
|
if num_messages_chunked > 0:
|
||||||
|
logger.info(f"Split {num_messages_chunked} messages into {total_chunks} chunks for embedding")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use global semaphore to limit concurrent Turbopuffer writes
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
||||||
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
||||||
@@ -1383,23 +1552,35 @@ class TurbopufferClient:
|
|||||||
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _process_message_query_results(self, result) -> List[dict]:
|
def _process_message_query_results(self, result, deduplicate: bool = True) -> List[dict]:
|
||||||
"""Process results from a message query into message dicts.
|
"""Process results from a message query into message dicts.
|
||||||
|
|
||||||
For RRF, we only need the rank order - scores are not used.
|
For RRF, we only need the rank order - scores are not used.
|
||||||
|
Deduplicates by message_id to handle chunked messages - keeps the first (best-ranked) chunk.
|
||||||
"""
|
"""
|
||||||
messages = []
|
messages = []
|
||||||
|
seen_message_ids = set()
|
||||||
|
|
||||||
for row in result.rows:
|
for row in result.rows:
|
||||||
|
# Get the original message_id (for chunked messages) or fall back to row.id
|
||||||
|
message_id = getattr(row, "message_id", None) or row.id
|
||||||
|
chunk_index = getattr(row, "chunk_index", 0) or 0
|
||||||
|
|
||||||
|
# Deduplicate by message_id - keep only the first (best-ranked) chunk
|
||||||
|
if deduplicate and message_id in seen_message_ids:
|
||||||
|
continue
|
||||||
|
seen_message_ids.add(message_id)
|
||||||
|
|
||||||
# Build message dict with key fields
|
# Build message dict with key fields
|
||||||
message_dict = {
|
message_dict = {
|
||||||
"id": row.id,
|
"id": message_id, # Use original message_id, not chunk id
|
||||||
"text": getattr(row, "text", ""),
|
"text": getattr(row, "text", ""),
|
||||||
"organization_id": getattr(row, "organization_id", None),
|
"organization_id": getattr(row, "organization_id", None),
|
||||||
"agent_id": getattr(row, "agent_id", None),
|
"agent_id": getattr(row, "agent_id", None),
|
||||||
"role": getattr(row, "role", None),
|
"role": getattr(row, "role", None),
|
||||||
"created_at": getattr(row, "created_at", None),
|
"created_at": getattr(row, "created_at", None),
|
||||||
"conversation_id": getattr(row, "conversation_id", None),
|
"conversation_id": getattr(row, "conversation_id", None),
|
||||||
|
"chunk_index": chunk_index, # Include chunk index for debugging
|
||||||
}
|
}
|
||||||
messages.append(message_dict)
|
messages.append(message_dict)
|
||||||
|
|
||||||
|
|||||||
@@ -911,7 +911,8 @@ class OpenAIClient(LLMClientBase):
|
|||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
current_size = len(chunk_inputs)
|
current_size = len(chunk_inputs)
|
||||||
|
|
||||||
if current_batch_size > 1:
|
if current_batch_size > 1 and current_size > 1:
|
||||||
|
# Multiple inputs in batch - try splitting the batch
|
||||||
new_batch_size = max(1, current_batch_size // 2)
|
new_batch_size = max(1, current_batch_size // 2)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. "
|
f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. "
|
||||||
|
|||||||
@@ -203,3 +203,99 @@ async def test_openai_embedding_minimum_chunk_failure(default_user):
|
|||||||
|
|
||||||
with pytest.raises(Exception, match="API error"):
|
with pytest.raises(Exception, match="API error"):
|
||||||
await client.request_embeddings(test_inputs, embedding_config)
|
await client.request_embeddings(test_inputs, embedding_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_text_in_half():
|
||||||
|
"""Test the _split_text_in_half helper function."""
|
||||||
|
from letta.helpers.tpuf_client import _split_text_in_half
|
||||||
|
|
||||||
|
# Test with text that has sentence boundaries
|
||||||
|
long_text = "This is a test sentence. " * 100
|
||||||
|
splits = _split_text_in_half(long_text)
|
||||||
|
assert len(splits) == 2
|
||||||
|
assert len(splits[0]) > 0
|
||||||
|
assert len(splits[1]) > 0
|
||||||
|
# Should split at a sentence boundary
|
||||||
|
assert splits[0].endswith(".")
|
||||||
|
|
||||||
|
# Test with text that has no good break points
|
||||||
|
no_breaks = "a" * 1000
|
||||||
|
splits = _split_text_in_half(no_breaks)
|
||||||
|
assert len(splits) == 2
|
||||||
|
assert len(splits[0]) + len(splits[1]) == 1000
|
||||||
|
|
||||||
|
# Test with empty text
|
||||||
|
splits = _split_text_in_half("")
|
||||||
|
assert splits == []
|
||||||
|
|
||||||
|
# Test with short text (still splits)
|
||||||
|
short_text = "hello world"
|
||||||
|
splits = _split_text_in_half(short_text)
|
||||||
|
assert len(splits) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_message_query_deduplication():
|
||||||
|
"""Test that chunked messages are deduplicated by message_id in query results."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from letta.helpers.tpuf_client import TurbopufferClient
|
||||||
|
|
||||||
|
# Create a mock result with multiple chunks from the same message
|
||||||
|
mock_result = MagicMock()
|
||||||
|
|
||||||
|
# Simulate 3 rows: 2 chunks from message-1, 1 chunk from message-2
|
||||||
|
# The chunks are ranked by relevance (row order = rank order)
|
||||||
|
row1 = MagicMock()
|
||||||
|
row1.id = "message-1_chunk_1" # Second chunk of message-1, but ranked first
|
||||||
|
row1.message_id = "message-1"
|
||||||
|
row1.chunk_index = 1
|
||||||
|
row1.text = "chunk 1 text"
|
||||||
|
row1.organization_id = "org-1"
|
||||||
|
row1.agent_id = "agent-1"
|
||||||
|
row1.role = "user"
|
||||||
|
row1.created_at = None
|
||||||
|
row1.conversation_id = None
|
||||||
|
|
||||||
|
row2 = MagicMock()
|
||||||
|
row2.id = "message-2"
|
||||||
|
row2.message_id = "message-2"
|
||||||
|
row2.chunk_index = 0
|
||||||
|
row2.text = "message 2 text"
|
||||||
|
row2.organization_id = "org-1"
|
||||||
|
row2.agent_id = "agent-1"
|
||||||
|
row2.role = "assistant"
|
||||||
|
row2.created_at = None
|
||||||
|
row2.conversation_id = None
|
||||||
|
|
||||||
|
row3 = MagicMock()
|
||||||
|
row3.id = "message-1" # First chunk of message-1, but ranked third
|
||||||
|
row3.message_id = "message-1"
|
||||||
|
row3.chunk_index = 0
|
||||||
|
row3.text = "chunk 0 text"
|
||||||
|
row3.organization_id = "org-1"
|
||||||
|
row3.agent_id = "agent-1"
|
||||||
|
row3.role = "user"
|
||||||
|
row3.created_at = None
|
||||||
|
row3.conversation_id = None
|
||||||
|
|
||||||
|
mock_result.rows = [row1, row2, row3]
|
||||||
|
|
||||||
|
# Process results with deduplication
|
||||||
|
client = TurbopufferClient.__new__(TurbopufferClient) # Create without __init__
|
||||||
|
results = client._process_message_query_results(mock_result, deduplicate=True)
|
||||||
|
|
||||||
|
# Should have 2 messages (message-1 deduplicated)
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
# First result should be message-1 (from the best-ranked chunk)
|
||||||
|
assert results[0]["id"] == "message-1"
|
||||||
|
assert results[0]["text"] == "chunk 1 text" # Text from best-ranked chunk
|
||||||
|
assert results[0]["chunk_index"] == 1
|
||||||
|
|
||||||
|
# Second result should be message-2
|
||||||
|
assert results[1]["id"] == "message-2"
|
||||||
|
assert results[1]["text"] == "message 2 text"
|
||||||
|
|
||||||
|
# Test without deduplication
|
||||||
|
results_no_dedup = client._process_message_query_results(mock_result, deduplicate=False)
|
||||||
|
assert len(results_no_dedup) == 3 # All rows returned
|
||||||
|
|||||||
Reference in New Issue
Block a user