Revert "fix: truncate oversized text in embedding requests" (#9227)

Revert "fix: truncate oversized text in embedding requests (#9196)"

This reverts commit a9c342087e022519c63d62fb76b72aed8859539b.
This commit is contained in:
Kian Jones
2026-01-30 16:35:21 -08:00
committed by Caren Thomas
parent 68eb076135
commit 01cb00ae10
3 changed files with 25 additions and 303 deletions

View File

@@ -12,7 +12,6 @@ import httpx
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.errors import LettaInvalidArgumentError
from letta.llm_api.llm_client import LLMClient
from letta.otel.tracing import trace_method, log_event
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, TagMatchMode
@@ -159,38 +158,6 @@ def async_retry_with_backoff(
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
def _split_text_in_half(text: str) -> List[str]:
"""Split text roughly in half, trying to break at sentence/word boundaries."""
if not text:
return []
mid = len(text) // 2
# Look for a good break point (sentence end, newline, or space) near the middle
best_break = mid
search_range = min(500, mid // 2) # Search within 500 chars of middle
for offset in range(search_range):
# Check both directions from middle
for pos in [mid + offset, mid - offset]:
if 0 <= pos < len(text) and text[pos] in ".!?\n ":
best_break = pos + 1
break
else:
continue
break
first_half = text[:best_break].strip()
second_half = text[best_break:].strip()
result = []
if first_half:
result.append(first_half)
if second_half:
result.append(second_half)
return result
def _run_turbopuffer_write_in_thread(
api_key: str,
region: str,
@@ -288,6 +255,8 @@ class TurbopufferClient:
Returns:
List of embedding vectors
"""
from letta.llm_api.llm_client import LLMClient
# filter out empty strings after stripping
filtered_texts = [text for text in texts if text.strip()]
@@ -302,98 +271,6 @@ class TurbopufferClient:
embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
return embeddings
@trace_method
async def _generate_embeddings_with_chunking(
self, texts: List[str], actor: "PydanticUser", max_retries: int = 5
) -> List[Tuple[str, List[float], int]]:
"""Generate embeddings with automatic chunking for texts that exceed context length.
For texts that are too long, recursively splits them until embedding succeeds.
Args:
texts: List of texts to embed
actor: User actor for embedding generation
max_retries: Maximum split attempts per text before giving up
Returns:
List of (text, embedding, original_index) tuples. A single input text may
produce multiple output tuples if it was split.
"""
if not texts:
return []
embedding_client = LLMClient.create(
provider_type=self.default_embedding_config.embedding_endpoint_type,
actor=actor,
)
results: List[Tuple[str, List[float], int]] = []
for original_idx, text in enumerate(texts):
text = text.strip()
if not text:
continue
# Try to embed, splitting recursively on context length errors
texts_to_embed = [text]
retry_count = 0
while texts_to_embed and retry_count < max_retries:
try:
embeddings = await embedding_client.request_embeddings(texts_to_embed, self.default_embedding_config)
# Success - add all chunks to results
for chunk_text, embedding in zip(texts_to_embed, embeddings):
results.append((chunk_text, embedding, original_idx))
texts_to_embed = [] # Clear the queue
except Exception as e:
error_str = str(e).lower()
is_context_length_error = "context length" in error_str or "maximum context" in error_str
if is_context_length_error and len(texts_to_embed) == 1:
# Single text is too long - split it
long_text = texts_to_embed[0]
split_texts = _split_text_in_half(long_text)
if len(split_texts) <= 1:
# Can't split further, re-raise the original exception
logger.error(f"Cannot split text further, still exceeds context limit: {len(long_text)} chars")
raise e
logger.warning(
f"Text exceeds context limit ({len(long_text)} chars), splitting into {len(split_texts)} chunks and retrying"
)
texts_to_embed = split_texts
retry_count += 1
elif is_context_length_error and len(texts_to_embed) > 1:
# Multiple texts - try one at a time to find the problematic one
logger.warning("Batch embedding failed with context error, trying individually")
new_texts = []
for t in texts_to_embed:
try:
emb = await embedding_client.request_embeddings([t], self.default_embedding_config)
results.append((t, emb[0], original_idx))
except Exception as inner_e:
if "context length" in str(inner_e).lower():
# This text is too long - split it
new_texts.extend(_split_text_in_half(t))
else:
raise inner_e
texts_to_embed = new_texts
retry_count += 1
else:
# Non-context-length error, re-raise the original exception
raise e
if texts_to_embed:
# Exhausted retries
logger.error(f"Failed to embed text after {max_retries} split attempts")
raise RuntimeError(f"Text could not be embedded after {max_retries} split attempts")
return results
@trace_method
async def _get_archive_namespace_name(self, archive_id: str) -> str:
"""Get namespace name for a specific archive."""
@@ -759,52 +636,30 @@ class TurbopufferClient:
"""
from turbopuffer import AsyncTurbopuffer
# validation checks
if not message_ids:
raise LettaInvalidArgumentError("message_ids must be provided for Turbopuffer insertion", argument_name="message_ids")
if len(message_ids) != len(message_texts):
raise LettaInvalidArgumentError(
f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})",
argument_name="message_ids",
)
if len(message_ids) != len(roles):
raise LettaInvalidArgumentError(
f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})", argument_name="roles"
)
if len(message_ids) != len(created_ats):
raise LettaInvalidArgumentError(
f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})", argument_name="created_ats"
)
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
raise LettaInvalidArgumentError(
f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})",
argument_name="conversation_ids",
)
# filter out empty message texts
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
# Filter out empty messages and prepare metadata
valid_messages = [] # List of (index, text)
for i, text in enumerate(message_texts):
text = text.strip()
if text:
valid_messages.append((i, text))
if not valid_messages:
if not filtered_messages:
logger.warning("All message texts were empty, skipping insertion")
return True
# Generate embeddings with automatic chunking for texts that exceed context length
# This returns (chunk_text, embedding, original_valid_idx) tuples
texts_to_embed = [text for _, text in valid_messages]
embedding_results = await self._generate_embeddings_with_chunking(texts_to_embed, actor)
# generate embeddings using the default config
filtered_texts = [text for _, text in filtered_messages]
embeddings = await self._generate_embeddings(filtered_texts, actor)
namespace_name = await self._get_message_namespace_name(organization_id)
# Build a mapping from valid_messages index to original message metadata
# This lets us look up the original message_id, role, etc. for each chunk
valid_idx_to_original = {valid_idx: original_idx for valid_idx, (original_idx, _) in enumerate(valid_messages)}
# Track chunk indices per message for composite IDs
message_chunk_counts: dict = {}
# validation checks
if not message_ids:
raise ValueError("message_ids must be provided for Turbopuffer insertion")
if len(message_ids) != len(message_texts):
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
if len(message_ids) != len(roles):
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
if len(message_ids) != len(created_ats):
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
@@ -812,32 +667,18 @@ class TurbopufferClient:
texts = []
organization_ids_list = []
agent_ids_list = []
message_id_list = [] # Store original message_id for deduplication
chunk_index_list = [] # Store chunk index
message_roles = []
created_at_timestamps = []
project_ids_list = []
template_ids_list = []
conversation_ids_list = []
for chunk_text, embedding, valid_idx in embedding_results:
# Map back to original message metadata
original_idx = valid_idx_to_original[valid_idx]
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
message_id = message_ids[original_idx]
role = roles[original_idx]
created_at = created_ats[original_idx]
conversation_id = conversation_ids[original_idx] if conversation_ids else None
# Track chunk index for this message
chunk_idx = message_chunk_counts.get(message_id, 0)
message_chunk_counts[message_id] = chunk_idx + 1
# Use composite ID for chunks: message_id for chunk 0, message_id_chunk_N for others
if chunk_idx == 0:
record_id = message_id
else:
record_id = f"{message_id}_chunk_{chunk_idx}"
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
# assume UTC if no timezone provided
@@ -847,13 +688,11 @@ class TurbopufferClient:
timestamp = created_at.astimezone(timezone.utc)
# append to columns
ids.append(record_id)
ids.append(message_id)
vectors.append(embedding)
texts.append(chunk_text)
texts.append(text)
organization_ids_list.append(organization_id)
agent_ids_list.append(agent_id)
message_id_list.append(message_id) # Original message ID for deduplication
chunk_index_list.append(chunk_idx)
message_roles.append(role.value)
created_at_timestamps.append(timestamp)
project_ids_list.append(project_id)
@@ -867,8 +706,6 @@ class TurbopufferClient:
"text": texts,
"organization_id": organization_ids_list,
"agent_id": agent_ids_list,
"message_id": message_id_list, # Original message ID for deduplication
"chunk_index": chunk_index_list, # Chunk index (0 for first/only chunk)
"role": message_roles,
"created_at": created_at_timestamps,
}
@@ -885,12 +722,6 @@ class TurbopufferClient:
if template_id is not None:
upsert_columns["template_id"] = template_ids_list
# Log if we chunked any messages
num_messages_chunked = sum(1 for count in message_chunk_counts.values() if count > 1)
total_chunks = sum(message_chunk_counts.values())
if num_messages_chunked > 0:
logger.info(f"Split {num_messages_chunked} messages into {total_chunks} chunks for embedding")
try:
# Use global semaphore to limit concurrent Turbopuffer writes
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
@@ -1552,35 +1383,23 @@ class TurbopufferClient:
logger.error(f"Failed to query messages from Turbopuffer: {e}")
raise
def _process_message_query_results(self, result, deduplicate: bool = True) -> List[dict]:
def _process_message_query_results(self, result) -> List[dict]:
"""Process results from a message query into message dicts.
For RRF, we only need the rank order - scores are not used.
Deduplicates by message_id to handle chunked messages - keeps the first (best-ranked) chunk.
"""
messages = []
seen_message_ids = set()
for row in result.rows:
# Get the original message_id (for chunked messages) or fall back to row.id
message_id = getattr(row, "message_id", None) or row.id
chunk_index = getattr(row, "chunk_index", 0) or 0
# Deduplicate by message_id - keep only the first (best-ranked) chunk
if deduplicate and message_id in seen_message_ids:
continue
seen_message_ids.add(message_id)
# Build message dict with key fields
message_dict = {
"id": message_id, # Use original message_id, not chunk id
"id": row.id,
"text": getattr(row, "text", ""),
"organization_id": getattr(row, "organization_id", None),
"agent_id": getattr(row, "agent_id", None),
"role": getattr(row, "role", None),
"created_at": getattr(row, "created_at", None),
"conversation_id": getattr(row, "conversation_id", None),
"chunk_index": chunk_index, # Include chunk index for debugging
}
messages.append(message_dict)

