Merge pull request #140 from cpacker/azure-patch
Patch azure support Co-Authored-By: rivms <50959956+rivms@users.noreply.github.com>
This commit is contained in:
@@ -132,7 +132,10 @@ If you're using Azure OpenAI, set these variables instead:
|
||||
export AZURE_OPENAI_KEY = ...
|
||||
export AZURE_OPENAI_ENDPOINT = ...
|
||||
export AZURE_OPENAI_VERSION = ...
|
||||
|
||||
# set the below if you are using deployment ids
|
||||
export AZURE_OPENAI_DEPLOYMENT = ...
|
||||
export AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT = ...
|
||||
|
||||
# then use the --use_azure_openai flag
|
||||
memgpt --use_azure_openai
|
||||
|
||||
@@ -28,6 +28,12 @@ from memgpt.persistence_manager import (
|
||||
|
||||
from memgpt.config import Config
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.openai_tools import (
|
||||
configure_azure_support,
|
||||
check_azure_embeddings,
|
||||
get_set_azure_env_vars,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
app = typer.Typer()
|
||||
@@ -187,6 +193,18 @@ async def main(
|
||||
if debug:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Azure OpenAI support
|
||||
if use_azure_openai:
|
||||
configure_azure_support()
|
||||
check_azure_embeddings()
|
||||
else:
|
||||
azure_vars = get_set_azure_env_vars()
|
||||
if len(azure_vars) > 0:
|
||||
print(
|
||||
f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False"
|
||||
)
|
||||
return
|
||||
|
||||
if any(
|
||||
(
|
||||
persona,
|
||||
@@ -285,38 +303,6 @@ async def main(
|
||||
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
|
||||
)
|
||||
|
||||
# Azure OpenAI support
|
||||
if use_azure_openai:
|
||||
azure_openai_key = os.getenv("AZURE_OPENAI_KEY")
|
||||
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
azure_openai_version = os.getenv("AZURE_OPENAI_VERSION")
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
if None in [
|
||||
azure_openai_key,
|
||||
azure_openai_endpoint,
|
||||
azure_openai_version,
|
||||
azure_openai_deployment,
|
||||
]:
|
||||
print(
|
||||
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
|
||||
)
|
||||
return
|
||||
|
||||
import openai
|
||||
|
||||
openai.api_type = "azure"
|
||||
openai.api_key = azure_openai_key
|
||||
openai.api_base = azure_openai_endpoint
|
||||
openai.api_version = azure_openai_version
|
||||
# deployment gets passed into chatcompletion
|
||||
else:
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
print(
|
||||
f"Error: AZURE_OPENAI_DEPLOYMENT should not be set if --use_azure_openai is False"
|
||||
)
|
||||
return
|
||||
|
||||
if cfg.index:
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(
|
||||
cfg.index, cfg.archival_database
|
||||
|
||||
@@ -116,18 +116,26 @@ async def acompletions_with_backoff(**kwargs):
|
||||
|
||||
# OpenAI / Azure model
|
||||
else:
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs["deployment_id"] = azure_openai_deployment
|
||||
if using_azure():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs["deployment_id"] = azure_openai_deployment
|
||||
else:
|
||||
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[kwargs["model"]]
|
||||
kwargs.pop("model")
|
||||
return await openai.ChatCompletion.acreate(**kwargs)
|
||||
|
||||
|
||||
@aretry_with_exponential_backoff
|
||||
async def acreate_embedding_with_backoff(**kwargs):
|
||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs["deployment_id"] = azure_openai_deployment
|
||||
if using_azure():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs["deployment_id"] = azure_openai_deployment
|
||||
else:
|
||||
kwargs["engine"] = kwargs["model"]
|
||||
kwargs.pop("model")
|
||||
return await openai.Embedding.acreate(**kwargs)
|
||||
|
||||
|
||||
@@ -138,3 +146,63 @@ async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002")
|
||||
response = await acreate_embedding_with_backoff(input=[text], model=model)
|
||||
embedding = response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
||||
|
||||
MODEL_TO_AZURE_ENGINE = {
|
||||
"gpt-4": "gpt-4",
|
||||
"gpt-4-32k": "gpt-4-32k",
|
||||
"gpt-3.5": "gpt-35-turbo",
|
||||
"gpt-3.5-turbo": "gpt-35-turbo",
|
||||
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k",
|
||||
}
|
||||
|
||||
|
||||
def get_set_azure_env_vars():
|
||||
azure_env_variables = [
|
||||
("AZURE_OPENAI_KEY", os.getenv("AZURE_OPENAI_KEY")),
|
||||
("AZURE_OPENAI_ENDPOINT", os.getenv("AZURE_OPENAI_ENDPOINT")),
|
||||
("AZURE_OPENAI_VERSION", os.getenv("AZURE_OPENAI_VERSION")),
|
||||
("AZURE_OPENAI_DEPLOYMENT", os.getenv("AZURE_OPENAI_DEPLOYMENT")),
|
||||
(
|
||||
"AZURE_OPENAI_EMBEDDING_DEPLOYMENT",
|
||||
os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"),
|
||||
),
|
||||
]
|
||||
return [x for x in azure_env_variables if x[1] is not None]
|
||||
|
||||
|
||||
def using_azure():
|
||||
return len(get_set_azure_env_vars()) > 0
|
||||
|
||||
|
||||
def configure_azure_support():
|
||||
azure_openai_key = os.getenv("AZURE_OPENAI_KEY")
|
||||
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
azure_openai_version = os.getenv("AZURE_OPENAI_VERSION")
|
||||
if None in [
|
||||
azure_openai_key,
|
||||
azure_openai_endpoint,
|
||||
azure_openai_version,
|
||||
]:
|
||||
print(
|
||||
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
|
||||
)
|
||||
return
|
||||
|
||||
openai.api_type = "azure"
|
||||
openai.api_key = azure_openai_key
|
||||
openai.api_base = azure_openai_endpoint
|
||||
openai.api_version = azure_openai_version
|
||||
# deployment gets passed into chatcompletion
|
||||
|
||||
|
||||
def check_azure_embeddings():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
|
||||
if (
|
||||
azure_openai_deployment is not None
|
||||
and azure_openai_embedding_deployment is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user