add ollama support (#314)
* untested * patch * updated * clarified using tags in docs * tested ollama, working * fixed template issue by creating dummy template, also added missing context length indicator * moved count_tokens to utils.py * clean
This commit is contained in:
39
docs/ollama.md
Normal file
39
docs/ollama.md
Normal file
@@ -0,0 +1,39 @@
|
||||
### MemGPT + Ollama
|
||||
|
||||
!!! warning "Be careful when downloading Ollama models!"
|
||||
|
||||
Make sure to use tags when downloading Ollama models! Don't do `ollama run dolphin2.2-mistral`, do `ollama run dolphin2.2-mistral:7b-q6_K`.
|
||||
|
||||
If you don't specify a tag, Ollama may default to using a highly compressed model variant (e.g. Q4). We highly recommend **NOT** using a compression level below Q4 (stick to Q6, Q8, or fp16 if possible). In our testing, models below Q6 start to become extremely unstable when used with MemGPT.
|
||||
|
||||
1. Download + install [Ollama](https://github.com/jmorganca/ollama) and the model you want to test with
|
||||
2. Download a model to test with by running `ollama run <MODEL_NAME>` in the terminal (check the [Ollama model library](https://ollama.ai/library) for available models)
|
||||
3. In addition to setting `OPENAI_API_BASE` and `BACKEND_TYPE`, we additionally need to set `OLLAMA_MODEL` (to the Ollama model name)
|
||||
|
||||
For example, if we want to use Dolphin 2.2.1 Mistral, we can download it by running:
|
||||
```sh
|
||||
# Let's use the q6_K variant
|
||||
ollama run dolphin2.2-mistral:7b-q6_K
|
||||
```
|
||||
```text
|
||||
pulling manifest
|
||||
pulling d8a5ee4aba09... 100% |█████████████████████████████████████████████████████████████████████████| (4.1/4.1 GB, 20 MB/s)
|
||||
pulling a47b02e00552... 100% |██████████████████████████████████████████████████████████████████████████████| (106/106 B, 77 B/s)
|
||||
pulling 9640c2212a51... 100% |████████████████████████████████████████████████████████████████████████████████| (41/41 B, 22 B/s)
|
||||
pulling de6bcd73f9b4... 100% |████████████████████████████████████████████████████████████████████████████████| (58/58 B, 28 B/s)
|
||||
pulling 95c3d8d4429f... 100% |█████████████████████████████████████████████████████████████████████████████| (455/455 B, 330 B/s)
|
||||
verifying sha256 digest
|
||||
writing manifest
|
||||
removing any unused layers
|
||||
success
|
||||
```
|
||||
|
||||
In your terminal where you're running MemGPT, run:
|
||||
```sh
|
||||
# By default, Ollama runs an API server on port 11434
|
||||
export OPENAI_API_BASE=http://localhost:11434
|
||||
export BACKEND_TYPE=ollama
|
||||
|
||||
# Make sure to add the tag!
|
||||
export OLLAMA_MODEL=dolphin2.2-mistral:7b-q6_K
|
||||
```
|
||||
@@ -8,6 +8,7 @@ from .webui.api import get_webui_completion
|
||||
from .lmstudio.api import get_lmstudio_completion
|
||||
from .llamacpp.api import get_llamacpp_completion
|
||||
from .koboldcpp.api import get_koboldcpp_completion
|
||||
from .ollama.api import get_ollama_completion
|
||||
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
|
||||
from .utils import DotDict
|
||||
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
@@ -96,6 +97,8 @@ def get_chat_completion(
|
||||
result = get_llamacpp_completion(prompt, grammar=grammar_name)
|
||||
elif HOST_TYPE == "koboldcpp":
|
||||
result = get_koboldcpp_completion(prompt, grammar=grammar_name)
|
||||
elif HOST_TYPE == "ollama":
|
||||
result = get_ollama_completion(prompt)
|
||||
else:
|
||||
raise LocalLLMError(
|
||||
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
from ...constants import LLM_MAX_TOKENS
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
@@ -14,11 +13,6 @@ KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
|
||||
DEBUG = True
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(s))
|
||||
|
||||
|
||||
def get_koboldcpp_completion(prompt, grammar=None, settings=SIMPLE):
|
||||
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
from ...constants import LLM_MAX_TOKENS
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
@@ -14,11 +13,6 @@ LLAMACPP_API_SUFFIX = "/completion"
|
||||
DEBUG = True
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(s))
|
||||
|
||||
|
||||
def get_llamacpp_completion(prompt, grammar=None, settings=SIMPLE):
|
||||
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
|
||||
57
memgpt/local_llm/ollama/api.py
Normal file
57
memgpt/local_llm/ollama/api.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
|
||||
from .settings import SIMPLE
|
||||
from ..utils import count_tokens
|
||||
from ...constants import LLM_MAX_TOKENS
|
||||
from ...errors import LocalLLMError
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
MODEL_NAME = os.getenv("OLLAMA_MODEL") # ollama API requires this in the request
|
||||
OLLAMA_API_SUFFIX = "/api/generate"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def get_ollama_completion(prompt, settings=SIMPLE, grammar=None):
|
||||
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > LLM_MAX_TOKENS:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)")
|
||||
|
||||
if MODEL_NAME is None:
|
||||
raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')")
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["prompt"] = prompt
|
||||
request["model"] = MODEL_NAME
|
||||
|
||||
# Set grammar
|
||||
if grammar is not None:
|
||||
# request["grammar_string"] = load_grammar_file(grammar)
|
||||
raise NotImplementedError(f"Ollama does not support grammars")
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
URI = urljoin(HOST.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result = result["response"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
+ f" Make sure that the ollama API server is running and reachable at {URI}."
|
||||
)
|
||||
|
||||
except:
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
34
memgpt/local_llm/ollama/settings.py
Normal file
34
memgpt/local_llm/ollama/settings.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from ...constants import LLM_MAX_TOKENS
|
||||
|
||||
# see https://github.com/jmorganca/ollama/blob/main/docs/api.md
|
||||
# and https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
|
||||
SIMPLE = {
|
||||
"options": {
|
||||
"stop": [
|
||||
"\nUSER:",
|
||||
"\nASSISTANT:",
|
||||
"\nFUNCTION RETURN:",
|
||||
"\nUSER",
|
||||
"\nASSISTANT",
|
||||
"\nFUNCTION RETURN",
|
||||
"\nFUNCTION",
|
||||
"\nFUNC",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|im_sep|>",
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
"num_ctx": LLM_MAX_TOKENS,
|
||||
},
|
||||
"stream": False,
|
||||
# turn off Ollama's own prompt formatting
|
||||
"system": "",
|
||||
"template": "{{ .Prompt }}",
|
||||
# "system": None,
|
||||
# "template": None,
|
||||
"context": None,
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import tiktoken
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
@@ -31,3 +32,8 @@ def load_grammar_file(grammar):
|
||||
grammar_str = file.read()
|
||||
|
||||
return grammar_str
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(s))
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
from ...constants import LLM_MAX_TOKENS
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
@@ -13,11 +12,6 @@ WEBUI_API_SUFFIX = "/api/v1/generate"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(s))
|
||||
|
||||
|
||||
def get_webui_completion(prompt, settings=SIMPLE, grammar=None):
|
||||
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
|
||||
@@ -21,6 +21,7 @@ nav:
|
||||
- 'LM Studio': lmstudio.md
|
||||
- 'llama.cpp': llamacpp.md
|
||||
- 'koboldcpp': koboldcpp.md
|
||||
- 'ollama': ollama.md
|
||||
- 'Troubleshooting': local_llm_faq.md
|
||||
- 'Integrations':
|
||||
- 'Autogen': autogen.md
|
||||
|
||||
Reference in New Issue
Block a user