View File

@@ -911,8 +911,7 @@ class OpenAIClient(LLMClientBase):
if isinstance(result, Exception):
current_size = len(chunk_inputs)
if current_batch_size > 1 and current_size > 1:
# Multiple inputs in batch - try splitting the batch
if current_batch_size > 1:
new_batch_size = max(1, current_batch_size // 2)
logger.warning(
f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. "

View File

@@ -203,99 +203,3 @@ async def test_openai_embedding_minimum_chunk_failure(default_user):
with pytest.raises(Exception, match="API error"):
await client.request_embeddings(test_inputs, embedding_config)
def test_split_text_in_half():
"""Test the _split_text_in_half helper function."""
from letta.helpers.tpuf_client import _split_text_in_half
# Test with text that has sentence boundaries
long_text = "This is a test sentence. " * 100
splits = _split_text_in_half(long_text)
assert len(splits) == 2
assert len(splits[0]) > 0
assert len(splits[1]) > 0
# Should split at a sentence boundary
assert splits[0].endswith(".")
# Test with text that has no good break points
no_breaks = "a" * 1000
splits = _split_text_in_half(no_breaks)
assert len(splits) == 2
assert len(splits[0]) + len(splits[1]) == 1000
# Test with empty text
splits = _split_text_in_half("")
assert splits == []
# Test with short text (still splits)
short_text = "hello world"
splits = _split_text_in_half(short_text)
assert len(splits) == 2
def test_chunked_message_query_deduplication():
"""Test that chunked messages are deduplicated by message_id in query results."""
from unittest.mock import MagicMock
from letta.helpers.tpuf_client import TurbopufferClient
# Create a mock result with multiple chunks from the same message
mock_result = MagicMock()
# Simulate 3 rows: 2 chunks from message-1, 1 chunk from message-2
# The chunks are ranked by relevance (row order = rank order)
row1 = MagicMock()
row1.id = "message-1_chunk_1" # Second chunk of message-1, but ranked first
row1.message_id = "message-1"
row1.chunk_index = 1
row1.text = "chunk 1 text"
row1.organization_id = "org-1"
row1.agent_id = "agent-1"
row1.role = "user"
row1.created_at = None
row1.conversation_id = None
row2 = MagicMock()
row2.id = "message-2"
row2.message_id = "message-2"
row2.chunk_index = 0
row2.text = "message 2 text"
row2.organization_id = "org-1"
row2.agent_id = "agent-1"
row2.role = "assistant"
row2.created_at = None
row2.conversation_id = None
row3 = MagicMock()
row3.id = "message-1" # First chunk of message-1, but ranked third
row3.message_id = "message-1"
row3.chunk_index = 0
row3.text = "chunk 0 text"
row3.organization_id = "org-1"
row3.agent_id = "agent-1"
row3.role = "user"
row3.created_at = None
row3.conversation_id = None
mock_result.rows = [row1, row2, row3]
# Process results with deduplication
client = TurbopufferClient.__new__(TurbopufferClient) # Create without __init__
results = client._process_message_query_results(mock_result, deduplicate=True)
# Should have 2 messages (message-1 deduplicated)
assert len(results) == 2
# First result should be message-1 (from the best-ranked chunk)
assert results[0]["id"] == "message-1"
assert results[0]["text"] == "chunk 1 text" # Text from best-ranked chunk
assert results[0]["chunk_index"] == 1
# Second result should be message-2
assert results[1]["id"] == "message-2"
assert results[1]["text"] == "message 2 text"
# Test without deduplication
results_no_dedup = client._process_message_query_results(mock_result, deduplicate=False)
assert len(results_no_dedup) == 3 # All rows returned