fix: Add embedding tests to azure (#1920)

This commit is contained in:
Matthew Zhou
2024-10-22 11:53:49 -07:00
committed by GitHub
parent 1a93b85bfd
commit a2e1cfd9e5
5 changed files with 90 additions and 21 deletions

View File

@@ -21,7 +21,6 @@ from letta.constants import (
EMBEDDING_TO_TOKENIZER_MAP,
MAX_EMBEDDING_DIM,
)
from letta.credentials import LettaCredentials
from letta.schemas.embedding_config import EmbeddingConfig
from letta.utils import is_valid_url, printd
@@ -138,6 +137,18 @@ class EmbeddingEndpoint:
return self._call_api(text)
class AzureOpenAIEmbedding:
def __init__(self, api_endpoint: str, api_key: str, api_version: str, model: str):
from openai import AzureOpenAI
self.client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_endpoint)
self.model = model
def get_text_embedding(self, text: str):
embeddings = self.client.embeddings.create(input=[text], model=self.model).data[0].embedding
return embeddings
def default_embedding_model():
# default to hugging face model running local
# warning: this is a terrible model
@@ -161,8 +172,8 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
endpoint_type = config.embedding_endpoint_type
# TODO refactor to pass credentials through args
credentials = LettaCredentials.load()
# TODO: refactor to pass in settings from server
from letta.settings import model_settings
if endpoint_type == "openai":
from llama_index.embeddings.openai import OpenAIEmbedding
@@ -170,7 +181,7 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
additional_kwargs = {"user_id": user_id} if user_id else {}
model = OpenAIEmbedding(
api_base=config.embedding_endpoint,
api_key=credentials.openai_key,
api_key=model_settings.openai_api_key,
additional_kwargs=additional_kwargs,
)
return model
@@ -178,22 +189,29 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
elif endpoint_type == "azure":
assert all(
[
credentials.azure_key is not None,
credentials.azure_embedding_endpoint is not None,
credentials.azure_version is not None,
model_settings.azure_api_key is not None,
model_settings.azure_base_url is not None,
model_settings.azure_api_version is not None,
]
)
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
## https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
# model = "text-embedding-ada-002"
# deployment = credentials.azure_embedding_deployment if credentials.azure_embedding_deployment is not None else model
# return AzureOpenAIEmbedding(
# model=model,
# deployment_name=deployment,
# api_key=credentials.azure_key,
# azure_endpoint=credentials.azure_endpoint,
# api_version=credentials.azure_version,
# )
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
model = "text-embedding-ada-002"
deployment = credentials.azure_embedding_deployment if credentials.azure_embedding_deployment is not None else model
return AzureOpenAIEmbedding(
model=model,
deployment_name=deployment,
api_key=credentials.azure_key,
azure_endpoint=credentials.azure_endpoint,
api_version=credentials.azure_version,
api_endpoint=model_settings.azure_base_url,
api_key=model_settings.azure_api_key,
api_version=model_settings.azure_api_version,
model=config.embedding_model,
)
elif endpoint_type == "hugging-face":

View File

@@ -1,3 +1,5 @@
from collections import defaultdict
import requests
from letta.llm_api.helpers import make_post_request
@@ -20,7 +22,14 @@ def get_azure_model_list_endpoint(base_url: str, api_version: str):
return f"{base_url}/openai/models?api-version={api_version}"
def azure_openai_get_model_list(base_url: str, api_key: str, api_version: str) -> list:
def get_azure_deployment_list_endpoint(base_url: str):
# Please note that it has to be 2023-03-15-preview
# That's the only api version that works with this deployments endpoint
# TODO: Use the Azure Client library here instead
return f"{base_url}/openai/deployments?api-version=2023-03-15-preview"
def azure_openai_get_deployed_model_list(base_url: str, api_key: str, api_version: str) -> list:
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
# https://xxx.openai.azure.com/openai/models?api-version=xxx
@@ -28,18 +37,48 @@ def azure_openai_get_model_list(base_url: str, api_key: str, api_version: str) -
if api_key is not None:
headers["api-key"] = f"{api_key}"
# 1. Get all available models
url = get_azure_model_list_endpoint(base_url, api_version)
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError(f"Failed to retrieve model list: {e}")
all_available_models = response.json().get("data", [])
return response.json().get("data", [])
# 2. Get all the deployed models
url = get_azure_deployment_list_endpoint(base_url)
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError(f"Failed to retrieve model list: {e}")
deployed_models = response.json().get("data", [])
deployed_model_names = set([m["id"] for m in deployed_models])
# 3. Only return the models in available models if they have been deployed
deployed_models = [m for m in all_available_models if m["id"] in deployed_model_names]
# 4. Remove redundant deployments, only include the ones with the latest deployment
# Create a dictionary to store the latest model for each ID
latest_models = defaultdict()
# Iterate through the models and update the dictionary with the most recent model
for model in deployed_models:
model_id = model["id"]
updated_at = model["created_at"]
# If the model ID is new or the current model has a more recent created_at, update the dictionary
if model_id not in latest_models or updated_at > latest_models[model_id]["created_at"]:
latest_models[model_id] = model
# Extract the unique models
return list(latest_models.values())
def azure_openai_get_chat_completion_model_list(base_url: str, api_key: str, api_version: str) -> list:
model_list = azure_openai_get_model_list(base_url, api_key, api_version)
model_list = azure_openai_get_deployed_model_list(base_url, api_key, api_version)
# Extract models that support text generation
model_options = [m for m in model_list if m.get("capabilities").get("chat_completion") == True]
return model_options
@@ -53,10 +92,11 @@ def azure_openai_get_embeddings_model_list(base_url: str, api_key: str, api_vers
return m.get("capabilities").get("embeddings") == True and valid_name
model_list = azure_openai_get_model_list(base_url, api_key, api_version)
model_list = azure_openai_get_deployed_model_list(base_url, api_key, api_version)
# Extract models that support embeddings
model_options = [m for m in model_list if valid_embedding_model(m)]
return model_options

View File

@@ -0,0 +1,6 @@
{
"embedding_endpoint_type": "azure",
"embedding_model": "text-embedding-ada-002",
"embedding_dim": 768,
"embedding_chunk_size": 300
}

View File

@@ -101,7 +101,7 @@ def test_openai_gpt_4_edit_core_memory():
def test_embedding_endpoint_openai():
filename = os.path.join(embedding_config_dir, "text-embedding-ada-002.json")
filename = os.path.join(embedding_config_dir, "openai_embed.json")
run_embedding_endpoint(filename)
@@ -151,6 +151,11 @@ def test_azure_gpt_4o_mini_edit_core_memory():
print(f"Got successful response from client: \n\n{response}")
def test_azure_embedding_endpoint():
filename = os.path.join(embedding_config_dir, "azure_embed.json")
run_embedding_endpoint(filename)
# ======================================================================================================================
# LETTA HOSTED
# ======================================================================================================================