diff --git a/README.md b/README.md
index 347dc39f..2362d4e8 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,8 @@
Try out our MemGPT chatbot on Discord!
+
+
⭐ NEW: You can now run MemGPT with local LLMs! ⭐
[](https://discord.gg/9GEQrxmVyE)
[](https://arxiv.org/abs/2310.08560)
@@ -130,6 +132,9 @@ python main.py --model gpt-3.5-turbo
Please report any bugs you encounter regarding MemGPT running on GPT-3.5 to https://github.com/cpacker/MemGPT/issues/59.
+### Local LLM support
+You can run MemGPT with local LLMs too. See [instructions here](/memgpt/local_llm) and report any bugs/improvements here https://github.com/cpacker/MemGPT/discussions/67.
+
### `main.py` flags
```text
diff --git a/memgpt/local_llm/README.md b/memgpt/local_llm/README.md
new file mode 100644
index 00000000..a79c0f9e
--- /dev/null
+++ b/memgpt/local_llm/README.md
@@ -0,0 +1,103 @@
+⁉️ Need help configuring local LLMs with MemGPT? Ask for help on [our Discord](https://discord.gg/9GEQrxmVyE) or [post on the GitHub discussion](https://github.com/cpacker/MemGPT/discussions/67).
+
+👀 If you have a hosted ChatCompletion-compatible endpoint that works with function calling, you can simply set `OPENAI_API_BASE` (`export OPENAI_API_BASE=...`) to the IP+port of your endpoint. **As of 10/22/2023, most ChatCompletion endpoints do *NOT* support function calls, so if you want to play with MemGPT and open models, you probably need to follow the instructions below.**
+
+🙋 Our examples assume that you're using [oobabooga web UI](https://github.com/oobabooga/text-generation-webui#starting-the-web-ui) to put your LLMs behind a web server. If you need help setting this up, check the instructions [here](https://github.com/oobabooga/text-generation-webui#starting-the-web-ui). More LLM web server support to come soon (tell us what you use and we'll add it)!
+
+---
+
+# How to connect MemGPT to non-OpenAI LLMs
+
+**If you have an LLM that is function-call finetuned**:
+ - Implement a wrapper class for that model
+ - The wrapper class needs to implement two functions:
+ - One to go from ChatCompletion messages/functions schema to a prompt string
+ - And one to go from raw LLM outputs to a ChatCompletion response
+ - Put that model behind a server (e.g. using WebUI) and set `OPENAI_API_BASE`
+
+```python
+class LLMChatCompletionWrapper(ABC):
+
+ @abstractmethod
+ def chat_completion_to_prompt(self, messages, functions):
+ """Go from ChatCompletion to a single prompt string"""
+ pass
+
+ @abstractmethod
+ def output_to_chat_completion_response(self, raw_llm_output):
+ """Turn the LLM output string into a ChatCompletion response"""
+ pass
+```
+
+## Example with [Airoboros](https://huggingface.co/jondurbin/airoboros-l2-70b-2.1) (llama2 finetune)
+
+To help you get started, we've implemented an example wrapper class for a popular llama2 model **finetuned on function calling** (Airoboros). We want MemGPT to run well on open models as much as you do, so we'll be actively updating this page with more examples. Additionally, we welcome contributions from the community! If you find an open LLM that works well with MemGPT, please open a PR with a model wrapper and we'll merge it ASAP.
+
+```python
+class Airoboros21Wrapper(LLMChatCompletionWrapper):
+ """Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1"""
+
+ def chat_completion_to_prompt(self, messages, functions):
+ """
+ Examples for how airoboros expects its prompt inputs: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
+ Examples for how airoboros expects to see function schemas: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling
+ """
+
+ def output_to_chat_completion_response(self, raw_llm_output):
+ """Turn raw LLM output into a ChatCompletion style response with:
+ "message" = {
+ "role": "assistant",
+ "content": ...,
+ "function_call": {
+ "name": ...
+ "arguments": {
+ "arg1": val1,
+ ...
+ }
+ }
+ }
+ """
+```
+See full file [here](llm_chat_completion_wrappers/airoboros.py). WebUI exposes a lot of parameters that can dramatically change LLM outputs, to change these you can modify the [WebUI settings file](/memgpt/local_llm/webui/settings.py).
+
+### Running the example
+
+```sh
+# running airoboros behind a textgen webui server
+export OPENAI_API_BASE =
+export BACKEND_TYPE = webui
+
+# using --no_verify because this airoboros example does not output inner monologue, just functions
+# airoboros is able to properly call `send_message`
+$ python3 main.py --no_verify
+
+Running... [exit by typing '/exit']
+💭 Bootup sequence complete. Persona activated. Testing messaging functionality.
+
+💭 None
+🤖 Welcome! My name is Sam. How can I assist you today?
+Enter your message: My name is Brad, not Chad...
+
+💭 None
+⚡🧠 [function] updating memory with core_memory_replace:
+ First name: Chad
+ → First name: Brad
+```
+
+---
+
+## Status of ChatCompletion w/ function calling and open LLMs
+
+MemGPT uses function calling to do memory management. With [OpenAI's ChatCompletion API](https://platform.openai.com/docs/api-reference/chat/), you can pass in a function schema in the `functions` keyword arg, and the API response will include a `function_call` field that includes the function name and the function arguments (generated JSON). How this works under the hood is your `functions` keyword is combined with the `messages` and `system` to form one big string input to the transformer, and the output of the transformer is parsed to extract the JSON function call.
+
+In the future, more open LLMs and LLM servers (that can host OpenAI-compatable ChatCompletion endpoints) may start including parsing code to do this automatically as standard practice. However, in the meantime, when you see a model that says it supports “function calling”, like Airoboros, it doesn't mean that you can just load Airoboros into a ChatCompletion-compatable endpoint like WebUI, and then use the same OpenAI API call and it'll just work.
+
+1. When a model page says it supports function calling, they probably mean that the model was finetuned on some function call data (not that you can just use ChatCompletion with functions out-of-the-box). Remember, LLMs are just string-in-string-out, so there are many ways to format the function call data. E.g. Airoboros formats the function schema in YAML style (see https://huggingface.co/jondurbin/airoboros-l2-70b-3.1.2#agentfunction-calling) and the output is in JSON style. To get this to work behind a ChatCompletion API, you still have to do the parsing from `functions` keyword arg (containing the schema) to the model's expected schema style in the prompt (YAML for Airoboros), and you have to run some code to extract the function call (JSON for Airoboros) and package it cleanly as a `function_call` field in the response.
+
+2. Partly because of how complex it is to support function calling, most (all?) of the community projects that do OpenAI ChatCompletion endpoints for arbitrary open LLMs do not support function calling, because if they did, they would need to write model-specific parsing code for each one.
+
+## What is this all this extra code for?
+
+Because of the poor state of function calling support in existing ChatCompletion API serving code, we instead provide a light wrapper on top of ChatCompletion that adds parsers to handle function calling support. These parsers need to be specific to the model you're using (or at least specific to the way it was trained on function calling). We hope that our example code will help the community add additional compatability of MemGPT with more function-calling LLMs - we will also add more model support as we test more models and find those that work well enough to run MemGPT's function set.
+
+To run the example of MemGPT with Airoboros, you'll need to host the model behind some LLM web server (for example [webui](https://github.com/oobabooga/text-generation-webui#starting-the-web-ui)). Then, all you need to do is point MemGPT to this API endpoint by setting the environment variables `OPENAI_API_BASE` and `BACKEND_TYPE`. Now, instead of calling ChatCompletion on OpenAI's API, MemGPT will use it's own ChatCompletion wrapper that parses the system, messages, and function arguments into a format that Airoboros has been finetuned on, and once Airoboros generates a string output, MemGPT will parse the response to extract a potential function call (knowing what we know about Airoboros expected function call output).
diff --git a/memgpt/local_llm/__init__.py b/memgpt/local_llm/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py
new file mode 100644
index 00000000..a5290717
--- /dev/null
+++ b/memgpt/local_llm/chat_completion_proxy.py
@@ -0,0 +1,74 @@
+"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
+
+import os
+import requests
+import json
+
+from .webui.api import get_webui_completion
+from .llm_chat_completion_wrappers import airoboros
+from .utils import DotDict
+
+HOST = os.getenv("OPENAI_API_BASE")
+HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
+DEBUG = True
+
+
+async def get_chat_completion(
+ model, # no model, since the model is fixed to whatever you set in your own backend
+ messages,
+ functions,
+ function_call="auto",
+):
+ if function_call != "auto":
+ raise ValueError(f"function_call == {function_call} not supported (auto only)")
+
+ if model == "airoboros_v2.1":
+ llm_wrapper = airoboros.Airoboros21Wrapper()
+ else:
+ # Warn the user that we're using the fallback
+ print(
+ f"Warning: could not find an LLM wrapper for {model}, using the airoboros wrapper"
+ )
+ llm_wrapper = airoboros.Airoboros21Wrapper()
+
+ # First step: turn the message sequence into a prompt that the model expects
+ prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
+ if DEBUG:
+ print(prompt)
+
+ try:
+ if HOST_TYPE == "webui":
+ result = get_webui_completion(prompt)
+ else:
+ print(f"Warning: BACKEND_TYPE was not set, defaulting to webui")
+ result = get_webui_completion(prompt)
+ except requests.exceptions.ConnectionError as e:
+ raise ValueError(f"Was unable to connect to host {HOST}")
+
+ chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
+ if DEBUG:
+ print(json.dumps(chat_completion_result, indent=2))
+
+ # unpack with response.choices[0].message.content
+ response = DotDict(
+ {
+ "model": None,
+ "choices": [
+ DotDict(
+ {
+ "message": DotDict(chat_completion_result),
+ "finish_reason": "stop", # TODO vary based on backend response
+ }
+ )
+ ],
+ "usage": DotDict(
+ {
+ # TODO fix, actually use real info
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0,
+ }
+ ),
+ }
+ )
+ return response
diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/__init__.py b/memgpt/local_llm/llm_chat_completion_wrappers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
new file mode 100644
index 00000000..98d3625e
--- /dev/null
+++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
@@ -0,0 +1,204 @@
+import json
+
+from .wrapper_base import LLMChatCompletionWrapper
+
+
+class Airoboros21Wrapper(LLMChatCompletionWrapper):
+ """Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1
+
+ Note: this wrapper formats a prompt that only generates JSON, no inner thoughts
+ """
+
+ def __init__(
+ self,
+ simplify_json_content=True,
+ clean_function_args=True,
+ include_assistant_prefix=True,
+ include_opening_brace_in_prefix=True,
+ include_section_separators=True,
+ ):
+ self.simplify_json_content = simplify_json_content
+ self.clean_func_args = clean_function_args
+ self.include_assistant_prefix = include_assistant_prefix
+ self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
+ self.include_section_separators = include_section_separators
+
+ def chat_completion_to_prompt(self, messages, functions):
+ """Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
+
+ A chat.
+ USER: {prompt}
+ ASSISTANT:
+
+ Functions support: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling
+
+ As an AI assistant, please select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format.
+
+ Input: I want to know how many times 'Python' is mentioned in my text file.
+
+ Available functions:
+ file_analytics:
+ description: This tool performs various operations on a text file.
+ params:
+ action: The operation we want to perform on the data, such as "count_occurrences", "find_line", etc.
+ filters:
+ keyword: The word or phrase we want to search for.
+
+ OpenAI functions schema style:
+
+ {
+ "name": "send_message",
+ "description": "Sends a message to the human user",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ # https://json-schema.org/understanding-json-schema/reference/array.html
+ "message": {
+ "type": "string",
+ "description": "Message contents. All unicode (including emojis) are supported.",
+ },
+ },
+ "required": ["message"],
+ }
+ },
+ """
+ prompt = ""
+
+ # System insturctions go first
+ assert messages[0]["role"] == "system"
+ prompt += messages[0]["content"]
+
+ # Next is the functions preamble
+ def create_function_description(schema):
+ # airorobos style
+ func_str = ""
+ func_str += f"{schema['name']}:"
+ func_str += f"\n description: {schema['description']}"
+ func_str += f"\n params:"
+ for param_k, param_v in schema["parameters"]["properties"].items():
+ # TODO we're ignoring type
+ func_str += f"\n {param_k}: {param_v['description']}"
+ # TODO we're ignoring schema['parameters']['required']
+ return func_str
+
+ # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
+ prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
+ prompt += f"\nAvailable functions:"
+ for function_dict in functions:
+ prompt += f"\n{create_function_description(function_dict)}"
+
+ def create_function_call(function_call):
+ """Go from ChatCompletion to Airoboros style function trace (in prompt)
+
+ ChatCompletion data (inside message['function_call']):
+ "function_call": {
+ "name": ...
+ "arguments": {
+ "arg1": val1,
+ ...
+ }
+
+ Airoboros output:
+ {
+ "function": "send_message",
+ "params": {
+ "message": "Hello there! I am Sam, an AI developed by Liminal Corp. How can I assist you today?"
+ }
+ }
+ """
+ airo_func_call = {
+ "function": function_call["name"],
+ "params": json.loads(function_call["arguments"]),
+ }
+ return json.dumps(airo_func_call, indent=2)
+
+ # Add a sep for the conversation
+ if self.include_section_separators:
+ prompt += "\n### INPUT"
+
+ # Last are the user/assistant messages
+ for message in messages[1:]:
+ assert message["role"] in ["user", "assistant", "function"], message
+
+ if message["role"] == "user":
+ if self.simplify_json_content:
+ try:
+ content_json = json.loads(message["content"])
+ content_simple = content_json["message"]
+ prompt += f"\nUSER: {content_simple}"
+ except:
+ prompt += f"\nUSER: {message['content']}"
+ elif message["role"] == "assistant":
+ prompt += f"\nASSISTANT: {message['content']}"
+ # need to add the function call if there was one
+ if message["function_call"]:
+ prompt += f"\n{create_function_call(message['function_call'])}"
+ elif message["role"] == "function":
+ # TODO find a good way to add this
+ # prompt += f"\nASSISTANT: (function return) {message['content']}"
+ prompt += f"\nFUNCTION RETURN: {message['content']}"
+ continue
+ else:
+ raise ValueError(message)
+
+ # Add a sep for the response
+ if self.include_section_separators:
+ prompt += "\n### RESPONSE"
+
+ if self.include_assistant_prefix:
+ prompt += f"\nASSISTANT:"
+ if self.include_opening_brance_in_prefix:
+ prompt += "\n{"
+
+ return prompt
+
+ def clean_function_args(self, function_name, function_args):
+ """Some basic MemGPT-specific cleaning of function args"""
+ cleaned_function_name = function_name
+ cleaned_function_args = function_args.copy()
+
+ if function_name == "send_message":
+ # strip request_heartbeat
+ cleaned_function_args.pop("request_heartbeat", None)
+
+ # TODO more cleaning to fix errors LLM makes
+ return cleaned_function_name, cleaned_function_args
+
+ def output_to_chat_completion_response(self, raw_llm_output):
+ """Turn raw LLM output into a ChatCompletion style response with:
+ "message" = {
+ "role": "assistant",
+ "content": ...,
+ "function_call": {
+ "name": ...
+ "arguments": {
+ "arg1": val1,
+ ...
+ }
+ }
+ }
+ """
+ if self.include_opening_brance_in_prefix and raw_llm_output[0] != "{":
+ raw_llm_output = "{" + raw_llm_output
+
+ try:
+ function_json_output = json.loads(raw_llm_output)
+ except Exception as e:
+ raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
+ function_name = function_json_output["function"]
+ function_parameters = function_json_output["params"]
+
+ if self.clean_func_args:
+ function_name, function_parameters = self.clean_function_args(
+ function_name, function_parameters
+ )
+
+ message = {
+ "role": "assistant",
+ "content": None,
+ "function_call": {
+ "name": function_name,
+ "arguments": json.dumps(function_parameters),
+ },
+ }
+ return message
diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py b/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py
new file mode 100644
index 00000000..b1186c46
--- /dev/null
+++ b/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py
@@ -0,0 +1,13 @@
+from abc import ABC, abstractmethod
+
+
+class LLMChatCompletionWrapper(ABC):
+ @abstractmethod
+ def chat_completion_to_prompt(self, messages, functions):
+ """Go from ChatCompletion to a single prompt string"""
+ pass
+
+ @abstractmethod
+ def output_to_chat_completion_response(self, raw_llm_output):
+ """Turn the LLM output string into a ChatCompletion response"""
+ pass
diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py
new file mode 100644
index 00000000..42a0ce27
--- /dev/null
+++ b/memgpt/local_llm/utils.py
@@ -0,0 +1,8 @@
+class DotDict(dict):
+ """Allow dot access on properties similar to OpenAI response object"""
+
+ def __getattr__(self, attr):
+ return self.get(attr)
+
+ def __setattr__(self, key, value):
+ self[key] = value
diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py
new file mode 100644
index 00000000..3cff08e0
--- /dev/null
+++ b/memgpt/local_llm/webui/api.py
@@ -0,0 +1,33 @@
+import os
+import requests
+
+from .settings import SIMPLE
+
+HOST = os.getenv("OPENAI_API_BASE")
+HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
+WEBUI_API_SUFFIX = "/v1/generate"
+DEBUG = True
+
+
+def get_webui_completion(prompt, settings=SIMPLE):
+ """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
+
+ # Settings for the generation, includes the prompt + stop tokens, max length, etc
+ request = settings
+ request["prompt"] = prompt
+
+ try:
+ URI = f"{HOST}{WEBUI_API_SUFFIX}"
+ response = requests.post(URI, json=request)
+ if response.status_code == 200:
+ result = response.json()
+ result = result["results"][0]["text"]
+ if DEBUG:
+ print(f"json API response.text: {result}")
+ else:
+ raise Exception(f"API call got non-200 response code")
+ except:
+ # TODO handle gracefully
+ raise
+
+ return result
diff --git a/memgpt/local_llm/webui/settings.py b/memgpt/local_llm/webui/settings.py
new file mode 100644
index 00000000..2e9ecbce
--- /dev/null
+++ b/memgpt/local_llm/webui/settings.py
@@ -0,0 +1,12 @@
+SIMPLE = {
+ "stopping_strings": [
+ "\nUSER:",
+ "\nASSISTANT:",
+ # '\n' +
+ # '',
+ # '<|',
+ # '\n#',
+ # '\n\n\n',
+ ],
+ "truncation_length": 4096, # assuming llama2 models
+}
diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py
index 98444878..3d63d134 100644
--- a/memgpt/openai_tools.py
+++ b/memgpt/openai_tools.py
@@ -3,8 +3,16 @@ import random
import os
import time
+from .local_llm.chat_completion_proxy import get_chat_completion
+
+HOST = os.getenv("OPENAI_API_BASE")
+HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
+
import openai
+if HOST is not None:
+ openai.api_base = HOST
+
def retry_with_exponential_backoff(
func,
@@ -102,18 +110,24 @@ def aretry_with_exponential_backoff(
@aretry_with_exponential_backoff
async def acompletions_with_backoff(**kwargs):
- azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
- if azure_openai_deployment is not None:
- kwargs['deployment_id'] = azure_openai_deployment
- return await openai.ChatCompletion.acreate(**kwargs)
+ # Local model
+ if HOST_TYPE is not None:
+ return await get_chat_completion(**kwargs)
+
+ # OpenAI / Azure model
+ else:
+ azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
+ if azure_openai_deployment is not None:
+ kwargs["deployment_id"] = azure_openai_deployment
+ return await openai.ChatCompletion.acreate(**kwargs)
@aretry_with_exponential_backoff
async def acreate_embedding_with_backoff(**kwargs):
"""Wrapper around Embedding.acreate w/ backoff"""
- azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
+ azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
- kwargs['deployment_id'] = azure_openai_deployment
+ kwargs["deployment_id"] = azure_openai_deployment
return await openai.Embedding.acreate(**kwargs)
@@ -121,6 +135,6 @@ async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002")
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
- response = await acreate_embedding_with_backoff(input = [text], model=model)
- embedding = response['data'][0]['embedding']
- return embedding
\ No newline at end of file
+ response = await acreate_embedding_with_backoff(input=[text], model=model)
+ embedding = response["data"][0]["embedding"]
+ return embedding