feat: added groq support via local option w/ auth (#1203)
This commit is contained in:
@@ -132,7 +132,9 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
|
||||
model_endpoint = azure_creds["azure_endpoint"]
|
||||
|
||||
else: # local models
|
||||
backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
|
||||
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
|
||||
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
|
||||
# assert backend_options_old == backend_options, (backend_options_old, backend_options)
|
||||
default_model_endpoint_type = None
|
||||
if config.default_llm_config.model_endpoint_type in backend_options:
|
||||
# set from previous config
|
||||
@@ -223,8 +225,12 @@ def get_model_options(
|
||||
|
||||
else:
|
||||
# Attempt to do OpenAI endpoint style model fetching
|
||||
# TODO support local auth
|
||||
fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=None, fix_url=True)
|
||||
# TODO support local auth with api-key header
|
||||
if credentials.openllm_auth_type == "bearer_token":
|
||||
api_key = credentials.openllm_key
|
||||
else:
|
||||
api_key = None
|
||||
fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=api_key, fix_url=True)
|
||||
model_options = [obj["id"] for obj in fetched_model_options_response["data"]]
|
||||
# NOTE no filtering of local model options
|
||||
|
||||
@@ -289,6 +295,44 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
raise KeyboardInterrupt
|
||||
|
||||
else: # local models
|
||||
|
||||
# ask about local auth
|
||||
if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys
|
||||
use_local_auth = True
|
||||
local_auth_type = "bearer_token"
|
||||
local_auth_key = questionary.password(
|
||||
"Enter your Groq API key:",
|
||||
).ask()
|
||||
if local_auth_key is None:
|
||||
raise KeyboardInterrupt
|
||||
credentials.openllm_auth_type = local_auth_type
|
||||
credentials.openllm_key = local_auth_key
|
||||
credentials.save()
|
||||
else:
|
||||
use_local_auth = questionary.confirm(
|
||||
"Is your LLM endpoint authenticated? (default no)",
|
||||
default=False,
|
||||
).ask()
|
||||
if use_local_auth is None:
|
||||
raise KeyboardInterrupt
|
||||
if use_local_auth:
|
||||
local_auth_type = questionary.select(
|
||||
"What HTTP authentication method does your endpoint require?",
|
||||
choices=SUPPORTED_AUTH_TYPES,
|
||||
default=SUPPORTED_AUTH_TYPES[0],
|
||||
).ask()
|
||||
if local_auth_type is None:
|
||||
raise KeyboardInterrupt
|
||||
local_auth_key = questionary.password(
|
||||
"Enter your authentication key:",
|
||||
).ask()
|
||||
if local_auth_key is None:
|
||||
raise KeyboardInterrupt
|
||||
# credentials = MemGPTCredentials.load()
|
||||
credentials.openllm_auth_type = local_auth_type
|
||||
credentials.openllm_key = local_auth_key
|
||||
credentials.save()
|
||||
|
||||
# ollama also needs model type
|
||||
if model_endpoint_type == "ollama":
|
||||
default_model = (
|
||||
@@ -311,7 +355,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
)
|
||||
|
||||
# vllm needs huggingface model tag
|
||||
if model_endpoint_type == "vllm":
|
||||
if model_endpoint_type in ["vllm", "groq"]:
|
||||
try:
|
||||
# Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint
|
||||
# + probably has custom model names
|
||||
@@ -366,31 +410,6 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
if model_wrapper is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# ask about local auth
|
||||
use_local_auth = questionary.confirm(
|
||||
"Is your LLM endpoint authenticated? (default no)",
|
||||
default=False,
|
||||
).ask()
|
||||
if use_local_auth is None:
|
||||
raise KeyboardInterrupt
|
||||
if use_local_auth:
|
||||
local_auth_type = questionary.select(
|
||||
"What HTTP authentication method does your endpoint require?",
|
||||
choices=SUPPORTED_AUTH_TYPES,
|
||||
default=SUPPORTED_AUTH_TYPES[0],
|
||||
).ask()
|
||||
if local_auth_type is None:
|
||||
raise KeyboardInterrupt
|
||||
local_auth_key = questionary.password(
|
||||
"Enter your authentication key:",
|
||||
).ask()
|
||||
if local_auth_key is None:
|
||||
raise KeyboardInterrupt
|
||||
# credentials = MemGPTCredentials.load()
|
||||
credentials.openllm_auth_type = local_auth_type
|
||||
credentials.openllm_key = local_auth_key
|
||||
credentials.save()
|
||||
|
||||
# set: context_window
|
||||
if str(model) not in LLM_MAX_TOKENS:
|
||||
# Ask the user to specify the context length
|
||||
|
||||
@@ -13,6 +13,7 @@ from memgpt.local_llm.llamacpp.api import get_llamacpp_completion
|
||||
from memgpt.local_llm.koboldcpp.api import get_koboldcpp_completion
|
||||
from memgpt.local_llm.ollama.api import get_ollama_completion
|
||||
from memgpt.local_llm.vllm.api import get_vllm_completion
|
||||
from memgpt.local_llm.groq.api import get_groq_completion
|
||||
from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
|
||||
from memgpt.local_llm.constants import DEFAULT_WRAPPER
|
||||
from memgpt.local_llm.utils import get_available_wrappers, count_tokens
|
||||
@@ -155,6 +156,8 @@ def get_chat_completion(
|
||||
result, usage = get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window)
|
||||
elif endpoint_type == "vllm":
|
||||
result, usage = get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user)
|
||||
elif endpoint_type == "groq":
|
||||
result, usage = get_groq_completion(endpoint, auth_type, auth_key, model, prompt, context_window)
|
||||
else:
|
||||
raise LocalLLMError(
|
||||
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from memgpt.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper, ChatMLOuterInnerMonologueWrapper
|
||||
|
||||
DEFAULT_ENDPOINTS = {
|
||||
# Local
|
||||
"koboldcpp": "http://localhost:5001",
|
||||
"llamacpp": "http://localhost:8080",
|
||||
"lmstudio": "http://localhost:1234",
|
||||
@@ -10,6 +11,9 @@ DEFAULT_ENDPOINTS = {
|
||||
"webui-legacy": "http://localhost:5000",
|
||||
"webui": "http://localhost:5000",
|
||||
"vllm": "http://localhost:8000",
|
||||
# APIs
|
||||
"openai": "https://api.openai.com",
|
||||
"groq": "https://api.groq.com/openai",
|
||||
}
|
||||
|
||||
DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
|
||||
|
||||
79
memgpt/local_llm/groq/api.py
Normal file
79
memgpt/local_llm/groq/api.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from typing import Tuple
|
||||
|
||||
from memgpt.local_llm.settings.settings import get_completions_settings
|
||||
from memgpt.local_llm.utils import post_json_auth_request
|
||||
from memgpt.utils import count_tokens
|
||||
|
||||
|
||||
API_CHAT_SUFFIX = "/v1/chat/completions"
|
||||
# LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
|
||||
|
||||
|
||||
def get_groq_completion(endpoint: str, auth_type: str, auth_key: str, model: str, prompt: str, context_window: int) -> Tuple[str, dict]:
|
||||
"""TODO no support for function calling OR raw completions, so we need to route the request into /chat/completions instead"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
settings = get_completions_settings()
|
||||
settings.update(
|
||||
{
|
||||
# see https://console.groq.com/docs/text-chat, supports:
|
||||
# "temperature": ,
|
||||
# "max_tokens": ,
|
||||
# "top_p",
|
||||
# "stream",
|
||||
# "stop",
|
||||
}
|
||||
)
|
||||
|
||||
URI = urljoin(endpoint.strip("/") + "/", API_CHAT_SUFFIX.strip("/"))
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["model"] = model
|
||||
request["max_tokens"] = context_window
|
||||
# NOTE: Hack for chat/completion-only endpoints: put the entire completion string inside the first message
|
||||
message_structure = [{"role": "user", "content": prompt}]
|
||||
request["messages"] = message_structure
|
||||
|
||||
if not endpoint.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key)
|
||||
if response.status_code == 200:
|
||||
result_full = response.json()
|
||||
printd(f"JSON API response:\n{result_full}")
|
||||
result = result_full["choices"][0]["message"]["content"]
|
||||
usage = result_full.get("usage", None)
|
||||
else:
|
||||
# Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"}
|
||||
if "context length" in str(response.text).lower():
|
||||
# "exceeds context length" is what appears in the LM Studio error message
|
||||
# raise an alternate exception that matches OpenAI's message, which is "maximum context length"
|
||||
raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})")
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
+ f" Make sure that the inference server is running and reachable at {URI}."
|
||||
)
|
||||
except:
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
# Pass usage statistics back to main thread
|
||||
# These are used to compute memory warning messages
|
||||
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
|
||||
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return result, usage
|
||||
@@ -5,7 +5,19 @@ from memgpt.constants import JSON_LOADS_STRICT
|
||||
from memgpt.errors import LLMJSONParsingError
|
||||
|
||||
|
||||
def extract_first_json(string):
|
||||
def replace_escaped_underscores(string: str):
|
||||
"""Handles the case of escaped underscores, e.g.:
|
||||
|
||||
{
|
||||
"function":"send\_message",
|
||||
"params": {
|
||||
"inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.",
|
||||
"message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?"
|
||||
"""
|
||||
return string.replace("\_", "_")
|
||||
|
||||
|
||||
def extract_first_json(string: str):
|
||||
"""Handles the case of two JSON objects back-to-back"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
@@ -163,6 +175,9 @@ def clean_json(raw_llm_output, messages=None, functions=None):
|
||||
lambda output: json.loads(repair_even_worse_json(output), strict=JSON_LOADS_STRICT),
|
||||
lambda output: extract_first_json(output + "}}"),
|
||||
lambda output: clean_and_interpret_send_message_json(output),
|
||||
# replace underscores
|
||||
lambda output: json.loads(replace_escaped_underscores(output), strict=JSON_LOADS_STRICT),
|
||||
lambda output: extract_first_json(replace_escaped_underscores(output) + "}}"),
|
||||
]
|
||||
|
||||
for strategy in strategies:
|
||||
|
||||
@@ -18,6 +18,7 @@ total KV lookups are required to find the final value), and
|
||||
sample 30 different ordering configurations including both
|
||||
the initial key position and nesting key positions.
|
||||
"""
|
||||
|
||||
import math
|
||||
import json
|
||||
import argparse
|
||||
|
||||
@@ -4,6 +4,14 @@ from memgpt.constants import JSON_LOADS_STRICT
|
||||
import memgpt.local_llm.json_parser as json_parser
|
||||
|
||||
|
||||
EXAMPLE_ESCAPED_UNDERSCORES = """{
|
||||
"function":"send\_message",
|
||||
"params": {
|
||||
"inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.",
|
||||
"message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?"
|
||||
"""
|
||||
|
||||
|
||||
EXAMPLE_MISSING_CLOSING_BRACE = """{
|
||||
"function": "send_message",
|
||||
"params": {
|
||||
@@ -72,6 +80,7 @@ def test_json_parsers():
|
||||
"""Try various broken JSON and check that the parsers can fix it"""
|
||||
|
||||
test_strings = [
|
||||
EXAMPLE_ESCAPED_UNDERSCORES,
|
||||
EXAMPLE_MISSING_CLOSING_BRACE,
|
||||
EXAMPLE_BAD_TOKEN_END,
|
||||
EXAMPLE_DOUBLE_JSON,
|
||||
|
||||
Reference in New Issue
Block a user