feat: pull model list for openai-compatible endpoints (#630)
* allow entering custom model name when using openai/azure * pull models from endpoint * added/tested vllm and azure * no print * make red * make the endpoint question give you an opportunity to enter your openai api key again in case you made a mitake / want to swap it out * add cascading workflow for openai+azure model listings * patched bug w/ azure listing
This commit is contained in:
@@ -14,6 +14,8 @@ from memgpt.connectors.storage import StorageConnector
|
||||
from memgpt.constants import LLM_MAX_TOKENS
|
||||
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
|
||||
from memgpt.local_llm.utils import get_available_wrappers
|
||||
from memgpt.openai_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||
from memgpt.server.utils import shorten_key_middle
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -63,6 +65,18 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
openai_api_key = questionary.text(
|
||||
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
|
||||
).ask()
|
||||
config.openai_key = openai_api_key
|
||||
config.save()
|
||||
else:
|
||||
# Give the user an opportunity to overwrite the key
|
||||
openai_api_key = None
|
||||
default_input = shorten_key_middle(config.openai_key) if config.openai_key.startswith("sk-") else config.openai_key
|
||||
openai_api_key = questionary.text(
|
||||
"Enter your OpenAI API key (hit enter to use existing key):",
|
||||
default=default_input,
|
||||
).ask()
|
||||
# If the user modified it, use the new one
|
||||
if openai_api_key != default_input:
|
||||
config.openai_key = openai_api_key
|
||||
config.save()
|
||||
|
||||
@@ -78,6 +92,11 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
raise ValueError(
|
||||
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
|
||||
)
|
||||
else:
|
||||
config.azure_key = azure_creds["azure_key"]
|
||||
config.azure_endpoint = azure_creds["azure_endpoint"]
|
||||
config.azure_version = azure_creds["azure_version"]
|
||||
config.save()
|
||||
|
||||
model_endpoint_type = "azure"
|
||||
model_endpoint = azure_creds["azure_endpoint"]
|
||||
@@ -119,16 +138,56 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
return model_endpoint_type, model_endpoint
|
||||
|
||||
|
||||
def configure_model(config: MemGPTConfig, model_endpoint_type: str):
|
||||
def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoint: str):
|
||||
# set: model, model_wrapper
|
||||
model, model_wrapper = None, None
|
||||
if model_endpoint_type == "openai" or model_endpoint_type == "azure":
|
||||
model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
|
||||
# TODO: select
|
||||
valid_model = config.model in model_options
|
||||
# Get the model list from the openai / azure endpoint
|
||||
hardcoded_model_options = ["gpt-4", "gpt-4-32k", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
|
||||
fetched_model_options = None
|
||||
try:
|
||||
if model_endpoint_type == "openai":
|
||||
fetched_model_options = openai_get_model_list(url=model_endpoint, api_key=config.openai_key)
|
||||
elif model_endpoint_type == "azure":
|
||||
fetched_model_options = azure_openai_get_model_list(
|
||||
url=model_endpoint, api_key=config.azure_key, api_version=config.azure_version
|
||||
)
|
||||
fetched_model_options = [obj["id"] for obj in fetched_model_options["data"] if obj["id"].startswith("gpt-")]
|
||||
except:
|
||||
# NOTE: if this fails, it means the user's key is probably bad
|
||||
typer.secho(
|
||||
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
|
||||
)
|
||||
|
||||
# First ask if the user wants to see the full model list (some may be incompatible)
|
||||
see_all_option_str = "[see all options]"
|
||||
other_option_str = "[enter model name manually]"
|
||||
|
||||
# Check if the model we have set already is even in the list (informs our default)
|
||||
valid_model = config.model in hardcoded_model_options
|
||||
model = questionary.select(
|
||||
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
|
||||
"Select default model (recommended: gpt-4):",
|
||||
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
|
||||
default=config.model if valid_model else hardcoded_model_options[0],
|
||||
).ask()
|
||||
|
||||
# If the user asked for the full list, show it
|
||||
if model == see_all_option_str:
|
||||
typer.secho(f"Warning: not all models shown are guaranteed to work with MemGPT", fg=typer.colors.RED)
|
||||
model = questionary.select(
|
||||
"Select default model (recommended: gpt-4):",
|
||||
choices=fetched_model_options + [other_option_str],
|
||||
default=config.model if valid_model else fetched_model_options[0],
|
||||
).ask()
|
||||
|
||||
# Finally if the user asked to manually input, allow it
|
||||
if model == other_option_str:
|
||||
model = ""
|
||||
while len(model) == 0:
|
||||
model = questionary.text(
|
||||
"Enter custom model name:",
|
||||
).ask()
|
||||
|
||||
else: # local models
|
||||
# ollama also needs model type
|
||||
if model_endpoint_type == "ollama":
|
||||
@@ -139,24 +198,51 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str):
|
||||
).ask()
|
||||
model = None if len(model) == 0 else model
|
||||
|
||||
default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""
|
||||
|
||||
# vllm needs huggingface model tag
|
||||
if model_endpoint_type == "vllm":
|
||||
default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""
|
||||
model = questionary.text(
|
||||
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
|
||||
default=default_model,
|
||||
).ask()
|
||||
model = None if len(model) == 0 else model
|
||||
model_wrapper = None # no model wrapper for vLLM
|
||||
try:
|
||||
# Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint
|
||||
# + probably has custom model names
|
||||
model_options = openai_get_model_list(url=smart_urljoin(model_endpoint, "v1"), api_key=None)
|
||||
model_options = [obj["id"] for obj in model_options["data"]]
|
||||
except:
|
||||
print(f"Failed to get model list from {model_endpoint}, using defaults")
|
||||
model_options = None
|
||||
|
||||
# If we got model options from vLLM endpoint, allow selection + custom input
|
||||
if model_options is not None:
|
||||
other_option_str = "other (enter name)"
|
||||
valid_model = config.model in model_options
|
||||
model_options.append(other_option_str)
|
||||
model = questionary.select(
|
||||
"Select default model:", choices=model_options, default=config.model if valid_model else model_options[0]
|
||||
).ask()
|
||||
|
||||
# If we got custom input, ask for raw input
|
||||
if model == other_option_str:
|
||||
model = questionary.text(
|
||||
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
|
||||
default=default_model,
|
||||
).ask()
|
||||
# TODO allow empty string for input?
|
||||
model = None if len(model) == 0 else model
|
||||
|
||||
else:
|
||||
model = questionary.text(
|
||||
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
|
||||
default=default_model,
|
||||
).ask()
|
||||
model = None if len(model) == 0 else model
|
||||
|
||||
# model wrapper
|
||||
if model_endpoint_type != "vllm":
|
||||
available_model_wrappers = builtins.list(get_available_wrappers().keys())
|
||||
model_wrapper = questionary.select(
|
||||
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
|
||||
choices=available_model_wrappers,
|
||||
default=DEFAULT_WRAPPER_NAME,
|
||||
).ask()
|
||||
available_model_wrappers = builtins.list(get_available_wrappers().keys())
|
||||
model_wrapper = questionary.select(
|
||||
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
|
||||
choices=available_model_wrappers,
|
||||
default=DEFAULT_WRAPPER_NAME,
|
||||
).ask()
|
||||
|
||||
# set: context_window
|
||||
if str(model) not in LLM_MAX_TOKENS:
|
||||
@@ -228,6 +314,7 @@ def configure_embedding_endpoint(config: MemGPTConfig):
|
||||
raise ValueError(
|
||||
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
|
||||
)
|
||||
# TODO we need to write these out to the config once we use them if we plan to ping for embedding lists with them
|
||||
|
||||
embedding_endpoint_type = "azure"
|
||||
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
|
||||
@@ -345,7 +432,9 @@ def configure():
|
||||
config = MemGPTConfig.load()
|
||||
try:
|
||||
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
|
||||
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
|
||||
model, model_wrapper, context_window = configure_model(
|
||||
config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
|
||||
)
|
||||
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
|
||||
default_preset, default_persona, default_human, default_agent = configure_cli(config)
|
||||
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
|
||||
|
||||
@@ -2,6 +2,7 @@ import random
|
||||
import time
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable, TypeVar, Union
|
||||
import urllib
|
||||
|
||||
from box import Box
|
||||
@@ -75,6 +76,94 @@ def clean_azure_endpoint(raw_endpoint_name):
|
||||
return endpoint_address
|
||||
|
||||
|
||||
def openai_get_model_list(url: str, api_key: Union[str, None]) -> dict:
|
||||
"""https://platform.openai.com/docs/api-reference/models/list"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
url = smart_urljoin(url, "models")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response = {response}")
|
||||
return response
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got HTTPError, exception={http_err}, response={response}")
|
||||
raise http_err
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got RequestException, exception={req_err}, response={response}")
|
||||
raise req_err
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got unknown Exception, exception={e}, response={response}")
|
||||
raise e
|
||||
|
||||
|
||||
def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict:
|
||||
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
# https://xxx.openai.azure.com/openai/models?api-version=xxx
|
||||
url = smart_urljoin(url, "openai")
|
||||
url = smart_urljoin(url, f"models?api-version={api_version}")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["api-key"] = f"{api_key}"
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response = {response}")
|
||||
return response
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got HTTPError, exception={http_err}, response={response}")
|
||||
raise http_err
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got RequestException, exception={req_err}, response={response}")
|
||||
raise req_err
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got unknown Exception, exception={e}, response={response}")
|
||||
raise e
|
||||
|
||||
|
||||
def openai_chat_completions_request(url, api_key, data):
|
||||
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
Reference in New Issue
Block a user