fix: remove unused embedding generation (#9013)

* remove unused embedding generation

* prevent double embed

* fix embedding dimension comparison and valueerror
This commit is contained in:
Kian Jones
2026-01-21 15:50:51 -08:00
committed by Caren Thomas
parent dbc4f88701
commit 2bb4caffc3
4 changed files with 34 additions and 17 deletions

View File

@@ -7,6 +7,7 @@ from datetime import datetime, timezone
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.errors import LettaInvalidArgumentError
from letta.otel.tracing import trace_method from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, TagMatchMode from letta.schemas.enums import MessageRole, TagMatchMode
@@ -321,6 +322,7 @@ class TurbopufferClient:
actor: "PydanticUser", actor: "PydanticUser",
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
created_at: Optional[datetime] = None, created_at: Optional[datetime] = None,
embeddings: Optional[List[List[float]]] = None,
) -> List[PydanticPassage]: ) -> List[PydanticPassage]:
"""Insert passages into Turbopuffer. """Insert passages into Turbopuffer.
@@ -332,6 +334,7 @@ class TurbopufferClient:
actor: User actor for embedding generation actor: User actor for embedding generation
tags: Optional list of tags to attach to all passages tags: Optional list of tags to attach to all passages
created_at: Optional timestamp for retroactive entries (defaults to current UTC time) created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
embeddings: Optional pre-computed embeddings (must match 1:1 with text_chunks). If provided, skips embedding generation.
Returns: Returns:
List of PydanticPassage objects that were inserted List of PydanticPassage objects that were inserted
@@ -345,9 +348,30 @@ class TurbopufferClient:
logger.warning("All text chunks were empty, skipping insertion") logger.warning("All text chunks were empty, skipping insertion")
return [] return []
# generate embeddings using the default config
filtered_texts = [text for _, text in filtered_chunks] filtered_texts = [text for _, text in filtered_chunks]
embeddings = await self._generate_embeddings(filtered_texts, actor)
# use provided embeddings only if dimensions match TPUF's expected dimension
use_provided_embeddings = False
if embeddings is not None:
if len(embeddings) != len(text_chunks):
raise LettaInvalidArgumentError(
f"embeddings length ({len(embeddings)}) must match text_chunks length ({len(text_chunks)})",
argument_name="embeddings",
)
# check if first non-empty embedding has correct dimensions
filtered_indices = [i for i, _ in filtered_chunks]
sample_embedding = embeddings[filtered_indices[0]] if filtered_indices else None
if sample_embedding is not None and len(sample_embedding) == self.default_embedding_config.embedding_dim:
use_provided_embeddings = True
filtered_embeddings = [embeddings[i] for i, _ in filtered_chunks]
else:
logger.debug(
f"Embedding dimension mismatch (got {len(sample_embedding) if sample_embedding else 'None'}, "
f"expected {self.default_embedding_config.embedding_dim}), regenerating embeddings"
)
if not use_provided_embeddings:
filtered_embeddings = await self._generate_embeddings(filtered_texts, actor)
namespace_name = await self._get_archive_namespace_name(archive_id) namespace_name = await self._get_archive_namespace_name(archive_id)
@@ -379,7 +403,7 @@ class TurbopufferClient:
tags_arrays = [] # Store tags as arrays tags_arrays = [] # Store tags as arrays
passages = [] passages = []
for (original_idx, text), embedding in zip(filtered_chunks, embeddings): for (original_idx, text), embedding in zip(filtered_chunks, filtered_embeddings):
passage_id = passage_ids[original_idx] passage_id = passage_ids[original_idx]
# append to columns # append to columns

View File

@@ -2321,15 +2321,6 @@ class AgentManager:
# Use Turbopuffer for vector search if archive is configured for TPUF # Use Turbopuffer for vector search if archive is configured for TPUF
if archive.vector_db_provider == VectorDBProvider.TPUF: if archive.vector_db_provider == VectorDBProvider.TPUF:
from letta.helpers.tpuf_client import TurbopufferClient from letta.helpers.tpuf_client import TurbopufferClient
from letta.llm_api.llm_client import LLMClient
# Generate embedding for query
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
query_embedding = embeddings[0]
# Query Turbopuffer - use hybrid search when text is available # Query Turbopuffer - use hybrid search when text is available
tpuf_client = TurbopufferClient() tpuf_client = TurbopufferClient()

View File

@@ -345,13 +345,14 @@ class ArchiveManager:
tpuf_client = TurbopufferClient() tpuf_client = TurbopufferClient()
# Insert to Turbopuffer with the same ID as SQL # Insert to Turbopuffer with the same ID as SQL, reusing existing embedding
await tpuf_client.insert_archival_memories( await tpuf_client.insert_archival_memories(
archive_id=archive.id, archive_id=archive.id,
text_chunks=[created_passage.text], text_chunks=[created_passage.text],
passage_ids=[created_passage.id], passage_ids=[created_passage.id],
organization_id=actor.organization_id, organization_id=actor.organization_id,
actor=actor, actor=actor,
embeddings=[created_passage.embedding],
) )
logger.info(f"Uploaded passage {created_passage.id} to Turbopuffer for archive {archive_id}") logger.info(f"Uploaded passage {created_passage.id} to Turbopuffer for archive {archive_id}")
except Exception as e: except Exception as e:

View File

@@ -525,20 +525,21 @@ class PassageManager:
tpuf_client = TurbopufferClient() tpuf_client = TurbopufferClient()
# Extract IDs and texts from the created passages # Extract IDs, texts, and embeddings from the created passages
passage_ids = [p.id for p in passages] passage_ids = [p.id for p in passages]
passage_texts = [p.text for p in passages] passage_texts = [p.text for p in passages]
passage_embeddings = [p.embedding for p in passages]
# Insert to Turbopuffer with the same IDs as SQL # Insert to Turbopuffer with the same IDs as SQL, reusing existing embeddings
# TurbopufferClient will generate embeddings internally using default config
await tpuf_client.insert_archival_memories( await tpuf_client.insert_archival_memories(
archive_id=archive.id, archive_id=archive.id,
text_chunks=passage_texts, text_chunks=passage_texts,
passage_ids=passage_ids, # Use same IDs as SQL passage_ids=passage_ids,
organization_id=actor.organization_id, organization_id=actor.organization_id,
actor=actor, actor=actor,
tags=tags, tags=tags,
created_at=passages[0].created_at if passages else None, created_at=passages[0].created_at if passages else None,
embeddings=passage_embeddings,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to insert passages to Turbopuffer: {e}") logger.error(f"Failed to insert passages to Turbopuffer: {e}")