feat: added groq support via local option w/ auth (#1203)

This commit is contained in:
Charles Packer
2024-04-01 15:31:05 -07:00
committed by GitHub
parent 534054a144
commit d4cf8bda2c
7 changed files with 160 additions and 30 deletions

View File

@@ -132,7 +132,9 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
model_endpoint = azure_creds["azure_endpoint"]
else: # local models
backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
# assert backend_options_old == backend_options, (backend_options_old, backend_options)
default_model_endpoint_type = None
if config.default_llm_config.model_endpoint_type in backend_options:
# set from previous config
@@ -223,8 +225,12 @@ def get_model_options(
else:
# Attempt to do OpenAI endpoint style model fetching
# TODO support local auth
fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=None, fix_url=True)
# TODO support local auth with api-key header
if credentials.openllm_auth_type == "bearer_token":
api_key = credentials.openllm_key
else:
api_key = None
fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=api_key, fix_url=True)
model_options = [obj["id"] for obj in fetched_model_options_response["data"]]
# NOTE no filtering of local model options
@@ -289,6 +295,44 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
raise KeyboardInterrupt
else: # local models
# ask about local auth
if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys
use_local_auth = True
local_auth_type = "bearer_token"
local_auth_key = questionary.password(
"Enter your Groq API key:",
).ask()
if local_auth_key is None:
raise KeyboardInterrupt
credentials.openllm_auth_type = local_auth_type
credentials.openllm_key = local_auth_key
credentials.save()
else:
use_local_auth = questionary.confirm(
"Is your LLM endpoint authenticated? (default no)",
default=False,
).ask()
if use_local_auth is None:
raise KeyboardInterrupt
if use_local_auth:
local_auth_type = questionary.select(
"What HTTP authentication method does your endpoint require?",
choices=SUPPORTED_AUTH_TYPES,
default=SUPPORTED_AUTH_TYPES[0],
).ask()
if local_auth_type is None:
raise KeyboardInterrupt
local_auth_key = questionary.password(
"Enter your authentication key:",
).ask()
if local_auth_key is None:
raise KeyboardInterrupt
# credentials = MemGPTCredentials.load()
credentials.openllm_auth_type = local_auth_type
credentials.openllm_key = local_auth_key
credentials.save()
# ollama also needs model type
if model_endpoint_type == "ollama":
default_model = (
@@ -311,7 +355,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
)
# vllm needs huggingface model tag
if model_endpoint_type == "vllm":
if model_endpoint_type in ["vllm", "groq"]:
try:
# Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint
# + probably has custom model names
@@ -366,31 +410,6 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model_wrapper is None:
raise KeyboardInterrupt
# ask about local auth
use_local_auth = questionary.confirm(
"Is your LLM endpoint authenticated? (default no)",
default=False,
).ask()
if use_local_auth is None:
raise KeyboardInterrupt
if use_local_auth:
local_auth_type = questionary.select(
"What HTTP authentication method does your endpoint require?",
choices=SUPPORTED_AUTH_TYPES,
default=SUPPORTED_AUTH_TYPES[0],
).ask()
if local_auth_type is None:
raise KeyboardInterrupt
local_auth_key = questionary.password(
"Enter your authentication key:",
).ask()
if local_auth_key is None:
raise KeyboardInterrupt
# credentials = MemGPTCredentials.load()
credentials.openllm_auth_type = local_auth_type
credentials.openllm_key = local_auth_key
credentials.save()
# set: context_window
if str(model) not in LLM_MAX_TOKENS:
# Ask the user to specify the context length

View File

