Files
letta-server/memgpt/local_llm/koboldcpp/api.py
Charles Packer e90c00ad63 Add grammar-based sampling (for webui, llamacpp, and koboldcpp) (#293)
* add llamacpp server support

* use gbnf loader

* cleanup and warning about grammar when not using llama.cpp

* added memgpt-specific grammar file

* add grammar support to webui api calls

* black

* typo

* add koboldcpp support

* no more defaulting to webui, should error out instead

* fix grammar

* patch kobold (testing, now working) + cleanup log messages

Co-Authored-By: Drake-AI <drake-ai@users.noreply.github.com>
2023-11-04 12:02:44 -07:00

60 lines
2.0 KiB
Python

import os
from urllib.parse import urljoin
import requests
import tiktoken
from .settings import SIMPLE
from ..utils import load_grammar_file
from ...constants import LLM_MAX_TOKENS
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
# DEBUG = False
DEBUG = True
def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))
def get_koboldcpp_completion(prompt, grammar=None, settings=SIMPLE):
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > LLM_MAX_TOKENS:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)")
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt
# Set grammar
if grammar is not None:
request["grammar"] = load_grammar_file(grammar)
if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
try:
# NOTE: llama.cpp server returns the following when it's out of context
# curl: (52) Empty reply from server
URI = urljoin(HOST.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["results"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f" Make sure that the koboldcpp server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
raise
return result