139 lines
5.8 KiB
Python
139 lines
5.8 KiB
Python
import requests
|
|
import time
|
|
from typing import Union, Optional
|
|
|
|
from memgpt.models.chat_completion_response import ChatCompletionResponse
|
|
from memgpt.models.chat_completion_request import ChatCompletionRequest
|
|
from memgpt.models.embedding_response import EmbeddingResponse
|
|
from memgpt.utils import smart_urljoin
|
|
|
|
|
|
def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
|
|
"""https://platform.openai.com/docs/api-reference/models/list"""
|
|
from memgpt.utils import printd
|
|
|
|
# In some cases we may want to double-check the URL and do basic correction, eg:
|
|
# In MemGPT config the address for vLLM is w/o a /v1 suffix for simplicity
|
|
# However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit
|
|
if fix_url:
|
|
if not url.endswith("/v1"):
|
|
url = smart_urljoin(url, "v1")
|
|
|
|
url = smart_urljoin(url, "models")
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
if api_key is not None:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
printd(f"Sending request to {url}")
|
|
try:
|
|
response = requests.get(url, headers=headers)
|
|
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
|
response = response.json() # convert to dict from string
|
|
printd(f"response = {response}")
|
|
return response
|
|
except requests.exceptions.HTTPError as http_err:
|
|
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
|
try:
|
|
response = response.json()
|
|
except:
|
|
pass
|
|
printd(f"Got HTTPError, exception={http_err}, response={response}")
|
|
raise http_err
|
|
except requests.exceptions.RequestException as req_err:
|
|
# Handle other requests-related errors (e.g., connection error)
|
|
try:
|
|
response = response.json()
|
|
except:
|
|
pass
|
|
printd(f"Got RequestException, exception={req_err}, response={response}")
|
|
raise req_err
|
|
except Exception as e:
|
|
# Handle other potential errors
|
|
try:
|
|
response = response.json()
|
|
except:
|
|
pass
|
|
printd(f"Got unknown Exception, exception={e}, response={response}")
|
|
raise e
|
|
|
|
|
|
def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletionRequest) -> ChatCompletionResponse:
|
|
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
|
|
from memgpt.utils import printd
|
|
|
|
url = smart_urljoin(url, "chat/completions")
|
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
|
data = data.model_dump(exclude_none=True)
|
|
|
|
# If functions == None, strip from the payload
|
|
if "functions" in data and data["functions"] is None:
|
|
data.pop("functions")
|
|
data.pop("function_call", None) # extra safe, should exist always (default="auto")
|
|
|
|
if "tools" in data and data["tools"] is None:
|
|
data.pop("tools")
|
|
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
|
|
|
printd(f"Sending request to {url}")
|
|
try:
|
|
# Example code to trigger a rate limit response:
|
|
# mock_response = requests.Response()
|
|
# mock_response.status_code = 429
|
|
# http_error = requests.exceptions.HTTPError("429 Client Error: Too Many Requests")
|
|
# http_error.response = mock_response
|
|
# raise http_error
|
|
|
|
# Example code to trigger a context overflow response (for an 8k model)
|
|
# data["messages"][-1]["content"] = " ".join(["repeat after me this is not a fluke"] * 1000)
|
|
|
|
response = requests.post(url, headers=headers, json=data)
|
|
printd(f"response = {response}")
|
|
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
|
response = response.json() # convert to dict from string
|
|
printd(f"response.json = {response}")
|
|
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
|
|
return response
|
|
except requests.exceptions.HTTPError as http_err:
|
|
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
|
printd(f"Got HTTPError, exception={http_err}, payload={data}")
|
|
raise http_err
|
|
except requests.exceptions.RequestException as req_err:
|
|
# Handle other requests-related errors (e.g., connection error)
|
|
printd(f"Got RequestException, exception={req_err}")
|
|
raise req_err
|
|
except Exception as e:
|
|
# Handle other potential errors
|
|
printd(f"Got unknown Exception, exception={e}")
|
|
raise e
|
|
|
|
|
|
def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse:
|
|
"""https://platform.openai.com/docs/api-reference/embeddings/create"""
|
|
from memgpt.utils import printd
|
|
|
|
url = smart_urljoin(url, "embeddings")
|
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
|
|
|
printd(f"Sending request to {url}")
|
|
try:
|
|
response = requests.post(url, headers=headers, json=data)
|
|
printd(f"response = {response}")
|
|
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
|
response = response.json() # convert to dict from string
|
|
printd(f"response.json = {response}")
|
|
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
|
|
return response
|
|
except requests.exceptions.HTTPError as http_err:
|
|
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
|
printd(f"Got HTTPError, exception={http_err}, payload={data}")
|
|
raise http_err
|
|
except requests.exceptions.RequestException as req_err:
|
|
# Handle other requests-related errors (e.g., connection error)
|
|
printd(f"Got RequestException, exception={req_err}")
|
|
raise req_err
|
|
except Exception as e:
|
|
# Handle other potential errors
|
|
printd(f"Got unknown Exception, exception={e}")
|
|
raise e
|