diff --git a/letta/embeddings.py b/letta/embeddings.py index e8d1f54d..4dca8aab 100644 --- a/letta/embeddings.py +++ b/letta/embeddings.py @@ -3,6 +3,7 @@ from typing import Any, List, Optional import numpy as np import tiktoken +from openai import OpenAI from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP, MAX_EMBEDDING_DIM from letta.schemas.embedding_config import EmbeddingConfig @@ -201,6 +202,21 @@ class GoogleVertexEmbeddings: return response.embeddings[0].embedding +class OpenAIEmbeddings: + + def __init__(self, api_key: str, model: str, base_url: str): + if base_url: + self.client = OpenAI(api_key=api_key, base_url=base_url) + else: + self.client = OpenAI(api_key=api_key) + self.model = model + + def get_text_embedding(self, text: str): + response = self.client.embeddings.create(input=text, model=self.model) + + return response.data[0].embedding + + def query_embedding(embedding_model, query_text: str): """Generate padded embedding for querying database""" query_vec = embedding_model.get_text_embedding(query_text) @@ -218,15 +234,9 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None from letta.settings import model_settings if endpoint_type == "openai": - from llama_index.embeddings.openai import OpenAIEmbedding - - additional_kwargs = {"user_id": user_id} if user_id else {} - model = OpenAIEmbedding( - api_base=config.embedding_endpoint, - api_key=model_settings.openai_api_key, - additional_kwargs=additional_kwargs, + return OpenAIEmbeddings( + api_key=model_settings.openai_api_key, model=config.embedding_model, base_url=model_settings.openai_api_base ) - return model elif endpoint_type == "azure": assert all(