fix: various patches for Azure support + strip Box (#982)
This commit is contained in:
@@ -117,9 +117,11 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
|
||||
)
|
||||
else:
|
||||
credentials.azure_key = azure_creds["azure_key"]
|
||||
credentials.azure_endpoint = azure_creds["azure_endpoint"]
|
||||
credentials.azure_version = azure_creds["azure_version"]
|
||||
config.save()
|
||||
credentials.azure_embedding_version = azure_creds["azure_embedding_version"]
|
||||
credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"]
|
||||
if "azure_embedding_deployment" in azure_creds:
|
||||
credentials.azure_embedding_deployment = azure_creds["azure_embedding_deployment"]
|
||||
credentials.save()
|
||||
|
||||
model_endpoint_type = "azure"
|
||||
model_endpoint = azure_creds["azure_endpoint"]
|
||||
@@ -417,7 +419,12 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
|
||||
raise ValueError(
|
||||
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
|
||||
)
|
||||
# TODO we need to write these out to the config once we use them if we plan to ping for embedding lists with them
|
||||
credentials.azure_key = azure_creds["azure_key"]
|
||||
credentials.azure_version = azure_creds["azure_version"]
|
||||
credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"]
|
||||
if "azure_deployment" in azure_creds:
|
||||
credentials.azure_deployment = azure_creds["azure_deployment"]
|
||||
credentials.save()
|
||||
|
||||
embedding_endpoint_type = "azure"
|
||||
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
|
||||
|
||||
@@ -34,9 +34,13 @@ class MemGPTCredentials:
|
||||
# azure config
|
||||
azure_auth_type: str = "api_key"
|
||||
azure_key: Optional[str] = None
|
||||
azure_endpoint: Optional[str] = None
|
||||
# base llm / model
|
||||
azure_version: Optional[str] = None
|
||||
azure_endpoint: Optional[str] = None
|
||||
azure_deployment: Optional[str] = None
|
||||
# embeddings
|
||||
azure_embedding_version: Optional[str] = None
|
||||
azure_embedding_endpoint: Optional[str] = None
|
||||
azure_embedding_deployment: Optional[str] = None
|
||||
|
||||
# custom llm API config
|
||||
@@ -63,9 +67,11 @@ class MemGPTCredentials:
|
||||
# azure
|
||||
"azure_auth_type": get_field(config, "azure", "auth_type"),
|
||||
"azure_key": get_field(config, "azure", "key"),
|
||||
"azure_endpoint": get_field(config, "azure", "endpoint"),
|
||||
"azure_version": get_field(config, "azure", "version"),
|
||||
"azure_endpoint": get_field(config, "azure", "endpoint"),
|
||||
"azure_deployment": get_field(config, "azure", "deployment"),
|
||||
"azure_embedding_version": get_field(config, "azure", "embedding_version"),
|
||||
"azure_embedding_endpoint": get_field(config, "azure", "embedding_endpoint"),
|
||||
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
|
||||
# open llm
|
||||
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
|
||||
@@ -92,9 +98,11 @@ class MemGPTCredentials:
|
||||
# azure config
|
||||
set_field(config, "azure", "auth_type", self.azure_auth_type)
|
||||
set_field(config, "azure", "key", self.azure_key)
|
||||
set_field(config, "azure", "endpoint", self.azure_endpoint)
|
||||
set_field(config, "azure", "version", self.azure_version)
|
||||
set_field(config, "azure", "endpoint", self.azure_endpoint)
|
||||
set_field(config, "azure", "deployment", self.azure_deployment)
|
||||
set_field(config, "azure", "embedding_version", self.azure_embedding_version)
|
||||
set_field(config, "azure", "embedding_endpoint", self.azure_embedding_endpoint)
|
||||
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)
|
||||
|
||||
# openai config
|
||||
|
||||
@@ -5,12 +5,11 @@ import time
|
||||
from typing import Callable, TypeVar, Union
|
||||
import urllib
|
||||
|
||||
from box import Box
|
||||
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
|
||||
from memgpt.constants import CLI_WARNING_PREFIX
|
||||
from memgpt.models.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.models.embedding_response import EmbeddingResponse
|
||||
|
||||
from memgpt.data_types import AgentState
|
||||
|
||||
@@ -74,6 +73,8 @@ def smart_urljoin(base_url, relative_url):
|
||||
|
||||
def clean_azure_endpoint(raw_endpoint_name):
|
||||
"""Make sure the endpoint is of format 'https://YOUR_RESOURCE_NAME.openai.azure.com'"""
|
||||
if raw_endpoint_name is None:
|
||||
raise ValueError(raw_endpoint_name)
|
||||
endpoint_address = raw_endpoint_name.strip("/").replace(".openai.azure.com", "")
|
||||
endpoint_address = endpoint_address.replace("http://", "")
|
||||
endpoint_address = endpoint_address.replace("https://", "")
|
||||
@@ -231,7 +232,7 @@ def openai_embeddings_request(url, api_key, data):
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response}")
|
||||
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
|
||||
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
|
||||
return response
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
@@ -251,6 +252,11 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers
|
||||
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
assert resource_name is not None, "Missing required field when calling Azure OpenAI"
|
||||
assert deployment_id is not None, "Missing required field when calling Azure OpenAI"
|
||||
assert api_version is not None, "Missing required field when calling Azure OpenAI"
|
||||
assert api_key is not None, "Missing required field when calling Azure OpenAI"
|
||||
|
||||
resource_name = clean_azure_endpoint(resource_name)
|
||||
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}"
|
||||
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}
|
||||
@@ -274,7 +280,7 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers
|
||||
# NOTE: azure openai does not include "content" in the response when it is None, so we need to add it
|
||||
if "content" not in response["choices"][0].get("message"):
|
||||
response["choices"][0]["message"]["content"] = None
|
||||
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
|
||||
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
|
||||
return response
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
@@ -305,7 +311,7 @@ def azure_openai_embeddings_request(resource_name, deployment_id, api_version, a
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response}")
|
||||
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
|
||||
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
|
||||
return response
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
import requests
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from box import Box
|
||||
|
||||
from memgpt.local_llm.grammars.gbnf_grammar_generator import create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation
|
||||
from memgpt.local_llm.webui.api import get_webui_completion
|
||||
from memgpt.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
|
||||
|
||||
10
memgpt/models/embedding_response.py
Normal file
10
memgpt/models/embedding_response.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import List, Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""OpenAI embedding response model: https://platform.openai.com/docs/api-reference/embeddings/object"""
|
||||
|
||||
index: int # the index of the embedding in the list of embeddings
|
||||
embedding: List[float]
|
||||
object: Literal["embedding"] = "embedding"
|
||||
@@ -668,12 +668,19 @@ def verify_first_message_correctness(
|
||||
response_message = response.choices[0].message
|
||||
|
||||
# First message should be a call to send_message with a non-empty content
|
||||
if require_send_message and not (response_message.function_call or response_message.tool_calls):
|
||||
if ("function_call" in response_message and response_message.function_call is not None) and (
|
||||
"tool_calls" in response_message and response_message.tool_calls is not None
|
||||
):
|
||||
printd(f"First message includes both function call AND tool call: {response_message}")
|
||||
return False
|
||||
elif "function_call" in response_message and response_message.function_call is not None:
|
||||
function_call = response_message.function_call
|
||||
elif "tool_calls" in response_message and response_message.tool_calls is not None:
|
||||
function_call = response_message.tool_calls[0].function
|
||||
else:
|
||||
printd(f"First message didn't include function call: {response_message}")
|
||||
return False
|
||||
|
||||
assert not (response_message.function_call and response_message.tool_calls), response_message
|
||||
function_call = response_message.function_call if response_message.function_call else response_message.tool_calls[0].function
|
||||
function_name = function_call.name if function_call is not None else ""
|
||||
if require_send_message and function_name != "send_message" and function_name != "archival_memory_search":
|
||||
printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}")
|
||||
|
||||
Reference in New Issue
Block a user