updated local APIs to return usage info (#585)
* updated APIs to return usage info * tested all endpoints
This commit is contained in:
@@ -15,14 +15,11 @@ from memgpt.local_llm.ollama.api import get_ollama_completion
|
||||
from memgpt.local_llm.vllm.api import get_vllm_completion
|
||||
from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
|
||||
from memgpt.local_llm.constants import DEFAULT_WRAPPER
|
||||
from memgpt.local_llm.utils import get_available_wrappers
|
||||
from memgpt.local_llm.utils import get_available_wrappers, count_tokens
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
from memgpt.errors import LocalLLMConnectionError, LocalLLMError
|
||||
from memgpt.constants import CLI_WARNING_PREFIX
|
||||
|
||||
DEBUG = False
|
||||
# DEBUG = True
|
||||
|
||||
has_shown_warning = False
|
||||
|
||||
|
||||
@@ -38,6 +35,8 @@ def get_chat_completion(
|
||||
endpoint=None,
|
||||
endpoint_type=None,
|
||||
):
|
||||
from memgpt.utils import printd
|
||||
|
||||
assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
|
||||
assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set"
|
||||
assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set"
|
||||
@@ -78,8 +77,7 @@ def get_chat_completion(
|
||||
# First step: turn the message sequence into a prompt that the model expects
|
||||
try:
|
||||
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
|
||||
if DEBUG:
|
||||
print(prompt)
|
||||
printd(prompt)
|
||||
except Exception as e:
|
||||
raise LocalLLMError(
|
||||
f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}"
|
||||
@@ -87,19 +85,19 @@ def get_chat_completion(
|
||||
|
||||
try:
|
||||
if endpoint_type == "webui":
|
||||
result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
result, usage = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
elif endpoint_type == "webui-legacy":
|
||||
result = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
result, usage = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
elif endpoint_type == "lmstudio":
|
||||
result = get_lmstudio_completion(endpoint, prompt, context_window)
|
||||
result, usage = get_lmstudio_completion(endpoint, prompt, context_window)
|
||||
elif endpoint_type == "llamacpp":
|
||||
result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
result, usage = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
elif endpoint_type == "koboldcpp":
|
||||
result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
result, usage = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
elif endpoint_type == "ollama":
|
||||
result = get_ollama_completion(endpoint, model, prompt, context_window)
|
||||
result, usage = get_ollama_completion(endpoint, model, prompt, context_window)
|
||||
elif endpoint_type == "vllm":
|
||||
result = get_vllm_completion(endpoint, model, prompt, context_window, user)
|
||||
result, usage = get_vllm_completion(endpoint, model, prompt, context_window, user)
|
||||
else:
|
||||
raise LocalLLMError(
|
||||
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
|
||||
@@ -109,16 +107,37 @@ def get_chat_completion(
|
||||
|
||||
if result is None or result == "":
|
||||
raise LocalLLMError(f"Got back an empty response string from {endpoint}")
|
||||
if DEBUG:
|
||||
print(f"Raw LLM output:\n{result}")
|
||||
printd(f"Raw LLM output:\n{result}")
|
||||
|
||||
try:
|
||||
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
|
||||
if DEBUG:
|
||||
print(json.dumps(chat_completion_result, indent=2))
|
||||
printd(json.dumps(chat_completion_result, indent=2))
|
||||
except Exception as e:
|
||||
raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}")
|
||||
|
||||
# Fill in potential missing usage information (used for tracking token use)
|
||||
if not ("prompt_tokens" in usage and "completion_tokens" in usage and "total_tokens" in usage):
|
||||
raise LocalLLMError(f"usage dict in response was missing fields ({usage})")
|
||||
|
||||
if usage["prompt_tokens"] is None:
|
||||
printd(f"usage dict was missing prompt_tokens, computing on-the-fly...")
|
||||
usage["prompt_tokens"] = count_tokens(prompt)
|
||||
|
||||
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
|
||||
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result))
|
||||
"""
|
||||
if usage["completion_tokens"] is None:
|
||||
printd(f"usage dict was missing completion_tokens, computing on-the-fly...")
|
||||
# chat_completion_result is dict with 'role' and 'content'
|
||||
# token counter wants a string
|
||||
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result))
|
||||
"""
|
||||
|
||||
# NOTE: this is the token count that matters most
|
||||
if usage["total_tokens"] is None:
|
||||
printd(f"usage dict was missing total_tokens, computing on-the-fly...")
|
||||
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
response = Box(
|
||||
{
|
||||
@@ -126,15 +145,17 @@ def get_chat_completion(
|
||||
"choices": [
|
||||
{
|
||||
"message": chat_completion_result,
|
||||
"finish_reason": "stop", # TODO vary based on backend response
|
||||
# TODO vary 'finish_reason' based on backend response
|
||||
# NOTE if we got this far (parsing worked), then it's probably OK to treat this as a stop
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
# TODO fix, actually use real info
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": usage["prompt_tokens"],
|
||||
"completion_tokens": usage["completion_tokens"],
|
||||
"total_tokens": usage["total_tokens"],
|
||||
},
|
||||
}
|
||||
)
|
||||
printd(response)
|
||||
return response
|
||||
|
||||
@@ -6,12 +6,12 @@ from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
|
||||
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
|
||||
DEBUG = False
|
||||
# DEBUG = True
|
||||
|
||||
|
||||
def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
|
||||
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
@@ -34,10 +34,9 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set
|
||||
URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result = result["results"][0]["text"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
result_full = response.json()
|
||||
printd(f"JSON API response:\n{result_full}")
|
||||
result = result_full["results"][0]["text"]
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
@@ -48,4 +47,16 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
# Pass usage statistics back to main thread
|
||||
# These are used to compute memory warning messages
|
||||
# KoboldCpp doesn't return anything?
|
||||
# https://lite.koboldai.net/koboldcpp_api#/v1/post_v1_generate
|
||||
completion_tokens = None
|
||||
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return result, usage
|
||||
|
||||
@@ -6,12 +6,12 @@ from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
|
||||
LLAMACPP_API_SUFFIX = "/completion"
|
||||
DEBUG = False
|
||||
# DEBUG = True
|
||||
|
||||
|
||||
def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
|
||||
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
@@ -33,10 +33,9 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett
|
||||
URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result = result["content"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
result_full = response.json()
|
||||
printd(f"JSON API response:\n{result_full}")
|
||||
result = result_full["content"]
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
@@ -47,4 +46,14 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
# Pass usage statistics back to main thread
|
||||
# These are used to compute memory warning messages
|
||||
completion_tokens = result_full.get("tokens_predicted", None)
|
||||
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens, # can grab from "tokens_evaluated", but it's usually wrong (set to 0)
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return result, usage
|
||||
|
||||
@@ -7,30 +7,42 @@ from ..utils import count_tokens
|
||||
|
||||
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
|
||||
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
# TODO move to "completions" by default, not "chat"
|
||||
def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"):
|
||||
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["max_tokens"] = context_window
|
||||
|
||||
# Uses the ChatCompletions API style
|
||||
# Seems to work better, probably because it's applying some extra settings under-the-hood?
|
||||
if api == "chat":
|
||||
# Uses the ChatCompletions API style
|
||||
# Seems to work better, probably because it's applying some extra settings under-the-hood?
|
||||
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["max_tokens"] = context_window
|
||||
|
||||
# Put the entire completion string inside the first message
|
||||
message_structure = [{"role": "user", "content": prompt}]
|
||||
request["messages"] = message_structure
|
||||
|
||||
# Uses basic string completions (string in, string out)
|
||||
# Does not work as well as ChatCompletions for some reason
|
||||
elif api == "completions":
|
||||
# Uses basic string completions (string in, string out)
|
||||
# Does not work as well as ChatCompletions for some reason
|
||||
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["max_tokens"] = context_window
|
||||
|
||||
# Standard completions format, formatted string goes in prompt
|
||||
request["prompt"] = prompt
|
||||
|
||||
else:
|
||||
raise ValueError(api)
|
||||
|
||||
@@ -40,13 +52,14 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a
|
||||
try:
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result_full = response.json()
|
||||
printd(f"JSON API response:\n{result_full}")
|
||||
if api == "chat":
|
||||
result = result["choices"][0]["message"]["content"]
|
||||
result = result_full["choices"][0]["message"]["content"]
|
||||
usage = result_full.get("usage", None)
|
||||
elif api == "completions":
|
||||
result = result["choices"][0]["text"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
result = result_full["choices"][0]["text"]
|
||||
usage = result_full.get("usage", None)
|
||||
else:
|
||||
# Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"}
|
||||
if "context length" in str(response.text).lower():
|
||||
@@ -62,4 +75,14 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
# Pass usage statistics back to main thread
|
||||
# These are used to compute memory warning messages
|
||||
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
|
||||
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return result, usage
|
||||
|
||||
@@ -7,11 +7,12 @@ from ..utils import count_tokens
|
||||
from ...errors import LocalLLMError
|
||||
|
||||
OLLAMA_API_SUFFIX = "/api/generate"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None):
|
||||
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
@@ -39,10 +40,10 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP
|
||||
URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result = result["response"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
# https://github.com/jmorganca/ollama/blob/main/docs/api.md
|
||||
result_full = response.json()
|
||||
printd(f"JSON API response:\n{result_full}")
|
||||
result = result_full["response"]
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
@@ -53,4 +54,15 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
# Pass usage statistics back to main thread
|
||||
# These are used to compute memory warning messages
|
||||
# https://github.com/jmorganca/ollama/blob/main/docs/api.md#response
|
||||
completion_tokens = result_full.get("eval_count", None)
|
||||
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens, # can also grab from "prompt_eval_count"
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return result, usage
|
||||
|
||||
@@ -5,11 +5,12 @@ import requests
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
|
||||
WEBUI_API_SUFFIX = "/v1/completions"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def get_vllm_completion(endpoint, model, prompt, context_window, user, settings={}, grammar=None):
|
||||
"""https://github.com/vllm-project/vllm/blob/main/examples/api_client.py"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
@@ -36,10 +37,10 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings=
|
||||
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result = result["choices"][0]["text"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
result_full = response.json()
|
||||
printd(f"JSON API response:\n{result_full}")
|
||||
result = result_full["choices"][0]["text"]
|
||||
usage = result_full.get("usage", None)
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
@@ -50,4 +51,14 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings=
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
# Pass usage statistics back to main thread
|
||||
# These are used to compute memory warning messages
|
||||
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
|
||||
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return result, usage
|
||||
|
||||
@@ -6,11 +6,12 @@ from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
|
||||
WEBUI_API_SUFFIX = "/v1/completions"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None):
|
||||
"""Compatibility for the new OpenAI API: https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
@@ -33,10 +34,10 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram
|
||||
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result = result["choices"][0]["text"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
result_full = response.json()
|
||||
printd(f"JSON API response:\n{result_full}")
|
||||
result = result_full["choices"][0]["text"]
|
||||
usage = result_full.get("usage", None)
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
@@ -47,4 +48,14 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
# Pass usage statistics back to main thread
|
||||
# These are used to compute memory warning messages
|
||||
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
|
||||
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return result, usage
|
||||
|
||||
@@ -6,11 +6,12 @@ from .legacy_settings import SIMPLE
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
|
||||
WEBUI_API_SUFFIX = "/api/v1/generate"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None):
|
||||
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
@@ -31,10 +32,9 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram
|
||||
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result = result["results"][0]["text"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
result_full = response.json()
|
||||
printd(f"JSON API response:\n{result_full}")
|
||||
result = result_full["results"][0]["text"]
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
@@ -45,4 +45,13 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
# TODO correct for legacy
|
||||
completion_tokens = None
|
||||
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return result, usage
|
||||
|
||||
Reference in New Issue
Block a user