feat: use openai embedding client instead of llama-index (#1313)

This commit is contained in:
Sarah Wooders
2025-03-17 09:13:36 -07:00
committed by GitHub
parent a2b419a528
commit e404dbaaac

View File

@@ -3,6 +3,7 @@ from typing import Any, List, Optional
import numpy as np
import tiktoken
from openai import OpenAI
from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP, MAX_EMBEDDING_DIM
from letta.schemas.embedding_config import EmbeddingConfig
@@ -201,6 +202,21 @@ class GoogleVertexEmbeddings:
return response.embeddings[0].embedding
class OpenAIEmbeddings:
def __init__(self, api_key: str, model: str, base_url: str):
if base_url:
self.client = OpenAI(api_key=api_key, base_url=base_url)
else:
self.client = OpenAI(api_key=api_key)
self.model = model
def get_text_embedding(self, text: str):
response = self.client.embeddings.create(input=text, model=self.model)
return response.data[0].embedding
def query_embedding(embedding_model, query_text: str):
"""Generate padded embedding for querying database"""
query_vec = embedding_model.get_text_embedding(query_text)
@@ -218,15 +234,9 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
from letta.settings import model_settings
if endpoint_type == "openai":
from llama_index.embeddings.openai import OpenAIEmbedding
additional_kwargs = {"user_id": user_id} if user_id else {}
model = OpenAIEmbedding(
api_base=config.embedding_endpoint,
api_key=model_settings.openai_api_key,
additional_kwargs=additional_kwargs,
return OpenAIEmbeddings(
api_key=model_settings.openai_api_key, model=config.embedding_model, base_url=model_settings.openai_api_base
)
return model
elif endpoint_type == "azure":
assert all(