feat: use openai embedding client instead of llama-index (#1313)
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP, MAX_EMBEDDING_DIM
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -201,6 +202,21 @@ class GoogleVertexEmbeddings:
|
||||
return response.embeddings[0].embedding
|
||||
|
||||
|
||||
class OpenAIEmbeddings:
|
||||
|
||||
def __init__(self, api_key: str, model: str, base_url: str):
|
||||
if base_url:
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
|
||||
def get_text_embedding(self, text: str):
|
||||
response = self.client.embeddings.create(input=text, model=self.model)
|
||||
|
||||
return response.data[0].embedding
|
||||
|
||||
|
||||
def query_embedding(embedding_model, query_text: str):
|
||||
"""Generate padded embedding for querying database"""
|
||||
query_vec = embedding_model.get_text_embedding(query_text)
|
||||
@@ -218,15 +234,9 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
from letta.settings import model_settings
|
||||
|
||||
if endpoint_type == "openai":
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
additional_kwargs = {"user_id": user_id} if user_id else {}
|
||||
model = OpenAIEmbedding(
|
||||
api_base=config.embedding_endpoint,
|
||||
api_key=model_settings.openai_api_key,
|
||||
additional_kwargs=additional_kwargs,
|
||||
return OpenAIEmbeddings(
|
||||
api_key=model_settings.openai_api_key, model=config.embedding_model, base_url=model_settings.openai_api_base
|
||||
)
|
||||
return model
|
||||
|
||||
elif endpoint_type == "azure":
|
||||
assert all(
|
||||
|
||||
Reference in New Issue
Block a user