From 01cb00ae1045a2e3e612c71a7614e9e869fbe9e0 Mon Sep 17 00:00:00 2001 From: Kian Jones <11655409+kianjones9@users.noreply.github.com> Date: Fri, 30 Jan 2026 16:35:21 -0800 Subject: [PATCH] Revert "fix: truncate oversized text in embedding requests" (#9227) Revert "fix: truncate oversized text in embedding requests (#9196)" This reverts commit a9c342087e022519c63d62fb76b72aed8859539b. --- letta/helpers/tpuf_client.py | 229 ++++----------------------------- letta/llm_api/openai_client.py | 3 +- tests/test_embeddings.py | 96 -------------- 3 files changed, 25 insertions(+), 303 deletions(-) diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index a59ffe5b..17ac59fa 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -12,7 +12,6 @@ import httpx from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.errors import LettaInvalidArgumentError -from letta.llm_api.llm_client import LLMClient from letta.otel.tracing import trace_method, log_event from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, TagMatchMode @@ -159,38 +158,6 @@ def async_retry_with_backoff( _GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5) -def _split_text_in_half(text: str) -> List[str]: - """Split text roughly in half, trying to break at sentence/word boundaries.""" - if not text: - return [] - - mid = len(text) // 2 - - # Look for a good break point (sentence end, newline, or space) near the middle - best_break = mid - search_range = min(500, mid // 2) # Search within 500 chars of middle - - for offset in range(search_range): - # Check both directions from middle - for pos in [mid + offset, mid - offset]: - if 0 <= pos < len(text) and text[pos] in ".!?\n ": - best_break = pos + 1 - break - else: - continue - break - - first_half = text[:best_break].strip() - second_half = text[best_break:].strip() - - result = [] - if first_half: - result.append(first_half) - if second_half: - result.append(second_half) - return result - - def _run_turbopuffer_write_in_thread( api_key: str, region: str, @@ -288,6 +255,8 @@ class TurbopufferClient: Returns: List of embedding vectors """ + from letta.llm_api.llm_client import LLMClient + # filter out empty strings after stripping filtered_texts = [text for text in texts if text.strip()] @@ -302,98 +271,6 @@ class TurbopufferClient: embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config) return embeddings - @trace_method - async def _generate_embeddings_with_chunking( - self, texts: List[str], actor: "PydanticUser", max_retries: int = 5 - ) -> List[Tuple[str, List[float], int]]: - """Generate embeddings with automatic chunking for texts that exceed context length. - - For texts that are too long, recursively splits them until embedding succeeds. - - Args: - texts: List of texts to embed - actor: User actor for embedding generation - max_retries: Maximum split attempts per text before giving up - - Returns: - List of (text, embedding, original_index) tuples. A single input text may - produce multiple output tuples if it was split. - """ - if not texts: - return [] - - embedding_client = LLMClient.create( - provider_type=self.default_embedding_config.embedding_endpoint_type, - actor=actor, - ) - - results: List[Tuple[str, List[float], int]] = [] - - for original_idx, text in enumerate(texts): - text = text.strip() - if not text: - continue - - # Try to embed, splitting recursively on context length errors - texts_to_embed = [text] - retry_count = 0 - - while texts_to_embed and retry_count < max_retries: - try: - embeddings = await embedding_client.request_embeddings(texts_to_embed, self.default_embedding_config) - - # Success - add all chunks to results - for chunk_text, embedding in zip(texts_to_embed, embeddings): - results.append((chunk_text, embedding, original_idx)) - texts_to_embed = [] # Clear the queue - - except Exception as e: - error_str = str(e).lower() - is_context_length_error = "context length" in error_str or "maximum context" in error_str - - if is_context_length_error and len(texts_to_embed) == 1: - # Single text is too long - split it - long_text = texts_to_embed[0] - split_texts = _split_text_in_half(long_text) - - if len(split_texts) <= 1: - # Can't split further, re-raise the original exception - logger.error(f"Cannot split text further, still exceeds context limit: {len(long_text)} chars") - raise e - - logger.warning( - f"Text exceeds context limit ({len(long_text)} chars), splitting into {len(split_texts)} chunks and retrying" - ) - texts_to_embed = split_texts - retry_count += 1 - - elif is_context_length_error and len(texts_to_embed) > 1: - # Multiple texts - try one at a time to find the problematic one - logger.warning("Batch embedding failed with context error, trying individually") - new_texts = [] - for t in texts_to_embed: - try: - emb = await embedding_client.request_embeddings([t], self.default_embedding_config) - results.append((t, emb[0], original_idx)) - except Exception as inner_e: - if "context length" in str(inner_e).lower(): - # This text is too long - split it - new_texts.extend(_split_text_in_half(t)) - else: - raise inner_e - texts_to_embed = new_texts - retry_count += 1 - else: - # Non-context-length error, re-raise the original exception - raise e - - if texts_to_embed: - # Exhausted retries - logger.error(f"Failed to embed text after {max_retries} split attempts") - raise RuntimeError(f"Text could not be embedded after {max_retries} split attempts") - - return results - @trace_method async def _get_archive_namespace_name(self, archive_id: str) -> str: """Get namespace name for a specific archive.""" @@ -759,52 +636,30 @@ class TurbopufferClient: """ from turbopuffer import AsyncTurbopuffer - # validation checks - if not message_ids: - raise LettaInvalidArgumentError("message_ids must be provided for Turbopuffer insertion", argument_name="message_ids") - if len(message_ids) != len(message_texts): - raise LettaInvalidArgumentError( - f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})", - argument_name="message_ids", - ) - if len(message_ids) != len(roles): - raise LettaInvalidArgumentError( - f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})", argument_name="roles" - ) - if len(message_ids) != len(created_ats): - raise LettaInvalidArgumentError( - f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})", argument_name="created_ats" - ) - if conversation_ids is not None and len(conversation_ids) != len(message_ids): - raise LettaInvalidArgumentError( - f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})", - argument_name="conversation_ids", - ) + # filter out empty message texts + filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()] - # Filter out empty messages and prepare metadata - valid_messages = [] # List of (index, text) - for i, text in enumerate(message_texts): - text = text.strip() - if text: - valid_messages.append((i, text)) - - if not valid_messages: + if not filtered_messages: logger.warning("All message texts were empty, skipping insertion") return True - # Generate embeddings with automatic chunking for texts that exceed context length - # This returns (chunk_text, embedding, original_valid_idx) tuples - texts_to_embed = [text for _, text in valid_messages] - embedding_results = await self._generate_embeddings_with_chunking(texts_to_embed, actor) + # generate embeddings using the default config + filtered_texts = [text for _, text in filtered_messages] + embeddings = await self._generate_embeddings(filtered_texts, actor) namespace_name = await self._get_message_namespace_name(organization_id) - # Build a mapping from valid_messages index to original message metadata - # This lets us look up the original message_id, role, etc. for each chunk - valid_idx_to_original = {valid_idx: original_idx for valid_idx, (original_idx, _) in enumerate(valid_messages)} - - # Track chunk indices per message for composite IDs - message_chunk_counts: dict = {} + # validation checks + if not message_ids: + raise ValueError("message_ids must be provided for Turbopuffer insertion") + if len(message_ids) != len(message_texts): + raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})") + if len(message_ids) != len(roles): + raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})") + if len(message_ids) != len(created_ats): + raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})") + if conversation_ids is not None and len(conversation_ids) != len(message_ids): + raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})") # prepare column-based data for turbopuffer - optimized for batch insert ids = [] @@ -812,32 +667,18 @@ class TurbopufferClient: texts = [] organization_ids_list = [] agent_ids_list = [] - message_id_list = [] # Store original message_id for deduplication - chunk_index_list = [] # Store chunk index message_roles = [] created_at_timestamps = [] project_ids_list = [] template_ids_list = [] conversation_ids_list = [] - for chunk_text, embedding, valid_idx in embedding_results: - # Map back to original message metadata - original_idx = valid_idx_to_original[valid_idx] + for (original_idx, text), embedding in zip(filtered_messages, embeddings): message_id = message_ids[original_idx] role = roles[original_idx] created_at = created_ats[original_idx] conversation_id = conversation_ids[original_idx] if conversation_ids else None - # Track chunk index for this message - chunk_idx = message_chunk_counts.get(message_id, 0) - message_chunk_counts[message_id] = chunk_idx + 1 - - # Use composite ID for chunks: message_id for chunk 0, message_id_chunk_N for others - if chunk_idx == 0: - record_id = message_id - else: - record_id = f"{message_id}_chunk_{chunk_idx}" - # ensure the provided timestamp is timezone-aware and in UTC if created_at.tzinfo is None: # assume UTC if no timezone provided @@ -847,13 +688,11 @@ class TurbopufferClient: timestamp = created_at.astimezone(timezone.utc) # append to columns - ids.append(record_id) + ids.append(message_id) vectors.append(embedding) - texts.append(chunk_text) + texts.append(text) organization_ids_list.append(organization_id) agent_ids_list.append(agent_id) - message_id_list.append(message_id) # Original message ID for deduplication - chunk_index_list.append(chunk_idx) message_roles.append(role.value) created_at_timestamps.append(timestamp) project_ids_list.append(project_id) @@ -867,8 +706,6 @@ class TurbopufferClient: "text": texts, "organization_id": organization_ids_list, "agent_id": agent_ids_list, - "message_id": message_id_list, # Original message ID for deduplication - "chunk_index": chunk_index_list, # Chunk index (0 for first/only chunk) "role": message_roles, "created_at": created_at_timestamps, } @@ -885,12 +722,6 @@ class TurbopufferClient: if template_id is not None: upsert_columns["template_id"] = template_ids_list - # Log if we chunked any messages - num_messages_chunked = sum(1 for count in message_chunk_counts.values() if count > 1) - total_chunks = sum(message_chunk_counts.values()) - if num_messages_chunked > 0: - logger.info(f"Split {num_messages_chunked} messages into {total_chunks} chunks for embedding") - try: # Use global semaphore to limit concurrent Turbopuffer writes async with _GLOBAL_TURBOPUFFER_SEMAPHORE: @@ -1552,35 +1383,23 @@ class TurbopufferClient: logger.error(f"Failed to query messages from Turbopuffer: {e}") raise - def _process_message_query_results(self, result, deduplicate: bool = True) -> List[dict]: + def _process_message_query_results(self, result) -> List[dict]: """Process results from a message query into message dicts. For RRF, we only need the rank order - scores are not used. - Deduplicates by message_id to handle chunked messages - keeps the first (best-ranked) chunk. """ messages = [] - seen_message_ids = set() for row in result.rows: - # Get the original message_id (for chunked messages) or fall back to row.id - message_id = getattr(row, "message_id", None) or row.id - chunk_index = getattr(row, "chunk_index", 0) or 0 - - # Deduplicate by message_id - keep only the first (best-ranked) chunk - if deduplicate and message_id in seen_message_ids: - continue - seen_message_ids.add(message_id) - # Build message dict with key fields message_dict = { - "id": message_id, # Use original message_id, not chunk id + "id": row.id, "text": getattr(row, "text", ""), "organization_id": getattr(row, "organization_id", None), "agent_id": getattr(row, "agent_id", None), "role": getattr(row, "role", None), "created_at": getattr(row, "created_at", None), "conversation_id": getattr(row, "conversation_id", None), - "chunk_index": chunk_index, # Include chunk index for debugging } messages.append(message_dict) diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index ba352597..ca26c768 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -911,8 +911,7 @@ class OpenAIClient(LLMClientBase): if isinstance(result, Exception): current_size = len(chunk_inputs) - if current_batch_size > 1 and current_size > 1: - # Multiple inputs in batch - try splitting the batch + if current_batch_size > 1: new_batch_size = max(1, current_batch_size // 2) logger.warning( f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. " diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 96c63837..fd4c0871 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -203,99 +203,3 @@ async def test_openai_embedding_minimum_chunk_failure(default_user): with pytest.raises(Exception, match="API error"): await client.request_embeddings(test_inputs, embedding_config) - - -def test_split_text_in_half(): - """Test the _split_text_in_half helper function.""" - from letta.helpers.tpuf_client import _split_text_in_half - - # Test with text that has sentence boundaries - long_text = "This is a test sentence. " * 100 - splits = _split_text_in_half(long_text) - assert len(splits) == 2 - assert len(splits[0]) > 0 - assert len(splits[1]) > 0 - # Should split at a sentence boundary - assert splits[0].endswith(".") - - # Test with text that has no good break points - no_breaks = "a" * 1000 - splits = _split_text_in_half(no_breaks) - assert len(splits) == 2 - assert len(splits[0]) + len(splits[1]) == 1000 - - # Test with empty text - splits = _split_text_in_half("") - assert splits == [] - - # Test with short text (still splits) - short_text = "hello world" - splits = _split_text_in_half(short_text) - assert len(splits) == 2 - - -def test_chunked_message_query_deduplication(): - """Test that chunked messages are deduplicated by message_id in query results.""" - from unittest.mock import MagicMock - - from letta.helpers.tpuf_client import TurbopufferClient - - # Create a mock result with multiple chunks from the same message - mock_result = MagicMock() - - # Simulate 3 rows: 2 chunks from message-1, 1 chunk from message-2 - # The chunks are ranked by relevance (row order = rank order) - row1 = MagicMock() - row1.id = "message-1_chunk_1" # Second chunk of message-1, but ranked first - row1.message_id = "message-1" - row1.chunk_index = 1 - row1.text = "chunk 1 text" - row1.organization_id = "org-1" - row1.agent_id = "agent-1" - row1.role = "user" - row1.created_at = None - row1.conversation_id = None - - row2 = MagicMock() - row2.id = "message-2" - row2.message_id = "message-2" - row2.chunk_index = 0 - row2.text = "message 2 text" - row2.organization_id = "org-1" - row2.agent_id = "agent-1" - row2.role = "assistant" - row2.created_at = None - row2.conversation_id = None - - row3 = MagicMock() - row3.id = "message-1" # First chunk of message-1, but ranked third - row3.message_id = "message-1" - row3.chunk_index = 0 - row3.text = "chunk 0 text" - row3.organization_id = "org-1" - row3.agent_id = "agent-1" - row3.role = "user" - row3.created_at = None - row3.conversation_id = None - - mock_result.rows = [row1, row2, row3] - - # Process results with deduplication - client = TurbopufferClient.__new__(TurbopufferClient) # Create without __init__ - results = client._process_message_query_results(mock_result, deduplicate=True) - - # Should have 2 messages (message-1 deduplicated) - assert len(results) == 2 - - # First result should be message-1 (from the best-ranked chunk) - assert results[0]["id"] == "message-1" - assert results[0]["text"] == "chunk 1 text" # Text from best-ranked chunk - assert results[0]["chunk_index"] == 1 - - # Second result should be message-2 - assert results[1]["id"] == "message-2" - assert results[1]["text"] == "message 2 text" - - # Test without deduplication - results_no_dedup = client._process_message_query_results(mock_result, deduplicate=False) - assert len(results_no_dedup) == 3 # All rows returned