Revert "fix: truncate oversized text in embedding requests" (#9227)
Revert "fix: truncate oversized text in embedding requests (#9196)" This reverts commit a9c342087e022519c63d62fb76b72aed8859539b.
This commit is contained in:
@@ -12,7 +12,6 @@ import httpx
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.otel.tracing import trace_method, log_event
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, TagMatchMode
|
||||
@@ -159,38 +158,6 @@ def async_retry_with_backoff(
|
||||
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
|
||||
|
||||
|
||||
def _split_text_in_half(text: str) -> List[str]:
|
||||
"""Split text roughly in half, trying to break at sentence/word boundaries."""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
mid = len(text) // 2
|
||||
|
||||
# Look for a good break point (sentence end, newline, or space) near the middle
|
||||
best_break = mid
|
||||
search_range = min(500, mid // 2) # Search within 500 chars of middle
|
||||
|
||||
for offset in range(search_range):
|
||||
# Check both directions from middle
|
||||
for pos in [mid + offset, mid - offset]:
|
||||
if 0 <= pos < len(text) and text[pos] in ".!?\n ":
|
||||
best_break = pos + 1
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
first_half = text[:best_break].strip()
|
||||
second_half = text[best_break:].strip()
|
||||
|
||||
result = []
|
||||
if first_half:
|
||||
result.append(first_half)
|
||||
if second_half:
|
||||
result.append(second_half)
|
||||
return result
|
||||
|
||||
|
||||
def _run_turbopuffer_write_in_thread(
|
||||
api_key: str,
|
||||
region: str,
|
||||
@@ -288,6 +255,8 @@ class TurbopufferClient:
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
# filter out empty strings after stripping
|
||||
filtered_texts = [text for text in texts if text.strip()]
|
||||
|
||||
@@ -302,98 +271,6 @@ class TurbopufferClient:
|
||||
embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
|
||||
return embeddings
|
||||
|
||||
@trace_method
|
||||
async def _generate_embeddings_with_chunking(
|
||||
self, texts: List[str], actor: "PydanticUser", max_retries: int = 5
|
||||
) -> List[Tuple[str, List[float], int]]:
|
||||
"""Generate embeddings with automatic chunking for texts that exceed context length.
|
||||
|
||||
For texts that are too long, recursively splits them until embedding succeeds.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
actor: User actor for embedding generation
|
||||
max_retries: Maximum split attempts per text before giving up
|
||||
|
||||
Returns:
|
||||
List of (text, embedding, original_index) tuples. A single input text may
|
||||
produce multiple output tuples if it was split.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=self.default_embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
results: List[Tuple[str, List[float], int]] = []
|
||||
|
||||
for original_idx, text in enumerate(texts):
|
||||
text = text.strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# Try to embed, splitting recursively on context length errors
|
||||
texts_to_embed = [text]
|
||||
retry_count = 0
|
||||
|
||||
while texts_to_embed and retry_count < max_retries:
|
||||
try:
|
||||
embeddings = await embedding_client.request_embeddings(texts_to_embed, self.default_embedding_config)
|
||||
|
||||
# Success - add all chunks to results
|
||||
for chunk_text, embedding in zip(texts_to_embed, embeddings):
|
||||
results.append((chunk_text, embedding, original_idx))
|
||||
texts_to_embed = [] # Clear the queue
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
is_context_length_error = "context length" in error_str or "maximum context" in error_str
|
||||
|
||||
if is_context_length_error and len(texts_to_embed) == 1:
|
||||
# Single text is too long - split it
|
||||
long_text = texts_to_embed[0]
|
||||
split_texts = _split_text_in_half(long_text)
|
||||
|
||||
if len(split_texts) <= 1:
|
||||
# Can't split further, re-raise the original exception
|
||||
logger.error(f"Cannot split text further, still exceeds context limit: {len(long_text)} chars")
|
||||
raise e
|
||||
|
||||
logger.warning(
|
||||
f"Text exceeds context limit ({len(long_text)} chars), splitting into {len(split_texts)} chunks and retrying"
|
||||
)
|
||||
texts_to_embed = split_texts
|
||||
retry_count += 1
|
||||
|
||||
elif is_context_length_error and len(texts_to_embed) > 1:
|
||||
# Multiple texts - try one at a time to find the problematic one
|
||||
logger.warning("Batch embedding failed with context error, trying individually")
|
||||
new_texts = []
|
||||
for t in texts_to_embed:
|
||||
try:
|
||||
emb = await embedding_client.request_embeddings([t], self.default_embedding_config)
|
||||
results.append((t, emb[0], original_idx))
|
||||
except Exception as inner_e:
|
||||
if "context length" in str(inner_e).lower():
|
||||
# This text is too long - split it
|
||||
new_texts.extend(_split_text_in_half(t))
|
||||
else:
|
||||
raise inner_e
|
||||
texts_to_embed = new_texts
|
||||
retry_count += 1
|
||||
else:
|
||||
# Non-context-length error, re-raise the original exception
|
||||
raise e
|
||||
|
||||
if texts_to_embed:
|
||||
# Exhausted retries
|
||||
logger.error(f"Failed to embed text after {max_retries} split attempts")
|
||||
raise RuntimeError(f"Text could not be embedded after {max_retries} split attempts")
|
||||
|
||||
return results
|
||||
|
||||
@trace_method
|
||||
async def _get_archive_namespace_name(self, archive_id: str) -> str:
|
||||
"""Get namespace name for a specific archive."""
|
||||
@@ -759,52 +636,30 @@ class TurbopufferClient:
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
# validation checks
|
||||
if not message_ids:
|
||||
raise LettaInvalidArgumentError("message_ids must be provided for Turbopuffer insertion", argument_name="message_ids")
|
||||
if len(message_ids) != len(message_texts):
|
||||
raise LettaInvalidArgumentError(
|
||||
f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})",
|
||||
argument_name="message_ids",
|
||||
)
|
||||
if len(message_ids) != len(roles):
|
||||
raise LettaInvalidArgumentError(
|
||||
f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})", argument_name="roles"
|
||||
)
|
||||
if len(message_ids) != len(created_ats):
|
||||
raise LettaInvalidArgumentError(
|
||||
f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})", argument_name="created_ats"
|
||||
)
|
||||
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
|
||||
raise LettaInvalidArgumentError(
|
||||
f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})",
|
||||
argument_name="conversation_ids",
|
||||
)
|
||||
# filter out empty message texts
|
||||
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
|
||||
|
||||
# Filter out empty messages and prepare metadata
|
||||
valid_messages = [] # List of (index, text)
|
||||
for i, text in enumerate(message_texts):
|
||||
text = text.strip()
|
||||
if text:
|
||||
valid_messages.append((i, text))
|
||||
|
||||
if not valid_messages:
|
||||
if not filtered_messages:
|
||||
logger.warning("All message texts were empty, skipping insertion")
|
||||
return True
|
||||
|
||||
# Generate embeddings with automatic chunking for texts that exceed context length
|
||||
# This returns (chunk_text, embedding, original_valid_idx) tuples
|
||||
texts_to_embed = [text for _, text in valid_messages]
|
||||
embedding_results = await self._generate_embeddings_with_chunking(texts_to_embed, actor)
|
||||
# generate embeddings using the default config
|
||||
filtered_texts = [text for _, text in filtered_messages]
|
||||
embeddings = await self._generate_embeddings(filtered_texts, actor)
|
||||
|
||||
namespace_name = await self._get_message_namespace_name(organization_id)
|
||||
|
||||
# Build a mapping from valid_messages index to original message metadata
|
||||
# This lets us look up the original message_id, role, etc. for each chunk
|
||||
valid_idx_to_original = {valid_idx: original_idx for valid_idx, (original_idx, _) in enumerate(valid_messages)}
|
||||
|
||||
# Track chunk indices per message for composite IDs
|
||||
message_chunk_counts: dict = {}
|
||||
# validation checks
|
||||
if not message_ids:
|
||||
raise ValueError("message_ids must be provided for Turbopuffer insertion")
|
||||
if len(message_ids) != len(message_texts):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
|
||||
if len(message_ids) != len(roles):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
||||
if len(message_ids) != len(created_ats):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
|
||||
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
|
||||
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
|
||||
|
||||
# prepare column-based data for turbopuffer - optimized for batch insert
|
||||
ids = []
|
||||
@@ -812,32 +667,18 @@ class TurbopufferClient:
|
||||
texts = []
|
||||
organization_ids_list = []
|
||||
agent_ids_list = []
|
||||
message_id_list = [] # Store original message_id for deduplication
|
||||
chunk_index_list = [] # Store chunk index
|
||||
message_roles = []
|
||||
created_at_timestamps = []
|
||||
project_ids_list = []
|
||||
template_ids_list = []
|
||||
conversation_ids_list = []
|
||||
|
||||
for chunk_text, embedding, valid_idx in embedding_results:
|
||||
# Map back to original message metadata
|
||||
original_idx = valid_idx_to_original[valid_idx]
|
||||
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
|
||||
message_id = message_ids[original_idx]
|
||||
role = roles[original_idx]
|
||||
created_at = created_ats[original_idx]
|
||||
conversation_id = conversation_ids[original_idx] if conversation_ids else None
|
||||
|
||||
# Track chunk index for this message
|
||||
chunk_idx = message_chunk_counts.get(message_id, 0)
|
||||
message_chunk_counts[message_id] = chunk_idx + 1
|
||||
|
||||
# Use composite ID for chunks: message_id for chunk 0, message_id_chunk_N for others
|
||||
if chunk_idx == 0:
|
||||
record_id = message_id
|
||||
else:
|
||||
record_id = f"{message_id}_chunk_{chunk_idx}"
|
||||
|
||||
# ensure the provided timestamp is timezone-aware and in UTC
|
||||
if created_at.tzinfo is None:
|
||||
# assume UTC if no timezone provided
|
||||
@@ -847,13 +688,11 @@ class TurbopufferClient:
|
||||
timestamp = created_at.astimezone(timezone.utc)
|
||||
|
||||
# append to columns
|
||||
ids.append(record_id)
|
||||
ids.append(message_id)
|
||||
vectors.append(embedding)
|
||||
texts.append(chunk_text)
|
||||
texts.append(text)
|
||||
organization_ids_list.append(organization_id)
|
||||
agent_ids_list.append(agent_id)
|
||||
message_id_list.append(message_id) # Original message ID for deduplication
|
||||
chunk_index_list.append(chunk_idx)
|
||||
message_roles.append(role.value)
|
||||
created_at_timestamps.append(timestamp)
|
||||
project_ids_list.append(project_id)
|
||||
@@ -867,8 +706,6 @@ class TurbopufferClient:
|
||||
"text": texts,
|
||||
"organization_id": organization_ids_list,
|
||||
"agent_id": agent_ids_list,
|
||||
"message_id": message_id_list, # Original message ID for deduplication
|
||||
"chunk_index": chunk_index_list, # Chunk index (0 for first/only chunk)
|
||||
"role": message_roles,
|
||||
"created_at": created_at_timestamps,
|
||||
}
|
||||
@@ -885,12 +722,6 @@ class TurbopufferClient:
|
||||
if template_id is not None:
|
||||
upsert_columns["template_id"] = template_ids_list
|
||||
|
||||
# Log if we chunked any messages
|
||||
num_messages_chunked = sum(1 for count in message_chunk_counts.values() if count > 1)
|
||||
total_chunks = sum(message_chunk_counts.values())
|
||||
if num_messages_chunked > 0:
|
||||
logger.info(f"Split {num_messages_chunked} messages into {total_chunks} chunks for embedding")
|
||||
|
||||
try:
|
||||
# Use global semaphore to limit concurrent Turbopuffer writes
|
||||
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
||||
@@ -1552,35 +1383,23 @@ class TurbopufferClient:
|
||||
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
def _process_message_query_results(self, result, deduplicate: bool = True) -> List[dict]:
|
||||
def _process_message_query_results(self, result) -> List[dict]:
|
||||
"""Process results from a message query into message dicts.
|
||||
|
||||
For RRF, we only need the rank order - scores are not used.
|
||||
Deduplicates by message_id to handle chunked messages - keeps the first (best-ranked) chunk.
|
||||
"""
|
||||
messages = []
|
||||
seen_message_ids = set()
|
||||
|
||||
for row in result.rows:
|
||||
# Get the original message_id (for chunked messages) or fall back to row.id
|
||||
message_id = getattr(row, "message_id", None) or row.id
|
||||
chunk_index = getattr(row, "chunk_index", 0) or 0
|
||||
|
||||
# Deduplicate by message_id - keep only the first (best-ranked) chunk
|
||||
if deduplicate and message_id in seen_message_ids:
|
||||
continue
|
||||
seen_message_ids.add(message_id)
|
||||
|
||||
# Build message dict with key fields
|
||||
message_dict = {
|
||||
"id": message_id, # Use original message_id, not chunk id
|
||||
"id": row.id,
|
||||
"text": getattr(row, "text", ""),
|
||||
"organization_id": getattr(row, "organization_id", None),
|
||||
"agent_id": getattr(row, "agent_id", None),
|
||||
"role": getattr(row, "role", None),
|
||||
"created_at": getattr(row, "created_at", None),
|
||||
"conversation_id": getattr(row, "conversation_id", None),
|
||||
"chunk_index": chunk_index, # Include chunk index for debugging
|
||||
}
|
||||
messages.append(message_dict)
|
||||
|
||||
|
||||
@@ -911,8 +911,7 @@ class OpenAIClient(LLMClientBase):
|
||||
if isinstance(result, Exception):
|
||||
current_size = len(chunk_inputs)
|
||||
|
||||
if current_batch_size > 1 and current_size > 1:
|
||||
# Multiple inputs in batch - try splitting the batch
|
||||
if current_batch_size > 1:
|
||||
new_batch_size = max(1, current_batch_size // 2)
|
||||
logger.warning(
|
||||
f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. "
|
||||
|
||||
@@ -203,99 +203,3 @@ async def test_openai_embedding_minimum_chunk_failure(default_user):
|
||||
|
||||
with pytest.raises(Exception, match="API error"):
|
||||
await client.request_embeddings(test_inputs, embedding_config)
|
||||
|
||||
|
||||
def test_split_text_in_half():
|
||||
"""Test the _split_text_in_half helper function."""
|
||||
from letta.helpers.tpuf_client import _split_text_in_half
|
||||
|
||||
# Test with text that has sentence boundaries
|
||||
long_text = "This is a test sentence. " * 100
|
||||
splits = _split_text_in_half(long_text)
|
||||
assert len(splits) == 2
|
||||
assert len(splits[0]) > 0
|
||||
assert len(splits[1]) > 0
|
||||
# Should split at a sentence boundary
|
||||
assert splits[0].endswith(".")
|
||||
|
||||
# Test with text that has no good break points
|
||||
no_breaks = "a" * 1000
|
||||
splits = _split_text_in_half(no_breaks)
|
||||
assert len(splits) == 2
|
||||
assert len(splits[0]) + len(splits[1]) == 1000
|
||||
|
||||
# Test with empty text
|
||||
splits = _split_text_in_half("")
|
||||
assert splits == []
|
||||
|
||||
# Test with short text (still splits)
|
||||
short_text = "hello world"
|
||||
splits = _split_text_in_half(short_text)
|
||||
assert len(splits) == 2
|
||||
|
||||
|
||||
def test_chunked_message_query_deduplication():
|
||||
"""Test that chunked messages are deduplicated by message_id in query results."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
# Create a mock result with multiple chunks from the same message
|
||||
mock_result = MagicMock()
|
||||
|
||||
# Simulate 3 rows: 2 chunks from message-1, 1 chunk from message-2
|
||||
# The chunks are ranked by relevance (row order = rank order)
|
||||
row1 = MagicMock()
|
||||
row1.id = "message-1_chunk_1" # Second chunk of message-1, but ranked first
|
||||
row1.message_id = "message-1"
|
||||
row1.chunk_index = 1
|
||||
row1.text = "chunk 1 text"
|
||||
row1.organization_id = "org-1"
|
||||
row1.agent_id = "agent-1"
|
||||
row1.role = "user"
|
||||
row1.created_at = None
|
||||
row1.conversation_id = None
|
||||
|
||||
row2 = MagicMock()
|
||||
row2.id = "message-2"
|
||||
row2.message_id = "message-2"
|
||||
row2.chunk_index = 0
|
||||
row2.text = "message 2 text"
|
||||
row2.organization_id = "org-1"
|
||||
row2.agent_id = "agent-1"
|
||||
row2.role = "assistant"
|
||||
row2.created_at = None
|
||||
row2.conversation_id = None
|
||||
|
||||
row3 = MagicMock()
|
||||
row3.id = "message-1" # First chunk of message-1, but ranked third
|
||||
row3.message_id = "message-1"
|
||||
row3.chunk_index = 0
|
||||
row3.text = "chunk 0 text"
|
||||
row3.organization_id = "org-1"
|
||||
row3.agent_id = "agent-1"
|
||||
row3.role = "user"
|
||||
row3.created_at = None
|
||||
row3.conversation_id = None
|
||||
|
||||
mock_result.rows = [row1, row2, row3]
|
||||
|
||||
# Process results with deduplication
|
||||
client = TurbopufferClient.__new__(TurbopufferClient) # Create without __init__
|
||||
results = client._process_message_query_results(mock_result, deduplicate=True)
|
||||
|
||||
# Should have 2 messages (message-1 deduplicated)
|
||||
assert len(results) == 2
|
||||
|
||||
# First result should be message-1 (from the best-ranked chunk)
|
||||
assert results[0]["id"] == "message-1"
|
||||
assert results[0]["text"] == "chunk 1 text" # Text from best-ranked chunk
|
||||
assert results[0]["chunk_index"] == 1
|
||||
|
||||
# Second result should be message-2
|
||||
assert results[1]["id"] == "message-2"
|
||||
assert results[1]["text"] == "message 2 text"
|
||||
|
||||
# Test without deduplication
|
||||
results_no_dedup = client._process_message_query_results(mock_result, deduplicate=False)
|
||||
assert len(results_no_dedup) == 3 # All rows returned
|
||||
|
||||
Reference in New Issue
Block a user