diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 17ac59fa..a59ffe5b 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -12,6 +12,7 @@ import httpx from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.errors import LettaInvalidArgumentError +from letta.llm_api.llm_client import LLMClient from letta.otel.tracing import trace_method, log_event from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, TagMatchMode @@ -158,6 +159,38 @@ def async_retry_with_backoff( _GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5) +def _split_text_in_half(text: str) -> List[str]: + """Split text roughly in half, trying to break at sentence/word boundaries.""" + if not text: + return [] + + mid = len(text) // 2 + + # Look for a good break point (sentence end, newline, or space) near the middle + best_break = mid + search_range = min(500, mid // 2) # Search within 500 chars of middle + + for offset in range(search_range): + # Check both directions from middle + for pos in [mid + offset, mid - offset]: + if 0 <= pos < len(text) and text[pos] in ".!?\n ": + best_break = pos + 1 + break + else: + continue + break + + first_half = text[:best_break].strip() + second_half = text[best_break:].strip() + + result = [] + if first_half: + result.append(first_half) + if second_half: + result.append(second_half) + return result + + def _run_turbopuffer_write_in_thread( api_key: str, region: str, @@ -255,8 +288,6 @@ class TurbopufferClient: Returns: List of embedding vectors """ - from letta.llm_api.llm_client import LLMClient - # filter out empty strings after stripping filtered_texts = [text for text in texts if text.strip()] @@ -271,6 +302,98 @@ class TurbopufferClient: embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config) return embeddings + @trace_method + async def _generate_embeddings_with_chunking( + self, texts: List[str], actor: "PydanticUser", max_retries: int = 5 + ) -> List[Tuple[str, List[float], int]]: + """Generate embeddings with automatic chunking for texts that exceed context length. + + For texts that are too long, recursively splits them until embedding succeeds. + + Args: + texts: List of texts to embed + actor: User actor for embedding generation + max_retries: Maximum split attempts per text before giving up + + Returns: + List of (text, embedding, original_index) tuples. A single input text may + produce multiple output tuples if it was split. + """ + if not texts: + return [] + + embedding_client = LLMClient.create( + provider_type=self.default_embedding_config.embedding_endpoint_type, + actor=actor, + ) + + results: List[Tuple[str, List[float], int]] = [] + + for original_idx, text in enumerate(texts): + text = text.strip() + if not text: + continue + + # Try to embed, splitting recursively on context length errors + texts_to_embed = [text] + retry_count = 0 + + while texts_to_embed and retry_count < max_retries: + try: + embeddings = await embedding_client.request_embeddings(texts_to_embed, self.default_embedding_config) + + # Success - add all chunks to results + for chunk_text, embedding in zip(texts_to_embed, embeddings): + results.append((chunk_text, embedding, original_idx)) + texts_to_embed = [] # Clear the queue + + except Exception as e: + error_str = str(e).lower() + is_context_length_error = "context length" in error_str or "maximum context" in error_str + + if is_context_length_error and len(texts_to_embed) == 1: + # Single text is too long - split it + long_text = texts_to_embed[0] + split_texts = _split_text_in_half(long_text) + + if len(split_texts) <= 1: + # Can't split further, re-raise the original exception + logger.error(f"Cannot split text further, still exceeds context limit: {len(long_text)} chars") + raise e + + logger.warning( + f"Text exceeds context limit ({len(long_text)} chars), splitting into {len(split_texts)} chunks and retrying" + ) + texts_to_embed = split_texts + retry_count += 1 + + elif is_context_length_error and len(texts_to_embed) > 1: + # Multiple texts - try one at a time to find the problematic one + logger.warning("Batch embedding failed with context error, trying individually") + new_texts = [] + for t in texts_to_embed: + try: + emb = await embedding_client.request_embeddings([t], self.default_embedding_config) + results.append((t, emb[0], original_idx)) + except Exception as inner_e: + if "context length" in str(inner_e).lower(): + # This text is too long - split it + new_texts.extend(_split_text_in_half(t)) + else: + raise inner_e + texts_to_embed = new_texts + retry_count += 1 + else: + # Non-context-length error, re-raise the original exception + raise e + + if texts_to_embed: + # Exhausted retries + logger.error(f"Failed to embed text after {max_retries} split attempts") + raise RuntimeError(f"Text could not be embedded after {max_retries} split attempts") + + return results + @trace_method async def _get_archive_namespace_name(self, archive_id: str) -> str: """Get namespace name for a specific archive.""" @@ -636,30 +759,52 @@ class TurbopufferClient: """ from turbopuffer import AsyncTurbopuffer - # filter out empty message texts - filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()] + # validation checks + if not message_ids: + raise LettaInvalidArgumentError("message_ids must be provided for Turbopuffer insertion", argument_name="message_ids") + if len(message_ids) != len(message_texts): + raise LettaInvalidArgumentError( + f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})", + argument_name="message_ids", + ) + if len(message_ids) != len(roles): + raise LettaInvalidArgumentError( + f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})", argument_name="roles" + ) + if len(message_ids) != len(created_ats): + raise LettaInvalidArgumentError( + f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})", argument_name="created_ats" + ) + if conversation_ids is not None and len(conversation_ids) != len(message_ids): + raise LettaInvalidArgumentError( + f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})", + argument_name="conversation_ids", + ) - if not filtered_messages: + # Filter out empty messages and prepare metadata + valid_messages = [] # List of (index, text) + for i, text in enumerate(message_texts): + text = text.strip() + if text: + valid_messages.append((i, text)) + + if not valid_messages: logger.warning("All message texts were empty, skipping insertion") return True - # generate embeddings using the default config - filtered_texts = [text for _, text in filtered_messages] - embeddings = await self._generate_embeddings(filtered_texts, actor) + # Generate embeddings with automatic chunking for texts that exceed context length + # This returns (chunk_text, embedding, original_valid_idx) tuples + texts_to_embed = [text for _, text in valid_messages] + embedding_results = await self._generate_embeddings_with_chunking(texts_to_embed, actor) namespace_name = await self._get_message_namespace_name(organization_id) - # validation checks - if not message_ids: - raise ValueError("message_ids must be provided for Turbopuffer insertion") - if len(message_ids) != len(message_texts): - raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})") - if len(message_ids) != len(roles): - raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})") - if len(message_ids) != len(created_ats): - raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})") - if conversation_ids is not None and len(conversation_ids) != len(message_ids): - raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})") + # Build a mapping from valid_messages index to original message metadata + # This lets us look up the original message_id, role, etc. for each chunk + valid_idx_to_original = {valid_idx: original_idx for valid_idx, (original_idx, _) in enumerate(valid_messages)} + + # Track chunk indices per message for composite IDs + message_chunk_counts: dict = {} # prepare column-based data for turbopuffer - optimized for batch insert ids = [] @@ -667,18 +812,32 @@ class TurbopufferClient: texts = [] organization_ids_list = [] agent_ids_list = [] + message_id_list = [] # Store original message_id for deduplication + chunk_index_list = [] # Store chunk index message_roles = [] created_at_timestamps = [] project_ids_list = [] template_ids_list = [] conversation_ids_list = [] - for (original_idx, text), embedding in zip(filtered_messages, embeddings): + for chunk_text, embedding, valid_idx in embedding_results: + # Map back to original message metadata + original_idx = valid_idx_to_original[valid_idx] message_id = message_ids[original_idx] role = roles[original_idx] created_at = created_ats[original_idx] conversation_id = conversation_ids[original_idx] if conversation_ids else None + # Track chunk index for this message + chunk_idx = message_chunk_counts.get(message_id, 0) + message_chunk_counts[message_id] = chunk_idx + 1 + + # Use composite ID for chunks: message_id for chunk 0, message_id_chunk_N for others + if chunk_idx == 0: + record_id = message_id + else: + record_id = f"{message_id}_chunk_{chunk_idx}" + # ensure the provided timestamp is timezone-aware and in UTC if created_at.tzinfo is None: # assume UTC if no timezone provided @@ -688,11 +847,13 @@ class TurbopufferClient: timestamp = created_at.astimezone(timezone.utc) # append to columns - ids.append(message_id) + ids.append(record_id) vectors.append(embedding) - texts.append(text) + texts.append(chunk_text) organization_ids_list.append(organization_id) agent_ids_list.append(agent_id) + message_id_list.append(message_id) # Original message ID for deduplication + chunk_index_list.append(chunk_idx) message_roles.append(role.value) created_at_timestamps.append(timestamp) project_ids_list.append(project_id) @@ -706,6 +867,8 @@ class TurbopufferClient: "text": texts, "organization_id": organization_ids_list, "agent_id": agent_ids_list, + "message_id": message_id_list, # Original message ID for deduplication + "chunk_index": chunk_index_list, # Chunk index (0 for first/only chunk) "role": message_roles, "created_at": created_at_timestamps, } @@ -722,6 +885,12 @@ class TurbopufferClient: if template_id is not None: upsert_columns["template_id"] = template_ids_list + # Log if we chunked any messages + num_messages_chunked = sum(1 for count in message_chunk_counts.values() if count > 1) + total_chunks = sum(message_chunk_counts.values()) + if num_messages_chunked > 0: + logger.info(f"Split {num_messages_chunked} messages into {total_chunks} chunks for embedding") + try: # Use global semaphore to limit concurrent Turbopuffer writes async with _GLOBAL_TURBOPUFFER_SEMAPHORE: @@ -1383,23 +1552,35 @@ class TurbopufferClient: logger.error(f"Failed to query messages from Turbopuffer: {e}") raise - def _process_message_query_results(self, result) -> List[dict]: + def _process_message_query_results(self, result, deduplicate: bool = True) -> List[dict]: """Process results from a message query into message dicts. For RRF, we only need the rank order - scores are not used. + Deduplicates by message_id to handle chunked messages - keeps the first (best-ranked) chunk. """ messages = [] + seen_message_ids = set() for row in result.rows: + # Get the original message_id (for chunked messages) or fall back to row.id + message_id = getattr(row, "message_id", None) or row.id + chunk_index = getattr(row, "chunk_index", 0) or 0 + + # Deduplicate by message_id - keep only the first (best-ranked) chunk + if deduplicate and message_id in seen_message_ids: + continue + seen_message_ids.add(message_id) + # Build message dict with key fields message_dict = { - "id": row.id, + "id": message_id, # Use original message_id, not chunk id "text": getattr(row, "text", ""), "organization_id": getattr(row, "organization_id", None), "agent_id": getattr(row, "agent_id", None), "role": getattr(row, "role", None), "created_at": getattr(row, "created_at", None), "conversation_id": getattr(row, "conversation_id", None), + "chunk_index": chunk_index, # Include chunk index for debugging } messages.append(message_dict) diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index ca26c768..ba352597 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -911,7 +911,8 @@ class OpenAIClient(LLMClientBase): if isinstance(result, Exception): current_size = len(chunk_inputs) - if current_batch_size > 1: + if current_batch_size > 1 and current_size > 1: + # Multiple inputs in batch - try splitting the batch new_batch_size = max(1, current_batch_size // 2) logger.warning( f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. " diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index fd4c0871..96c63837 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -203,3 +203,99 @@ async def test_openai_embedding_minimum_chunk_failure(default_user): with pytest.raises(Exception, match="API error"): await client.request_embeddings(test_inputs, embedding_config) + + +def test_split_text_in_half(): + """Test the _split_text_in_half helper function.""" + from letta.helpers.tpuf_client import _split_text_in_half + + # Test with text that has sentence boundaries + long_text = "This is a test sentence. " * 100 + splits = _split_text_in_half(long_text) + assert len(splits) == 2 + assert len(splits[0]) > 0 + assert len(splits[1]) > 0 + # Should split at a sentence boundary + assert splits[0].endswith(".") + + # Test with text that has no good break points + no_breaks = "a" * 1000 + splits = _split_text_in_half(no_breaks) + assert len(splits) == 2 + assert len(splits[0]) + len(splits[1]) == 1000 + + # Test with empty text + splits = _split_text_in_half("") + assert splits == [] + + # Test with short text (still splits) + short_text = "hello world" + splits = _split_text_in_half(short_text) + assert len(splits) == 2 + + +def test_chunked_message_query_deduplication(): + """Test that chunked messages are deduplicated by message_id in query results.""" + from unittest.mock import MagicMock + + from letta.helpers.tpuf_client import TurbopufferClient + + # Create a mock result with multiple chunks from the same message + mock_result = MagicMock() + + # Simulate 3 rows: 2 chunks from message-1, 1 chunk from message-2 + # The chunks are ranked by relevance (row order = rank order) + row1 = MagicMock() + row1.id = "message-1_chunk_1" # Second chunk of message-1, but ranked first + row1.message_id = "message-1" + row1.chunk_index = 1 + row1.text = "chunk 1 text" + row1.organization_id = "org-1" + row1.agent_id = "agent-1" + row1.role = "user" + row1.created_at = None + row1.conversation_id = None + + row2 = MagicMock() + row2.id = "message-2" + row2.message_id = "message-2" + row2.chunk_index = 0 + row2.text = "message 2 text" + row2.organization_id = "org-1" + row2.agent_id = "agent-1" + row2.role = "assistant" + row2.created_at = None + row2.conversation_id = None + + row3 = MagicMock() + row3.id = "message-1" # First chunk of message-1, but ranked third + row3.message_id = "message-1" + row3.chunk_index = 0 + row3.text = "chunk 0 text" + row3.organization_id = "org-1" + row3.agent_id = "agent-1" + row3.role = "user" + row3.created_at = None + row3.conversation_id = None + + mock_result.rows = [row1, row2, row3] + + # Process results with deduplication + client = TurbopufferClient.__new__(TurbopufferClient) # Create without __init__ + results = client._process_message_query_results(mock_result, deduplicate=True) + + # Should have 2 messages (message-1 deduplicated) + assert len(results) == 2 + + # First result should be message-1 (from the best-ranked chunk) + assert results[0]["id"] == "message-1" + assert results[0]["text"] == "chunk 1 text" # Text from best-ranked chunk + assert results[0]["chunk_index"] == 1 + + # Second result should be message-2 + assert results[1]["id"] == "message-2" + assert results[1]["text"] == "message 2 text" + + # Test without deduplication + results_no_dedup = client._process_message_query_results(mock_result, deduplicate=False) + assert len(results_no_dedup) == 3 # All rows returned