fix: re-add embedding dimension padding for archival memories (#6041)
fix embedding dimension padding
This commit is contained in:
@@ -56,20 +56,33 @@ class Passage(PassageBase):
|
||||
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.")
|
||||
|
||||
@field_validator("embedding")
|
||||
@field_validator("embedding", mode="before")
|
||||
@classmethod
|
||||
def pad_embeddings(cls, embedding: List[float]) -> List[float]:
|
||||
def pad_embeddings(cls, embedding: List[float], info) -> List[float]:
|
||||
"""Pad embeddings to `MAX_EMBEDDING_SIZE`. This is necessary to ensure all stored embeddings are the same size."""
|
||||
# Only do this if using pgvector
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
if embedding is None:
|
||||
return embedding
|
||||
|
||||
if not should_use_tpuf() or settings.environment != "PRODUCTION":
|
||||
import numpy as np
|
||||
# Check if this is an archival memory passage (has archive_id) or file passage (has file_id)
|
||||
data = info.data if hasattr(info, "data") else {}
|
||||
is_archival = data.get("archive_id") is not None
|
||||
is_file = data.get("file_id") is not None
|
||||
|
||||
# Pad if using pgvector
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# For archival memory: always pad
|
||||
# For file passages: only pad if NOT using turbopuffer
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
|
||||
should_pad = is_archival or (is_file and not should_use_tpuf())
|
||||
|
||||
if should_pad:
|
||||
import numpy as np
|
||||
|
||||
if embedding and len(embedding) != MAX_EMBEDDING_DIM:
|
||||
np_embedding = np.array(embedding)
|
||||
padded_embedding = np.pad(np_embedding, (0, MAX_EMBEDDING_DIM - np_embedding.shape[0]), mode="constant")
|
||||
return padded_embedding.tolist()
|
||||
if np_embedding.shape[0] != MAX_EMBEDDING_DIM:
|
||||
padded_embedding = np.pad(np_embedding, (0, MAX_EMBEDDING_DIM - np_embedding.shape[0]), mode="constant")
|
||||
return padded_embedding.tolist()
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
Reference in New Issue
Block a user