@@ -13,6 +13,7 @@ from memgpt.local_llm.llamacpp.api import get_llamacpp_completion
from memgpt.local_llm.koboldcpp.api import get_koboldcpp_completion
from memgpt.local_llm.ollama.api import get_ollama_completion
from memgpt.local_llm.vllm.api import get_vllm_completion
from memgpt.local_llm.groq.api import get_groq_completion
from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
from memgpt.local_llm.constants import DEFAULT_WRAPPER
from memgpt.local_llm.utils import get_available_wrappers, count_tokens
@@ -155,6 +156,8 @@ def get_chat_completion(
result, usage = get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window)
elif endpoint_type == "vllm":
result, usage = get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user)
elif endpoint_type == "groq":
result, usage = get_groq_completion(endpoint, auth_type, auth_key, model, prompt, context_window)
else:
raise LocalLLMError(
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"

View File

@@ -2,6 +2,7 @@
from memgpt.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper, ChatMLOuterInnerMonologueWrapper
DEFAULT_ENDPOINTS = {
# Local
"koboldcpp": "http://localhost:5001",
"llamacpp": "http://localhost:8080",
"lmstudio": "http://localhost:1234",
@@ -10,6 +11,9 @@ DEFAULT_ENDPOINTS = {
"webui-legacy": "http://localhost:5000",
"webui": "http://localhost:5000",
"vllm": "http://localhost:8000",
# APIs
"openai": "https://api.openai.com",
"groq": "https://api.groq.com/openai",
}
DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"

View File

@@ -0,0 +1,79 @@
from urllib.parse import urljoin
import requests
from typing import Tuple
from memgpt.local_llm.settings.settings import get_completions_settings
from memgpt.local_llm.utils import post_json_auth_request
from memgpt.utils import count_tokens
API_CHAT_SUFFIX = "/v1/chat/completions"
# LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
def get_groq_completion(endpoint: str, auth_type: str, auth_key: str, model: str, prompt: str, context_window: int) -> Tuple[str, dict]:
"""TODO no support for function calling OR raw completions, so we need to route the request into /chat/completions instead"""
from memgpt.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
settings = get_completions_settings()
settings.update(
{
# see https://console.groq.com/docs/text-chat, supports:
# "temperature": ,
# "max_tokens": ,
# "top_p",
# "stream",
# "stop",
}
)
URI = urljoin(endpoint.strip("/") + "/", API_CHAT_SUFFIX.strip("/"))
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["model"] = model
request["max_tokens"] = context_window
# NOTE: Hack for chat/completion-only endpoints: put the entire completion string inside the first message
message_structure = [{"role": "user", "content": prompt}]
request["messages"] = message_structure
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try:
response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key)
if response.status_code == 200:
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["choices"][0]["message"]["content"]
usage = result_full.get("usage", None)
else:
# Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"}
if "context length" in str(response.text).lower():
# "exceeds context length" is what appears in the LM Studio error message
# raise an alternate exception that matches OpenAI's message, which is "maximum context length"
raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})")
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f" Make sure that the inference server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
raise
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage

View File

@@ -5,7 +5,19 @@ from memgpt.constants import JSON_LOADS_STRICT
from memgpt.errors import LLMJSONParsingError
def extract_first_json(string):
def replace_escaped_underscores(string: str):
"""Handles the case of escaped underscores, e.g.:
{
"function":"send\_message",
"params": {
"inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.",
"message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?"
"""
return string.replace("\_", "_")
def extract_first_json(string: str):
"""Handles the case of two JSON objects back-to-back"""
from memgpt.utils import printd
@@ -163,6 +175,9 @@ def clean_json(raw_llm_output, messages=None, functions=None):
lambda output: json.loads(repair_even_worse_json(output), strict=JSON_LOADS_STRICT),
lambda output: extract_first_json(output + "}}"),
lambda output: clean_and_interpret_send_message_json(output),
# replace underscores
lambda output: json.loads(replace_escaped_underscores(output), strict=JSON_LOADS_STRICT),
lambda output: extract_first_json(replace_escaped_underscores(output) + "}}"),
]
for strategy in strategies:

View File

@@ -18,6 +18,7 @@ total KV lookups are required to find the final value), and
sample 30 different ordering configurations including both
the initial key position and nesting key positions.
"""
import math
import json
import argparse

View File

@@ -4,6 +4,14 @@ from memgpt.constants import JSON_LOADS_STRICT
import memgpt.local_llm.json_parser as json_parser
EXAMPLE_ESCAPED_UNDERSCORES = """{
"function":"send\_message",
"params": {
"inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.",
"message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?"
"""
EXAMPLE_MISSING_CLOSING_BRACE = """{
"function": "send_message",
"params": {
@@ -72,6 +80,7 @@ def test_json_parsers():
"""Try various broken JSON and check that the parsers can fix it"""
test_strings = [
EXAMPLE_ESCAPED_UNDERSCORES,
EXAMPLE_MISSING_CLOSING_BRACE,
EXAMPLE_BAD_TOKEN_END,
EXAMPLE_DOUBLE_JSON,