feat: various fixes (#2320)

Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
Co-authored-by: Shubham Naik <shubham.naik10@gmail.com>
Co-authored-by: Caren Thomas <caren@letta.com>
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Sarah Wooders
2024-12-31 10:53:33 +04:00
committed by GitHub
parent 644fff77c3
commit ece8dab05d
79 changed files with 565 additions and 783 deletions

View File

@@ -1,13 +1,14 @@
import base64
import sqlite3
from typing import Optional, Union
import base64
import numpy as np
from sqlalchemy import event
from sqlalchemy.engine import Engine
import sqlite3
from letta.constants import MAX_EMBEDDING_DIM
def adapt_array(arr):
"""
Converts numpy array to binary for SQLite storage
@@ -19,12 +20,13 @@ def adapt_array(arr):
arr = np.array(arr, dtype=np.float32)
elif not isinstance(arr, np.ndarray):
raise ValueError(f"Unsupported type: {type(arr)}")
# Convert to bytes and then base64 encode
bytes_data = arr.tobytes()
base64_data = base64.b64encode(bytes_data)
return sqlite3.Binary(base64_data)
def convert_array(text):
"""
Converts binary back to numpy array
@@ -38,23 +40,24 @@ def convert_array(text):
# Handle both bytes and sqlite3.Binary
binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text
try:
# First decode base64
decoded_data = base64.b64decode(binary_data)
# Then convert to numpy array
return np.frombuffer(decoded_data, dtype=np.float32)
except Exception as e:
except Exception:
return None
def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
"""
Verifies that an embedding has the expected dimension
Args:
embedding: Input embedding array
expected_dim: Expected embedding dimension (default: 4096)
Returns:
bool: True if dimension matches, False otherwise
"""
@@ -62,28 +65,27 @@ def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EM
return False
return embedding.shape[0] == expected_dim
def validate_and_transform_embedding(
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray],
expected_dim: int = MAX_EMBEDDING_DIM,
dtype: np.dtype = np.float32
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], expected_dim: int = MAX_EMBEDDING_DIM, dtype: np.dtype = np.float32
) -> Optional[np.ndarray]:
"""
Validates and transforms embeddings to ensure correct dimensionality.
Args:
embedding: Input embedding in various possible formats
expected_dim: Expected embedding dimension (default 4096)
dtype: NumPy dtype for the embedding (default float32)
Returns:
np.ndarray: Validated and transformed embedding
Raises:
ValueError: If embedding dimension doesn't match expected dimension
"""
if embedding is None:
return None
# Convert to numpy array based on input type
if isinstance(embedding, (bytes, sqlite3.Binary)):
vec = convert_array(embedding)
@@ -93,48 +95,49 @@ def validate_and_transform_embedding(
vec = embedding.astype(dtype)
else:
raise ValueError(f"Unsupported embedding type: {type(embedding)}")
# Validate dimension
if vec.shape[0] != expected_dim:
raise ValueError(
f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}"
)
raise ValueError(f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}")
return vec
def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
"""
Calculate cosine distance between two embeddings
Args:
embedding1: First embedding
embedding2: Second embedding
expected_dim: Expected embedding dimension (default 4096)
Returns:
float: Cosine distance
"""
if embedding1 is None or embedding2 is None:
return 0.0 # Maximum distance if either embedding is None
try:
vec1 = validate_and_transform_embedding(embedding1, expected_dim)
vec2 = validate_and_transform_embedding(embedding2, expected_dim)
except ValueError as e:
except ValueError:
return 0.0
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
distance = float(1.0 - similarity)
return distance
@event.listens_for(Engine, "connect")
def register_functions(dbapi_connection, connection_record):
"""Register SQLite functions"""
if isinstance(dbapi_connection, sqlite3.Connection):
dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
# Register adapters and converters for numpy arrays
sqlite3.register_adapter(np.ndarray, adapt_array)
sqlite3.register_converter("ARRAY", convert_array)