fix: Add embedding tests to azure (#1920)
This commit is contained in:
@@ -21,7 +21,6 @@ from letta.constants import (
|
||||
EMBEDDING_TO_TOKENIZER_MAP,
|
||||
MAX_EMBEDDING_DIM,
|
||||
)
|
||||
from letta.credentials import LettaCredentials
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.utils import is_valid_url, printd
|
||||
|
||||
@@ -138,6 +137,18 @@ class EmbeddingEndpoint:
|
||||
return self._call_api(text)
|
||||
|
||||
|
||||
class AzureOpenAIEmbedding:
|
||||
def __init__(self, api_endpoint: str, api_key: str, api_version: str, model: str):
|
||||
from openai import AzureOpenAI
|
||||
|
||||
self.client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_endpoint)
|
||||
self.model = model
|
||||
|
||||
def get_text_embedding(self, text: str):
|
||||
embeddings = self.client.embeddings.create(input=[text], model=self.model).data[0].embedding
|
||||
return embeddings
|
||||
|
||||
|
||||
def default_embedding_model():
|
||||
# default to hugging face model running local
|
||||
# warning: this is a terrible model
|
||||
@@ -161,8 +172,8 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
|
||||
endpoint_type = config.embedding_endpoint_type
|
||||
|
||||
# TODO refactor to pass credentials through args
|
||||
credentials = LettaCredentials.load()
|
||||
# TODO: refactor to pass in settings from server
|
||||
from letta.settings import model_settings
|
||||
|
||||
if endpoint_type == "openai":
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
@@ -170,7 +181,7 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
additional_kwargs = {"user_id": user_id} if user_id else {}
|
||||
model = OpenAIEmbedding(
|
||||
api_base=config.embedding_endpoint,
|
||||
api_key=credentials.openai_key,
|
||||
api_key=model_settings.openai_api_key,
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
return model
|
||||
@@ -178,22 +189,29 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
elif endpoint_type == "azure":
|
||||
assert all(
|
||||
[
|
||||
credentials.azure_key is not None,
|
||||
credentials.azure_embedding_endpoint is not None,
|
||||
credentials.azure_version is not None,
|
||||
model_settings.azure_api_key is not None,
|
||||
model_settings.azure_base_url is not None,
|
||||
model_settings.azure_api_version is not None,
|
||||
]
|
||||
)
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
||||
## https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
|
||||
# model = "text-embedding-ada-002"
|
||||
# deployment = credentials.azure_embedding_deployment if credentials.azure_embedding_deployment is not None else model
|
||||
# return AzureOpenAIEmbedding(
|
||||
# model=model,
|
||||
# deployment_name=deployment,
|
||||
# api_key=credentials.azure_key,
|
||||
# azure_endpoint=credentials.azure_endpoint,
|
||||
# api_version=credentials.azure_version,
|
||||
# )
|
||||
|
||||
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
|
||||
model = "text-embedding-ada-002"
|
||||
deployment = credentials.azure_embedding_deployment if credentials.azure_embedding_deployment is not None else model
|
||||
return AzureOpenAIEmbedding(
|
||||
model=model,
|
||||
deployment_name=deployment,
|
||||
api_key=credentials.azure_key,
|
||||
azure_endpoint=credentials.azure_endpoint,
|
||||
api_version=credentials.azure_version,
|
||||
api_endpoint=model_settings.azure_base_url,
|
||||
api_key=model_settings.azure_api_key,
|
||||
api_version=model_settings.azure_api_version,
|
||||
model=config.embedding_model,
|
||||
)
|
||||
|
||||
elif endpoint_type == "hugging-face":
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import requests
|
||||
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
@@ -20,7 +22,14 @@ def get_azure_model_list_endpoint(base_url: str, api_version: str):
|
||||
return f"{base_url}/openai/models?api-version={api_version}"
|
||||
|
||||
|
||||
def azure_openai_get_model_list(base_url: str, api_key: str, api_version: str) -> list:
|
||||
def get_azure_deployment_list_endpoint(base_url: str):
|
||||
# Please note that it has to be 2023-03-15-preview
|
||||
# That's the only api version that works with this deployments endpoint
|
||||
# TODO: Use the Azure Client library here instead
|
||||
return f"{base_url}/openai/deployments?api-version=2023-03-15-preview"
|
||||
|
||||
|
||||
def azure_openai_get_deployed_model_list(base_url: str, api_key: str, api_version: str) -> list:
|
||||
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
|
||||
|
||||
# https://xxx.openai.azure.com/openai/models?api-version=xxx
|
||||
@@ -28,18 +37,48 @@ def azure_openai_get_model_list(base_url: str, api_key: str, api_version: str) -
|
||||
if api_key is not None:
|
||||
headers["api-key"] = f"{api_key}"
|
||||
|
||||
# 1. Get all available models
|
||||
url = get_azure_model_list_endpoint(base_url, api_version)
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"Failed to retrieve model list: {e}")
|
||||
all_available_models = response.json().get("data", [])
|
||||
|
||||
return response.json().get("data", [])
|
||||
# 2. Get all the deployed models
|
||||
url = get_azure_deployment_list_endpoint(base_url)
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"Failed to retrieve model list: {e}")
|
||||
|
||||
deployed_models = response.json().get("data", [])
|
||||
deployed_model_names = set([m["id"] for m in deployed_models])
|
||||
|
||||
# 3. Only return the models in available models if they have been deployed
|
||||
deployed_models = [m for m in all_available_models if m["id"] in deployed_model_names]
|
||||
|
||||
# 4. Remove redundant deployments, only include the ones with the latest deployment
|
||||
# Create a dictionary to store the latest model for each ID
|
||||
latest_models = defaultdict()
|
||||
|
||||
# Iterate through the models and update the dictionary with the most recent model
|
||||
for model in deployed_models:
|
||||
model_id = model["id"]
|
||||
updated_at = model["created_at"]
|
||||
|
||||
# If the model ID is new or the current model has a more recent created_at, update the dictionary
|
||||
if model_id not in latest_models or updated_at > latest_models[model_id]["created_at"]:
|
||||
latest_models[model_id] = model
|
||||
|
||||
# Extract the unique models
|
||||
return list(latest_models.values())
|
||||
|
||||
|
||||
def azure_openai_get_chat_completion_model_list(base_url: str, api_key: str, api_version: str) -> list:
|
||||
model_list = azure_openai_get_model_list(base_url, api_key, api_version)
|
||||
model_list = azure_openai_get_deployed_model_list(base_url, api_key, api_version)
|
||||
# Extract models that support text generation
|
||||
model_options = [m for m in model_list if m.get("capabilities").get("chat_completion") == True]
|
||||
return model_options
|
||||
@@ -53,10 +92,11 @@ def azure_openai_get_embeddings_model_list(base_url: str, api_key: str, api_vers
|
||||
|
||||
return m.get("capabilities").get("embeddings") == True and valid_name
|
||||
|
||||
model_list = azure_openai_get_model_list(base_url, api_key, api_version)
|
||||
model_list = azure_openai_get_deployed_model_list(base_url, api_key, api_version)
|
||||
# Extract models that support embeddings
|
||||
|
||||
model_options = [m for m in model_list if valid_embedding_model(m)]
|
||||
|
||||
return model_options
|
||||
|
||||
|
||||
|
||||
6
tests/configs/embedding_model_configs/azure_embed.json
Normal file
6
tests/configs/embedding_model_configs/azure_embed.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"embedding_endpoint_type": "azure",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"embedding_dim": 768,
|
||||
"embedding_chunk_size": 300
|
||||
}
|
||||
@@ -101,7 +101,7 @@ def test_openai_gpt_4_edit_core_memory():
|
||||
|
||||
|
||||
def test_embedding_endpoint_openai():
|
||||
filename = os.path.join(embedding_config_dir, "text-embedding-ada-002.json")
|
||||
filename = os.path.join(embedding_config_dir, "openai_embed.json")
|
||||
run_embedding_endpoint(filename)
|
||||
|
||||
|
||||
@@ -151,6 +151,11 @@ def test_azure_gpt_4o_mini_edit_core_memory():
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
def test_azure_embedding_endpoint():
|
||||
filename = os.path.join(embedding_config_dir, "azure_embed.json")
|
||||
run_embedding_endpoint(filename)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# LETTA HOSTED
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user