updated local APIs to return usage info (#585)

* updated APIs to return usage info

* tested all endpoints
This commit is contained in:
Charles Packer
2023-12-13 21:11:20 -08:00
committed by GitHub
parent 2048ba179b
commit 8cc1ed0f59
8 changed files with 182 additions and 75 deletions

View File

@@ -15,14 +15,11 @@ from memgpt.local_llm.ollama.api import get_ollama_completion
from memgpt.local_llm.vllm.api import get_vllm_completion
from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
from memgpt.local_llm.constants import DEFAULT_WRAPPER
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.local_llm.utils import get_available_wrappers, count_tokens
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from memgpt.errors import LocalLLMConnectionError, LocalLLMError
from memgpt.constants import CLI_WARNING_PREFIX
DEBUG = False
# DEBUG = True
has_shown_warning = False
@@ -38,6 +35,8 @@ def get_chat_completion(
endpoint=None,
endpoint_type=None,
):
from memgpt.utils import printd
assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set"
assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set"
@@ -78,8 +77,7 @@ def get_chat_completion(
# First step: turn the message sequence into a prompt that the model expects
try:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
if DEBUG:
print(prompt)
printd(prompt)
except Exception as e:
raise LocalLLMError(
f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}"
@@ -87,19 +85,19 @@ def get_chat_completion(
try:
if endpoint_type == "webui":
result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
result, usage = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "webui-legacy":
result = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar_name)
result, usage = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "lmstudio":
result = get_lmstudio_completion(endpoint, prompt, context_window)
result, usage = get_lmstudio_completion(endpoint, prompt, context_window)
elif endpoint_type == "llamacpp":
result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
result, usage = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "koboldcpp":
result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
result, usage = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "ollama":
result = get_ollama_completion(endpoint, model, prompt, context_window)
result, usage = get_ollama_completion(endpoint, model, prompt, context_window)
elif endpoint_type == "vllm":
result = get_vllm_completion(endpoint, model, prompt, context_window, user)
result, usage = get_vllm_completion(endpoint, model, prompt, context_window, user)
else:
raise LocalLLMError(
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
@@ -109,16 +107,37 @@ def get_chat_completion(
if result is None or result == "":
raise LocalLLMError(f"Got back an empty response string from {endpoint}")
if DEBUG:
print(f"Raw LLM output:\n{result}")
printd(f"Raw LLM output:\n{result}")
try:
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
if DEBUG:
print(json.dumps(chat_completion_result, indent=2))
printd(json.dumps(chat_completion_result, indent=2))
except Exception as e:
raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}")
# Fill in potential missing usage information (used for tracking token use)
if not ("prompt_tokens" in usage and "completion_tokens" in usage and "total_tokens" in usage):
raise LocalLLMError(f"usage dict in response was missing fields ({usage})")
if usage["prompt_tokens"] is None:
printd(f"usage dict was missing prompt_tokens, computing on-the-fly...")
usage["prompt_tokens"] = count_tokens(prompt)
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result))
"""
if usage["completion_tokens"] is None:
printd(f"usage dict was missing completion_tokens, computing on-the-fly...")
# chat_completion_result is dict with 'role' and 'content'
# token counter wants a string
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result))
"""
# NOTE: this is the token count that matters most
if usage["total_tokens"] is None:
printd(f"usage dict was missing total_tokens, computing on-the-fly...")
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
# unpack with response.choices[0].message.content
response = Box(
{
@@ -126,15 +145,17 @@ def get_chat_completion(
"choices": [
{
"message": chat_completion_result,
"finish_reason": "stop", # TODO vary based on backend response
# TODO vary 'finish_reason' based on backend response
# NOTE if we got this far (parsing worked), then it's probably OK to treat this as a stop
"finish_reason": "stop",
}
],
"usage": {
# TODO fix, actually use real info
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"prompt_tokens": usage["prompt_tokens"],
"completion_tokens": usage["completion_tokens"],
"total_tokens": usage["total_tokens"],
},
}
)
printd(response)
return response

View File

@@ -6,12 +6,12 @@ from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
DEBUG = False
# DEBUG = True
def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
from memgpt.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
@@ -34,10 +34,9 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set
URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["results"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["results"][0]["text"]
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
@@ -48,4 +47,16 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set
# TODO handle gracefully
raise
return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
# KoboldCpp doesn't return anything?
# https://lite.koboldai.net/koboldcpp_api#/v1/post_v1_generate
completion_tokens = None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage

View File

@@ -6,12 +6,12 @@ from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
LLAMACPP_API_SUFFIX = "/completion"
DEBUG = False
# DEBUG = True
def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
from memgpt.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
@@ -33,10 +33,9 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett
URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["content"]
if DEBUG:
print(f"json API response.text: {result}")
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["content"]
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
@@ -47,4 +46,14 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett
# TODO handle gracefully
raise
return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = result_full.get("tokens_predicted", None)
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from "tokens_evaluated", but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage

View File

@@ -7,30 +7,42 @@ from ..utils import count_tokens
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
DEBUG = False
# TODO move to "completions" by default, not "chat"
def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"):
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
from memgpt.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["max_tokens"] = context_window
# Uses the ChatCompletions API style
# Seems to work better, probably because it's applying some extra settings under-the-hood?
if api == "chat":
# Uses the ChatCompletions API style
# Seems to work better, probably because it's applying some extra settings under-the-hood?
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["max_tokens"] = context_window
# Put the entire completion string inside the first message
message_structure = [{"role": "user", "content": prompt}]
request["messages"] = message_structure
# Uses basic string completions (string in, string out)
# Does not work as well as ChatCompletions for some reason
elif api == "completions":
# Uses basic string completions (string in, string out)
# Does not work as well as ChatCompletions for some reason
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["max_tokens"] = context_window
# Standard completions format, formatted string goes in prompt
request["prompt"] = prompt
else:
raise ValueError(api)
@@ -40,13 +52,14 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a
try:
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
if api == "chat":
result = result["choices"][0]["message"]["content"]
result = result_full["choices"][0]["message"]["content"]
usage = result_full.get("usage", None)
elif api == "completions":
result = result["choices"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
result = result_full["choices"][0]["text"]
usage = result_full.get("usage", None)
else:
# Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"}
if "context length" in str(response.text).lower():
@@ -62,4 +75,14 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a
# TODO handle gracefully
raise
return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage

View File

@@ -7,11 +7,12 @@ from ..utils import count_tokens
from ...errors import LocalLLMError
OLLAMA_API_SUFFIX = "/api/generate"
DEBUG = False
def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None):
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
from memgpt.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
@@ -39,10 +40,10 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP
URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["response"]
if DEBUG:
print(f"json API response.text: {result}")
# https://github.com/jmorganca/ollama/blob/main/docs/api.md
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["response"]
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
@@ -53,4 +54,15 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP
# TODO handle gracefully
raise
return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
# https://github.com/jmorganca/ollama/blob/main/docs/api.md#response
completion_tokens = result_full.get("eval_count", None)
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can also grab from "prompt_eval_count"
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage

View File

@@ -5,11 +5,12 @@ import requests
from ..utils import load_grammar_file, count_tokens
WEBUI_API_SUFFIX = "/v1/completions"
DEBUG = False
def get_vllm_completion(endpoint, model, prompt, context_window, user, settings={}, grammar=None):
"""https://github.com/vllm-project/vllm/blob/main/examples/api_client.py"""
from memgpt.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
@@ -36,10 +37,10 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings=
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["choices"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["choices"][0]["text"]
usage = result_full.get("usage", None)
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
@@ -50,4 +51,14 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings=
# TODO handle gracefully
raise
return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage

View File

@@ -6,11 +6,12 @@ from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
WEBUI_API_SUFFIX = "/v1/completions"
DEBUG = False
def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None):
"""Compatibility for the new OpenAI API: https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples"""
from memgpt.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
@@ -33,10 +34,10 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["choices"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["choices"][0]["text"]
usage = result_full.get("usage", None)
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
@@ -47,4 +48,14 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram
# TODO handle gracefully
raise
return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage

View File

@@ -6,11 +6,12 @@ from .legacy_settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
WEBUI_API_SUFFIX = "/api/v1/generate"
DEBUG = False
def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None):
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
from memgpt.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
@@ -31,10 +32,9 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["results"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["results"][0]["text"]
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
@@ -45,4 +45,13 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram
# TODO handle gracefully
raise
return result
# TODO correct for legacy
completion_tokens = None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage