Removing dead code + legacy commands (#536)
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
poetry install -E dev -E postgres -E local -E legacy
|
||||
poetry install -E dev -E postgres -E local
|
||||
|
||||
- name: Set Poetry config
|
||||
env:
|
||||
|
||||
@@ -29,7 +29,7 @@ Once Poetry is installed, navigate to the MemGPT directory and install the MemGP
|
||||
```shell
|
||||
cd MemGPT
|
||||
poetry shell
|
||||
poetry install -E dev -E postgres -E local -E legacy
|
||||
poetry install -E dev -E postgres -E local
|
||||
```
|
||||
|
||||
Now when you want to use `memgpt`, make sure you first activate the `poetry` environment using poetry shell:
|
||||
@@ -54,7 +54,7 @@ python3 -m venv venv
|
||||
|
||||
Once you've activated your virtual environment and are in the MemGPT project directory, you can install the dependencies with `pip`:
|
||||
```shell
|
||||
pip install -e '.[dev,postgres,local,legacy]'
|
||||
pip install -e '.[dev,postgres,local]'
|
||||
```
|
||||
|
||||
Now, you should be able to run `memgpt` from the command-line using the downloaded source code (if you used a virtual environment, you have to activate the virtual environment to access `memgpt`):
|
||||
@@ -105,8 +105,8 @@ pytest -s tests
|
||||
### Creating new tests
|
||||
If you added a major feature change, please add new tests in the `tests/` directory.
|
||||
|
||||
## 4. 🧩 Adding new dependencies
|
||||
If you need to add a new dependency to MemGPT, please add the package via `poetry add <PACKAGE_NAME>`. This will update the `pyproject.toml` and `poetry.lock` files. If the dependency does not need to be installed by all users, make sure to mark the dependency as optional in the `pyproject.toml` file and if needed, create a new extra under `[tool.poetry.extras]`.
|
||||
## 4. 🧩 Adding new dependencies
|
||||
If you need to add a new dependency to MemGPT, please add the package via `poetry add <PACKAGE_NAME>`. This will update the `pyproject.toml` and `poetry.lock` files. If the dependency does not need to be installed by all users, make sure to mark the dependency as optional in the `pyproject.toml` file and if needed, create a new extra under `[tool.poetry.extras]`.
|
||||
|
||||
## 5. 🚀 Submitting Changes
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ Once Poetry is installed, navigate to the MemGPT directory and install the MemGP
|
||||
```shell
|
||||
cd MemGPT
|
||||
poetry shell
|
||||
poetry install -E dev -E postgres -E local -E legacy
|
||||
poetry install -E dev -E postgres -E local
|
||||
```
|
||||
|
||||
Now when you want to use `memgpt`, make sure you first activate the `poetry` environment using poetry shell:
|
||||
@@ -38,7 +38,7 @@ python3 -m venv venv
|
||||
|
||||
Once you've activated your virtual environment and are in the MemGPT project directory, you can install the dependencies with `pip`:
|
||||
```shell
|
||||
pip install -e '.[dev,postgres,local,legacy]'
|
||||
pip install -e '.[dev,postgres,local]'
|
||||
```
|
||||
|
||||
Now, you should be able to run `memgpt` from the command-line using the downloaded source code (if you used a virtual environment, you have to activate the virtual environment to access `memgpt`):
|
||||
|
||||
6
main.py
6
main.py
@@ -1,3 +1,7 @@
|
||||
from memgpt.main import app
|
||||
import typer
|
||||
|
||||
app()
|
||||
typer.secho(
|
||||
"Command `python main.py` no longer supported. Please run `memgpt run`. See https://memgpt.readthedocs.io/en/latest/quickstart/.",
|
||||
fg=typer.colors.YELLOW,
|
||||
)
|
||||
|
||||
@@ -6,14 +6,12 @@ import traceback
|
||||
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||
from .memory import CoreMemory as Memory, summarize_messages
|
||||
from .openai_tools import completions_with_backoff as create, is_context_overflow_error
|
||||
from memgpt.openai_tools import chat_completion_with_backoff
|
||||
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
|
||||
from .constants import (
|
||||
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||
from memgpt.memory import CoreMemory as Memory, summarize_messages
|
||||
from memgpt.openai_tools import create, is_context_overflow_error
|
||||
from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
|
||||
from memgpt.constants import (
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
MAX_PAUSE_HEARTBEATS,
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
@@ -759,33 +757,8 @@ class Agent(object):
|
||||
function_call="auto",
|
||||
):
|
||||
"""Get response from LLM API"""
|
||||
|
||||
# TODO: Legacy code - delete
|
||||
if self.config is None:
|
||||
try:
|
||||
response = create(
|
||||
model=self.model,
|
||||
context_window=self.context_window,
|
||||
messages=message_sequence,
|
||||
functions=self.functions,
|
||||
function_call=function_call,
|
||||
)
|
||||
|
||||
# special case for 'length'
|
||||
if response.choices[0].finish_reason == "length":
|
||||
raise Exception("Finish reason was length (maximum context length)")
|
||||
|
||||
# catches for soft errors
|
||||
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
||||
raise Exception(f"API call finish with bad finish reason: {response}")
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
response = chat_completion_with_backoff(
|
||||
response = create(
|
||||
agent_config=self.config,
|
||||
messages=message_sequence,
|
||||
functions=self.functions,
|
||||
|
||||
@@ -3,21 +3,14 @@ import json
|
||||
import sys
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from prettytable import PrettyTable
|
||||
import questionary
|
||||
|
||||
from llama_index import set_global_service_context
|
||||
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
|
||||
from llama_index import ServiceContext
|
||||
|
||||
from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
from memgpt.cli.cli_config import configure
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
import memgpt.utils as utils
|
||||
from memgpt.utils import printd
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
@@ -25,10 +18,6 @@ from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.openai_tools import (
|
||||
configure_azure_support,
|
||||
check_azure_embeddings,
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
@@ -196,11 +185,6 @@ def run(
|
||||
# start event loop
|
||||
from memgpt.main import run_agent_loop
|
||||
|
||||
# setup azure if using
|
||||
# TODO: cleanup this code
|
||||
if config.model_endpoint == "azure":
|
||||
configure_azure_support()
|
||||
|
||||
run_agent_loop(memgpt_agent, first, no_verify, config) # TODO: add back no_verify
|
||||
|
||||
|
||||
|
||||
@@ -4,14 +4,11 @@ from prettytable import PrettyTable
|
||||
import typer
|
||||
import os
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
|
||||
# from memgpt.cli import app
|
||||
from memgpt import utils
|
||||
|
||||
import memgpt.humans.humans as humans
|
||||
import memgpt.personas.personas as personas
|
||||
from memgpt.config import MemGPTConfig, AgentConfig, Config
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
from memgpt.constants import LLM_MAX_TOKENS
|
||||
|
||||
333
memgpt/config.py
333
memgpt/config.py
@@ -1,44 +1,14 @@
|
||||
import glob
|
||||
import inspect
|
||||
import random
|
||||
import string
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
import configparser
|
||||
|
||||
|
||||
import questionary
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
from typing import List, Type
|
||||
|
||||
import memgpt
|
||||
import memgpt.utils as utils
|
||||
from memgpt.interface import CLIInterface as interface
|
||||
from memgpt.personas.personas import get_persona_text
|
||||
from memgpt.humans.humans import get_human_text
|
||||
from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
from memgpt.presets.presets import DEFAULT_PRESET, preset_options
|
||||
|
||||
|
||||
model_choices = [
|
||||
questionary.Choice("gpt-4"),
|
||||
questionary.Choice(
|
||||
"gpt-4-turbo (developer preview)",
|
||||
value="gpt-4-1106-preview",
|
||||
),
|
||||
questionary.Choice(
|
||||
"gpt-3.5-turbo (experimental! function-calling performance is not quite at the level of gpt-4 yet)",
|
||||
value="gpt-3.5-turbo-16k",
|
||||
),
|
||||
]
|
||||
from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from memgpt.presets.presets import DEFAULT_PRESET
|
||||
|
||||
|
||||
# helper functions for writing to configs
|
||||
@@ -85,8 +55,8 @@ class MemGPTConfig:
|
||||
azure_embedding_deployment: str = None
|
||||
|
||||
# persona parameters
|
||||
persona: str = personas.DEFAULT
|
||||
human: str = humans.DEFAULT
|
||||
persona: str = DEFAULT_PERSONA
|
||||
human: str = DEFAULT_HUMAN
|
||||
agent: str = None
|
||||
|
||||
# embedding parameters
|
||||
@@ -377,298 +347,3 @@ class AgentConfig:
|
||||
utils.printd(f"Removing missing argument {key} from agent config")
|
||||
del agent_config[key]
|
||||
return cls(**agent_config)
|
||||
|
||||
|
||||
class Config:
|
||||
personas_dir = os.path.join("memgpt", "personas", "examples")
|
||||
custom_personas_dir = os.path.join(MEMGPT_DIR, "personas")
|
||||
humans_dir = os.path.join("memgpt", "humans", "examples")
|
||||
custom_humans_dir = os.path.join(MEMGPT_DIR, "humans")
|
||||
configs_dir = os.path.join(MEMGPT_DIR, "configs")
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(Config.custom_personas_dir, exist_ok=True)
|
||||
os.makedirs(Config.custom_humans_dir, exist_ok=True)
|
||||
self.load_type = None
|
||||
self.archival_storage_files = None
|
||||
self.compute_embeddings = False
|
||||
self.agent_save_file = None
|
||||
self.persistence_manager_save_file = None
|
||||
self.host = os.getenv("OPENAI_API_BASE")
|
||||
self.index = None
|
||||
self.config_file = None
|
||||
self.preload_archival = False
|
||||
|
||||
@classmethod
|
||||
def legacy_flags_init(
|
||||
cls: Type["Config"],
|
||||
model: str,
|
||||
memgpt_persona: str,
|
||||
human_persona: str,
|
||||
load_type: str = None,
|
||||
archival_storage_files: str = None,
|
||||
archival_storage_index: str = None,
|
||||
compute_embeddings: bool = False,
|
||||
):
|
||||
self = cls()
|
||||
self.model = model
|
||||
self.memgpt_persona = memgpt_persona
|
||||
self.human_persona = human_persona
|
||||
self.load_type = load_type
|
||||
self.archival_storage_files = archival_storage_files
|
||||
self.archival_storage_index = archival_storage_index
|
||||
self.compute_embeddings = compute_embeddings
|
||||
recompute_embeddings = self.compute_embeddings
|
||||
if self.archival_storage_index:
|
||||
recompute_embeddings = False # TODO Legacy support -- can't recompute embeddings on a path that's not specified.
|
||||
if self.archival_storage_files:
|
||||
self.configure_archival_storage(recompute_embeddings)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def config_init(cls: Type["Config"], config_file: str = None):
|
||||
self = cls()
|
||||
self.config_file = config_file
|
||||
if self.config_file is None:
|
||||
cfg = Config.get_most_recent_config()
|
||||
use_cfg = False
|
||||
if cfg:
|
||||
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}")
|
||||
use_cfg = questionary.confirm(f"Use most recent config file '{cfg}'?").ask()
|
||||
if use_cfg:
|
||||
self.config_file = cfg
|
||||
|
||||
if self.config_file:
|
||||
self.load_config(self.config_file)
|
||||
recompute_embeddings = False
|
||||
if self.compute_embeddings:
|
||||
if self.archival_storage_index:
|
||||
recompute_embeddings = questionary.confirm(
|
||||
f"Would you like to recompute embeddings? Do this if your files have changed.\n Files: {self.archival_storage_files}",
|
||||
default=False,
|
||||
).ask()
|
||||
else:
|
||||
recompute_embeddings = True
|
||||
if self.load_type:
|
||||
self.configure_archival_storage(recompute_embeddings)
|
||||
self.write_config()
|
||||
return self
|
||||
|
||||
# print("No settings file found, configuring MemGPT...")
|
||||
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}")
|
||||
|
||||
self.model = questionary.select(
|
||||
"Which model would you like to use?",
|
||||
model_choices,
|
||||
default=model_choices[0],
|
||||
).ask()
|
||||
|
||||
self.memgpt_persona = questionary.select(
|
||||
"Which persona would you like MemGPT to use?",
|
||||
Config.get_memgpt_personas(),
|
||||
).ask()
|
||||
print(self.memgpt_persona)
|
||||
|
||||
self.human_persona = questionary.select(
|
||||
"Which user would you like to use?",
|
||||
Config.get_user_personas(),
|
||||
).ask()
|
||||
|
||||
self.archival_storage_index = None
|
||||
self.preload_archival = questionary.confirm(
|
||||
"Would you like to preload anything into MemGPT's archival memory?", default=False
|
||||
).ask()
|
||||
if self.preload_archival:
|
||||
self.load_type = questionary.select(
|
||||
"What would you like to load?",
|
||||
choices=[
|
||||
questionary.Choice("A folder or file", value="folder"),
|
||||
questionary.Choice("A SQL database", value="sql"),
|
||||
questionary.Choice("A glob pattern", value="glob"),
|
||||
],
|
||||
).ask()
|
||||
if self.load_type == "folder" or self.load_type == "sql":
|
||||
archival_storage_path = questionary.path("Please enter the folder or file (tab for autocomplete):").ask()
|
||||
if os.path.isdir(archival_storage_path):
|
||||
self.archival_storage_files = os.path.join(archival_storage_path, "*")
|
||||
else:
|
||||
self.archival_storage_files = archival_storage_path
|
||||
else:
|
||||
self.archival_storage_files = questionary.path("Please enter the glob pattern (tab for autocomplete):").ask()
|
||||
self.compute_embeddings = questionary.confirm(
|
||||
"Would you like to compute embeddings over these files to enable embeddings search?"
|
||||
).ask()
|
||||
self.configure_archival_storage(self.compute_embeddings)
|
||||
|
||||
self.write_config()
|
||||
return self
|
||||
|
||||
def configure_archival_storage(self, recompute_embeddings):
|
||||
if recompute_embeddings:
|
||||
if self.host:
|
||||
interface.warning_message(
|
||||
"⛔️ Embeddings on a non-OpenAI endpoint are not yet supported, falling back to substring matching search."
|
||||
)
|
||||
else:
|
||||
self.archival_storage_index = utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files)
|
||||
if self.compute_embeddings and self.archival_storage_index:
|
||||
self.index, self.archival_database = utils.prepare_archival_index(self.archival_storage_index)
|
||||
else:
|
||||
self.archival_database = utils.prepare_archival_index_from_files(self.archival_storage_files)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"model": self.model,
|
||||
"memgpt_persona": self.memgpt_persona,
|
||||
"human_persona": self.human_persona,
|
||||
"preload_archival": self.preload_archival,
|
||||
"archival_storage_files": self.archival_storage_files,
|
||||
"archival_storage_index": self.archival_storage_index,
|
||||
"compute_embeddings": self.compute_embeddings,
|
||||
"load_type": self.load_type,
|
||||
"agent_save_file": self.agent_save_file,
|
||||
"persistence_manager_save_file": self.persistence_manager_save_file,
|
||||
"host": self.host,
|
||||
}
|
||||
|
||||
def load_config(self, config_file):
|
||||
with open(config_file, "rt") as f:
|
||||
cfg = json.load(f)
|
||||
self.model = cfg["model"]
|
||||
self.memgpt_persona = cfg["memgpt_persona"]
|
||||
self.human_persona = cfg["human_persona"]
|
||||
self.preload_archival = cfg["preload_archival"]
|
||||
self.archival_storage_files = cfg["archival_storage_files"]
|
||||
self.archival_storage_index = cfg["archival_storage_index"]
|
||||
self.compute_embeddings = cfg["compute_embeddings"]
|
||||
self.load_type = cfg["load_type"]
|
||||
self.agent_save_file = cfg["agent_save_file"]
|
||||
self.persistence_manager_save_file = cfg["persistence_manager_save_file"]
|
||||
self.host = cfg["host"]
|
||||
|
||||
def write_config(self, configs_dir=None):
|
||||
if configs_dir is None:
|
||||
configs_dir = Config.configs_dir
|
||||
os.makedirs(configs_dir, exist_ok=True)
|
||||
if self.config_file is None:
|
||||
filename = os.path.join(configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_"))
|
||||
self.config_file = f"{filename}.json"
|
||||
with open(self.config_file, "wt") as f:
|
||||
json.dump(self.to_dict(), f, indent=4)
|
||||
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}")
|
||||
|
||||
@staticmethod
|
||||
def is_valid_config_file(file: str):
|
||||
cfg = Config()
|
||||
try:
|
||||
cfg.load_config(file)
|
||||
except Exception:
|
||||
return False
|
||||
return cfg.memgpt_persona is not None and cfg.human_persona is not None # TODO: more validation for configs
|
||||
|
||||
@staticmethod
|
||||
def get_memgpt_personas():
|
||||
dir_path = Config.personas_dir
|
||||
all_personas = Config.get_personas(dir_path)
|
||||
default_personas = [
|
||||
"sam",
|
||||
"sam_pov",
|
||||
"memgpt_starter",
|
||||
"memgpt_doc",
|
||||
"sam_simple_pov_gpt35",
|
||||
]
|
||||
custom_personas_in_examples = list(set(all_personas) - set(default_personas))
|
||||
custom_personas = Config.get_personas(Config.custom_personas_dir)
|
||||
return (
|
||||
Config.get_persona_choices(
|
||||
[p for p in custom_personas],
|
||||
get_persona_text,
|
||||
Config.custom_personas_dir,
|
||||
)
|
||||
+ Config.get_persona_choices(
|
||||
[p for p in custom_personas_in_examples + default_personas],
|
||||
get_persona_text,
|
||||
None,
|
||||
# Config.personas_dir,
|
||||
)
|
||||
+ [
|
||||
questionary.Separator(),
|
||||
questionary.Choice(
|
||||
f"📝 You can create your own personas by adding .txt files to {Config.custom_personas_dir}.",
|
||||
disabled=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_personas():
|
||||
dir_path = Config.humans_dir
|
||||
all_personas = Config.get_personas(dir_path)
|
||||
default_personas = ["basic", "cs_phd"]
|
||||
custom_personas_in_examples = list(set(all_personas) - set(default_personas))
|
||||
custom_personas = Config.get_personas(Config.custom_humans_dir)
|
||||
return (
|
||||
Config.get_persona_choices(
|
||||
[p for p in custom_personas],
|
||||
get_human_text,
|
||||
Config.custom_humans_dir,
|
||||
)
|
||||
+ Config.get_persona_choices(
|
||||
[p for p in custom_personas_in_examples + default_personas],
|
||||
get_human_text,
|
||||
None,
|
||||
# Config.humans_dir,
|
||||
)
|
||||
+ [
|
||||
questionary.Separator(),
|
||||
questionary.Choice(
|
||||
f"📝 You can create your own human profiles by adding .txt files to {Config.custom_humans_dir}.",
|
||||
disabled=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_personas(dir_path) -> List[str]:
|
||||
files = sorted(glob.glob(os.path.join(dir_path, "*.txt")))
|
||||
stems = []
|
||||
for f in files:
|
||||
filename = os.path.basename(f)
|
||||
stem, _ = os.path.splitext(filename)
|
||||
stems.append(stem)
|
||||
return stems
|
||||
|
||||
@staticmethod
|
||||
def get_persona_choices(personas, text_getter, dir):
|
||||
return [
|
||||
questionary.Choice(
|
||||
title=[
|
||||
("class:question", f"{p}"),
|
||||
("class:text", f"\n{indent(text_getter(p, dir))}"),
|
||||
],
|
||||
value=(p, dir),
|
||||
)
|
||||
for p in personas
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_most_recent_config(configs_dir=None):
|
||||
if configs_dir is None:
|
||||
configs_dir = Config.configs_dir
|
||||
os.makedirs(configs_dir, exist_ok=True)
|
||||
files = [
|
||||
os.path.join(configs_dir, f)
|
||||
for f in os.listdir(configs_dir)
|
||||
if os.path.isfile(os.path.join(configs_dir, f)) and Config.is_valid_config_file(os.path.join(configs_dir, f))
|
||||
]
|
||||
# Return the file with the most recent modification time
|
||||
if len(files) == 0:
|
||||
return None
|
||||
return max(files, key=os.path.getmtime)
|
||||
|
||||
|
||||
def indent(text, num_lines=5):
|
||||
lines = textwrap.fill(text, width=100).split("\n")
|
||||
if len(lines) > num_lines:
|
||||
lines = lines[: num_lines - 1] + ["... (truncated)", lines[-1]]
|
||||
return " " + "\n ".join(lines)
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt")
|
||||
|
||||
DEFAULT_MEMGPT_MODEL = "gpt-4"
|
||||
DEFAULT_PERSONA = "sam_pov"
|
||||
DEFAULT_HUMAN = "basic"
|
||||
|
||||
FIRST_MESSAGE_ATTEMPTS = 10
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import requests
|
||||
|
||||
|
||||
from memgpt.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
|
||||
from memgpt.openai_tools import completions_with_backoff as create
|
||||
from memgpt.openai_tools import create
|
||||
|
||||
|
||||
def message_chatgpt(self, message: str):
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
import os
|
||||
|
||||
DEFAULT = "cs_phd"
|
||||
|
||||
|
||||
def get_human_text(key=DEFAULT, dir=None):
|
||||
if dir is None:
|
||||
dir = os.path.join(os.path.dirname(__file__), "examples")
|
||||
filename = key if key.endswith(".txt") else f"{key}.txt"
|
||||
file_path = os.path.join(dir, filename)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
return file.read().strip()
|
||||
else:
|
||||
raise FileNotFoundError(f"No file found for key {key}, path={file_path}")
|
||||
377
memgpt/main.py
377
memgpt/main.py
@@ -20,28 +20,10 @@ console = Console()
|
||||
from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.utils as utils
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
from memgpt.persistence_manager import (
|
||||
LocalStateManager,
|
||||
InMemoryStateManager,
|
||||
InMemoryStateManagerWithPreloadedArchivalMemory,
|
||||
InMemoryStateManagerWithFaiss,
|
||||
)
|
||||
from memgpt.cli.cli import run, attach, version
|
||||
from memgpt.cli.cli_config import configure, list, add
|
||||
from memgpt.cli.cli_load import app as load_app
|
||||
from memgpt.config import Config, MemGPTConfig, AgentConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.openai_tools import (
|
||||
configure_azure_support,
|
||||
check_azure_embeddings,
|
||||
get_set_azure_env_vars,
|
||||
)
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
@@ -65,313 +47,7 @@ def clear_line(strip_ui=False):
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def save(memgpt_agent, cfg):
|
||||
filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
|
||||
filename = f"{filename}.json"
|
||||
directory = os.path.join(MEMGPT_DIR, "saved_state")
|
||||
filename = os.path.join(directory, filename)
|
||||
try:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
memgpt_agent.save_to_json_file(filename)
|
||||
print(f"Saved checkpoint to: {filename}")
|
||||
cfg.agent_save_file = filename
|
||||
except Exception as e:
|
||||
print(f"Saving state to {filename} failed with: {e}")
|
||||
|
||||
# save the persistence manager too
|
||||
filename = filename.replace(".json", ".persistence.pickle")
|
||||
try:
|
||||
memgpt_agent.persistence_manager.save(filename)
|
||||
print(f"Saved persistence manager to: {filename}")
|
||||
cfg.persistence_manager_save_file = filename
|
||||
except Exception as e:
|
||||
print(f"Saving persistence manager to {filename} failed with: {e}")
|
||||
cfg.write_config()
|
||||
|
||||
|
||||
def load(memgpt_agent, filename):
|
||||
if filename is not None:
|
||||
if filename[-5:] != ".json":
|
||||
filename += ".json"
|
||||
try:
|
||||
memgpt_agent.load_from_json_file_inplace(filename)
|
||||
print(f"Loaded checkpoint {filename}")
|
||||
except Exception as e:
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
else:
|
||||
# Load the latest file
|
||||
save_path = os.path.join(constants.MEMGPT_DIR, "saved_state")
|
||||
print(f"/load warning: no checkpoint specified, loading most recent checkpoint from {save_path} instead")
|
||||
json_files = glob.glob(os.path.join(save_path, "*.json")) # This will list all .json files in the current directory.
|
||||
|
||||
# Check if there are any json files.
|
||||
if not json_files:
|
||||
print(f"/load error: no .json checkpoint files found")
|
||||
return
|
||||
else:
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
filename = max(json_files, key=os.path.getmtime)
|
||||
try:
|
||||
memgpt_agent.load_from_json_file_inplace(filename)
|
||||
print(f"Loaded checkpoint {filename}")
|
||||
except Exception as e:
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
|
||||
# need to load persistence manager too
|
||||
filename = filename.replace(".json", ".persistence.pickle")
|
||||
try:
|
||||
memgpt_agent.persistence_manager = InMemoryStateManager.load(
|
||||
filename
|
||||
) # TODO(fixme):for different types of persistence managers that require different load/save methods
|
||||
print(f"Loaded persistence manager from {filename}")
|
||||
except Exception as e:
|
||||
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True) # make default command
|
||||
# @app.command("legacy-run")
|
||||
def legacy_run(
|
||||
ctx: typer.Context,
|
||||
persona: str = typer.Option(None, help="Specify persona"),
|
||||
human: str = typer.Option(None, help="Specify human"),
|
||||
model: str = typer.Option(constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"),
|
||||
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
|
||||
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
|
||||
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
|
||||
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
|
||||
archival_storage_faiss_path: str = typer.Option(
|
||||
"",
|
||||
"--archival_storage_faiss_path",
|
||||
help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)",
|
||||
),
|
||||
archival_storage_files: str = typer.Option(
|
||||
"",
|
||||
"--archival_storage_files",
|
||||
help="Specify files to pre-load into archival memory (glob pattern)",
|
||||
),
|
||||
archival_storage_files_compute_embeddings: str = typer.Option(
|
||||
"",
|
||||
"--archival_storage_files_compute_embeddings",
|
||||
help="Specify files to pre-load into archival memory (glob pattern), and compute embeddings over them",
|
||||
),
|
||||
archival_storage_sqldb: str = typer.Option(
|
||||
"",
|
||||
"--archival_storage_sqldb",
|
||||
help="Specify SQL database to pre-load into archival memory",
|
||||
),
|
||||
use_azure_openai: bool = typer.Option(
|
||||
False,
|
||||
"--use_azure_openai",
|
||||
help="Use Azure OpenAI (requires additional environment variables)",
|
||||
), # TODO: just pass in?
|
||||
):
|
||||
if ctx.invoked_subcommand is not None:
|
||||
return
|
||||
|
||||
typer.secho(
|
||||
"Warning: Running legacy run command. You may need to `pip install pymemgpt[legacy] -U`. Run `memgpt run` instead.",
|
||||
fg=typer.colors.RED,
|
||||
bold=True,
|
||||
)
|
||||
if not questionary.confirm("Continue with legacy CLI?", default=False).ask():
|
||||
return
|
||||
|
||||
main(
|
||||
persona,
|
||||
human,
|
||||
model,
|
||||
first,
|
||||
debug,
|
||||
no_verify,
|
||||
archival_storage_faiss_path,
|
||||
archival_storage_files,
|
||||
archival_storage_files_compute_embeddings,
|
||||
archival_storage_sqldb,
|
||||
use_azure_openai,
|
||||
strip_ui,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
persona,
|
||||
human,
|
||||
model,
|
||||
first,
|
||||
debug,
|
||||
no_verify,
|
||||
archival_storage_faiss_path,
|
||||
archival_storage_files,
|
||||
archival_storage_files_compute_embeddings,
|
||||
archival_storage_sqldb,
|
||||
use_azure_openai,
|
||||
strip_ui,
|
||||
):
|
||||
interface.STRIP_UI = strip_ui
|
||||
utils.DEBUG = debug
|
||||
logging.getLogger().setLevel(logging.CRITICAL)
|
||||
if debug:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Azure OpenAI support
|
||||
if use_azure_openai:
|
||||
configure_azure_support()
|
||||
check_azure_embeddings()
|
||||
else:
|
||||
azure_vars = get_set_azure_env_vars()
|
||||
if len(azure_vars) > 0:
|
||||
print(f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False")
|
||||
return
|
||||
|
||||
if any(
|
||||
(
|
||||
persona,
|
||||
human,
|
||||
model != constants.DEFAULT_MEMGPT_MODEL,
|
||||
archival_storage_faiss_path,
|
||||
archival_storage_files,
|
||||
archival_storage_files_compute_embeddings,
|
||||
archival_storage_sqldb,
|
||||
)
|
||||
):
|
||||
interface.important_message("⚙️ Using legacy command line arguments.")
|
||||
model = model
|
||||
if model is None:
|
||||
model = constants.DEFAULT_MEMGPT_MODEL
|
||||
memgpt_persona = persona
|
||||
if memgpt_persona is None:
|
||||
memgpt_persona = (
|
||||
personas.GPT35_DEFAULT if (model is not None and "gpt-3.5" in model) else personas.DEFAULT,
|
||||
None, # represents the personas dir in pymemgpt package
|
||||
)
|
||||
else:
|
||||
try:
|
||||
personas.get_persona_text(memgpt_persona, Config.custom_personas_dir)
|
||||
memgpt_persona = (memgpt_persona, Config.custom_personas_dir)
|
||||
except FileNotFoundError:
|
||||
personas.get_persona_text(memgpt_persona)
|
||||
memgpt_persona = (memgpt_persona, None)
|
||||
|
||||
human_persona = human
|
||||
if human_persona is None:
|
||||
human_persona = (humans.DEFAULT, None)
|
||||
else:
|
||||
try:
|
||||
humans.get_human_text(human_persona, Config.custom_humans_dir)
|
||||
human_persona = (human_persona, Config.custom_humans_dir)
|
||||
except FileNotFoundError:
|
||||
humans.get_human_text(human_persona)
|
||||
human_persona = (human_persona, None)
|
||||
|
||||
print(persona, model, memgpt_persona)
|
||||
if archival_storage_files:
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
load_type="folder",
|
||||
archival_storage_files=archival_storage_files,
|
||||
compute_embeddings=False,
|
||||
)
|
||||
elif archival_storage_faiss_path:
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
load_type="folder",
|
||||
archival_storage_files=archival_storage_faiss_path,
|
||||
archival_storage_index=archival_storage_faiss_path,
|
||||
compute_embeddings=True,
|
||||
)
|
||||
elif archival_storage_files_compute_embeddings:
|
||||
print(model)
|
||||
print(memgpt_persona)
|
||||
print(human_persona)
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
load_type="folder",
|
||||
archival_storage_files=archival_storage_files_compute_embeddings,
|
||||
compute_embeddings=True,
|
||||
)
|
||||
elif archival_storage_sqldb:
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
load_type="sql",
|
||||
archival_storage_files=archival_storage_sqldb,
|
||||
compute_embeddings=False,
|
||||
)
|
||||
else:
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
)
|
||||
else:
|
||||
cfg = Config.config_init()
|
||||
|
||||
interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']")
|
||||
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
|
||||
interface.warning_message(
|
||||
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
|
||||
)
|
||||
|
||||
if cfg.index:
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(cfg.index, cfg.archival_database)
|
||||
elif cfg.archival_storage_files:
|
||||
print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.")
|
||||
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(cfg.archival_database)
|
||||
else:
|
||||
persistence_manager = InMemoryStateManager()
|
||||
|
||||
if archival_storage_files_compute_embeddings:
|
||||
interface.important_message(
|
||||
f"(legacy) To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={cfg.archival_storage_index} (if your files haven't changed)."
|
||||
)
|
||||
|
||||
# Moved defaults out of FLAGS so that we can dynamically select the default persona based on model
|
||||
chosen_human = cfg.human_persona
|
||||
chosen_persona = cfg.memgpt_persona
|
||||
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT_PRESET,
|
||||
None, # no agent config to provide
|
||||
cfg.model,
|
||||
personas.get_persona_text(*chosen_persona),
|
||||
humans.get_human_text(*chosen_human),
|
||||
interface,
|
||||
persistence_manager,
|
||||
)
|
||||
|
||||
print_messages = interface.print_messages
|
||||
print_messages(memgpt_agent.messages)
|
||||
|
||||
if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner
|
||||
if not os.path.exists(cfg.archival_storage_files):
|
||||
print(f"File {cfg.archival_storage_files} does not exist")
|
||||
return
|
||||
# Ingest data from file into archival storage
|
||||
else:
|
||||
print(f"Database found! Loading database into archival memory")
|
||||
data_list = utils.read_database_as_list(cfg.archival_storage_files)
|
||||
user_message = f"Your archival memory has been loaded with a SQL database called {data_list[0]}, which contains schema {data_list[1]}. Remember to refer to this first while answering any user questions!"
|
||||
for row in data_list:
|
||||
memgpt_agent.persistence_manager.archival_memory.insert(row)
|
||||
print(f"Database loaded into archival memory.")
|
||||
|
||||
if cfg.agent_save_file:
|
||||
load_save_file = questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask()
|
||||
if load_save_file:
|
||||
load(memgpt_agent, cfg.agent_save_file)
|
||||
|
||||
# run agent loop
|
||||
run_agent_loop(memgpt_agent, first, no_verify, cfg, strip_ui, legacy=True)
|
||||
|
||||
|
||||
def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=False, legacy=False):
|
||||
def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=False):
|
||||
counter = 0
|
||||
user_input = None
|
||||
skip_next_user_input = False
|
||||
@@ -412,49 +88,14 @@ def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=Fals
|
||||
# Handle CLI commands
|
||||
# Commands to not get passed as input to MemGPT
|
||||
if user_input.startswith("/"):
|
||||
if legacy:
|
||||
# legacy agent save functions (TODO: eventually remove)
|
||||
if user_input.lower() == "/load" or user_input.lower().startswith("/load "):
|
||||
command = user_input.strip().split()
|
||||
filename = command[1] if len(command) > 1 else None
|
||||
load(memgpt_agent=memgpt_agent, filename=filename)
|
||||
continue
|
||||
elif user_input.lower() == "/exit":
|
||||
# autosave
|
||||
save(memgpt_agent=memgpt_agent, cfg=cfg)
|
||||
break
|
||||
|
||||
elif user_input.lower() == "/savechat":
|
||||
filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
|
||||
filename = f"{filename}.pkl"
|
||||
directory = os.path.join(MEMGPT_DIR, "saved_chats")
|
||||
try:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
with open(os.path.join(directory, filename), "wb") as f:
|
||||
pickle.dump(memgpt_agent.messages, f)
|
||||
print(f"Saved messages to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving chat to {filename} failed with: {e}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/save":
|
||||
save(memgpt_agent=memgpt_agent, cfg=cfg)
|
||||
continue
|
||||
else:
|
||||
# updated agent save functions
|
||||
if user_input.lower() == "/exit":
|
||||
memgpt_agent.save()
|
||||
break
|
||||
elif user_input.lower() == "/save" or user_input.lower() == "/savechat":
|
||||
memgpt_agent.save()
|
||||
continue
|
||||
|
||||
if user_input.lower() == "/attach":
|
||||
if legacy:
|
||||
typer.secho("Error: /attach is not supported in legacy mode.", fg=typer.colors.RED, bold=True)
|
||||
continue
|
||||
|
||||
# updated agent save functions
|
||||
if user_input.lower() == "/exit":
|
||||
memgpt_agent.save()
|
||||
break
|
||||
elif user_input.lower() == "/save" or user_input.lower() == "/savechat":
|
||||
memgpt_agent.save()
|
||||
continue
|
||||
elif user_input.lower() == "/attach":
|
||||
# TODO: check if agent already has it
|
||||
data_source_options = StorageConnector.list_loaded_data()
|
||||
if len(data_source_options) == 0:
|
||||
|
||||
371
memgpt/memory.py
371
memgpt/memory.py
@@ -1,30 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import datetime
|
||||
import re
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC, MEMGPT_DIR
|
||||
from memgpt.utils import cosine_similarity, get_local_time, printd, count_tokens
|
||||
from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC
|
||||
from memgpt.utils import get_local_time, printd, count_tokens
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from memgpt import utils
|
||||
from memgpt.openai_tools import get_embedding_with_backoff, chat_completion_with_backoff
|
||||
from llama_index import (
|
||||
VectorStoreIndex,
|
||||
EmptyIndex,
|
||||
get_response_synthesizer,
|
||||
load_index_from_storage,
|
||||
StorageContext,
|
||||
Document,
|
||||
)
|
||||
from memgpt.openai_tools import create
|
||||
from llama_index import Document
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
from llama_index.retrievers import VectorIndexRetriever
|
||||
from llama_index.query_engine import RetrieverQueryEngine
|
||||
from llama_index.indices.postprocessor import SimilarityPostprocessor
|
||||
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.config import MemGPTConfig
|
||||
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.config import MemGPTConfig
|
||||
@@ -138,7 +123,7 @@ def summarize_messages(
|
||||
{"role": "user", "content": summary_input},
|
||||
]
|
||||
|
||||
response = chat_completion_with_backoff(
|
||||
response = create(
|
||||
agent_config=agent_config,
|
||||
messages=message_sequence,
|
||||
)
|
||||
@@ -178,206 +163,6 @@ class ArchivalMemory(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class DummyArchivalMemory(ArchivalMemory):
|
||||
"""Dummy in-memory version of an archival memory database (eg run on MongoDB)
|
||||
|
||||
Archival Memory: A more structured and deep storage space for the AI's reflections,
|
||||
insights, or any other data that doesn't fit into the active memory but
|
||||
is essential enough not to be left only to the recall memory.
|
||||
"""
|
||||
|
||||
def __init__(self, archival_memory_database=None):
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
|
||||
def __len__(self):
|
||||
return len(self._archive)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if len(self._archive) == 0:
|
||||
memory_str = "<empty>"
|
||||
else:
|
||||
memory_str = "\n".join([d["content"] for d in self._archive])
|
||||
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"
|
||||
|
||||
def insert(self, memory_string):
|
||||
self._archive.append(
|
||||
{
|
||||
# can eventually upgrade to adding semantic tags, etc
|
||||
"timestamp": get_local_time(),
|
||||
"content": memory_string,
|
||||
}
|
||||
)
|
||||
|
||||
def search(self, query_string, count=None, start=None):
|
||||
"""Simple text-based search"""
|
||||
# in the dummy version, run an (inefficient) case-insensitive match search
|
||||
# printd(f"query_string: {query_string}")
|
||||
matches = [s for s in self._archive if query_string.lower() in s["content"].lower()]
|
||||
# printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[str(d['content']) d in matches[:5]]}")
|
||||
printd(
|
||||
f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}"
|
||||
)
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
return matches[start:], len(matches)
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
"""Same as dummy in-memory archival memory, but with bare-bones embedding support"""
|
||||
|
||||
def __init__(self, archival_memory_database=None, embedding_model="text-embedding-ada-002"):
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
def __len__(self):
|
||||
return len(self._archive)
|
||||
|
||||
def _insert(self, memory_string, embedding):
|
||||
# Get the embedding
|
||||
embedding_meta = {"model": self.embedding_model}
|
||||
printd(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
|
||||
|
||||
self._archive.append(
|
||||
{
|
||||
"timestamp": get_local_time(),
|
||||
"content": memory_string,
|
||||
"embedding": embedding,
|
||||
"embedding_metadata": embedding_meta,
|
||||
}
|
||||
)
|
||||
|
||||
def insert(self, memory_string):
|
||||
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
return self._insert(memory_string, embedding)
|
||||
|
||||
def search(self, query_string, count, start):
|
||||
"""Simple embedding-based search (inefficient, no caching)"""
|
||||
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
|
||||
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
|
||||
# query_embedding = get_embedding(query_string, model=self.embedding_model)
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
similarity_scores = [cosine_similarity(memory["embedding"], query_embedding) for memory in self._archive]
|
||||
|
||||
# Sort the archive based on similarity scores
|
||||
sorted_archive_with_scores = sorted(
|
||||
zip(self._archive, similarity_scores),
|
||||
key=lambda pair: pair[1], # Sort by the similarity score
|
||||
reverse=True, # We want the highest similarity first
|
||||
)
|
||||
printd(
|
||||
f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}"
|
||||
)
|
||||
|
||||
# Extract the sorted archive without the scores
|
||||
matches = [item[0] for item in sorted_archive_with_scores]
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
return matches[start:], len(matches)
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
"""Dummy in-memory version of an archival memory database, using a FAISS
|
||||
index for fast nearest-neighbors embedding search.
|
||||
|
||||
Archival memory is effectively "infinite" overflow for core memory,
|
||||
and is read-only via string queries.
|
||||
|
||||
Archival Memory: A more structured and deep storage space for the AI's reflections,
|
||||
insights, or any other data that doesn't fit into the active memory but
|
||||
is essential enough not to be left only to the recall memory.
|
||||
"""
|
||||
|
||||
def __init__(self, index=None, archival_memory_database=None, embedding_model="text-embedding-ada-002", k=100):
|
||||
if index is None:
|
||||
import faiss
|
||||
|
||||
self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
|
||||
else:
|
||||
self.index = index
|
||||
self.k = k
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
self.embedding_model = embedding_model
|
||||
self.embeddings_dict = {}
|
||||
self.search_results = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self._archive)
|
||||
|
||||
def insert(self, memory_string):
|
||||
import numpy as np
|
||||
|
||||
# Get the embedding
|
||||
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
|
||||
print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
|
||||
|
||||
self._archive.append(
|
||||
{
|
||||
# can eventually upgrade to adding semantic tags, etc
|
||||
"timestamp": get_local_time(),
|
||||
"content": memory_string,
|
||||
}
|
||||
)
|
||||
embedding = np.array([embedding]).astype("float32")
|
||||
self.index.add(embedding)
|
||||
|
||||
def search(self, query_string, count=None, start=None):
|
||||
"""Simple embedding-based search (inefficient, no caching)"""
|
||||
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
|
||||
|
||||
# query_embedding = get_embedding(query_string, model=self.embedding_model)
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
import numpy as np
|
||||
|
||||
if query_string in self.embeddings_dict:
|
||||
search_result = self.search_results[query_string]
|
||||
else:
|
||||
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
_, indices = self.index.search(np.array([np.array(query_embedding, dtype=np.float32)]), self.k)
|
||||
search_result = [self._archive[idx] if idx < len(self._archive) else "" for idx in indices[0]]
|
||||
self.embeddings_dict[query_string] = query_embedding
|
||||
self.search_results[query_string] = search_result
|
||||
|
||||
if start is not None and count is not None:
|
||||
toprint = search_result[start : start + count]
|
||||
else:
|
||||
if len(search_result) >= 5:
|
||||
toprint = search_result[:5]
|
||||
else:
|
||||
toprint = search_result
|
||||
printd(
|
||||
f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}"
|
||||
)
|
||||
|
||||
# Extract the sorted archive without the scores
|
||||
matches = search_result
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
return matches[start:], len(matches)
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class RecallMemory(ABC):
|
||||
@abstractmethod
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
@@ -402,6 +187,8 @@ class DummyRecallMemory(RecallMemory):
|
||||
effectively allowing it to 'remember' prior engagements with a user.
|
||||
"""
|
||||
|
||||
# TODO: replace this with StorageConnector based implementation
|
||||
|
||||
def __init__(self, message_database=None, restrict_search_to_summaries=False):
|
||||
self._message_logs = [] if message_database is None else message_database # consists of full message dicts
|
||||
|
||||
@@ -508,150 +295,6 @@ class DummyRecallMemory(RecallMemory):
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
"""Lazily manage embeddings by keeping a string->embed dict"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.embeddings = dict()
|
||||
self.embedding_model = "text-embedding-ada-002"
|
||||
self.only_use_preloaded_embeddings = False
|
||||
|
||||
def text_search(self, query_string, count, start):
|
||||
# in the dummy version, run an (inefficient) case-insensitive match search
|
||||
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
|
||||
|
||||
# first, go through and make sure we have all the embeddings we need
|
||||
message_pool_filtered = []
|
||||
for d in message_pool:
|
||||
message_str = d["message"]["content"]
|
||||
if self.only_use_preloaded_embeddings:
|
||||
if message_str not in self.embeddings:
|
||||
printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, skipping.")
|
||||
else:
|
||||
message_pool_filtered.append(d)
|
||||
elif message_str not in self.embeddings:
|
||||
printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, computing now")
|
||||
self.embeddings[message_str] = get_embedding_with_backoff(message_str, model=self.embedding_model)
|
||||
message_pool_filtered.append(d)
|
||||
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
similarity_scores = [cosine_similarity(self.embeddings[d["message"]["content"]], query_embedding) for d in message_pool_filtered]
|
||||
|
||||
# Sort the archive based on similarity scores
|
||||
sorted_archive_with_scores = sorted(
|
||||
zip(message_pool_filtered, similarity_scores),
|
||||
key=lambda pair: pair[1], # Sort by the similarity score
|
||||
reverse=True, # We want the highest similarity first
|
||||
)
|
||||
printd(
|
||||
f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}"
|
||||
)
|
||||
|
||||
# Extract the sorted archive without the scores
|
||||
matches = [item[0] for item in sorted_archive_with_scores]
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
return matches[start:], len(matches)
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class LocalArchivalMemory(ArchivalMemory):
|
||||
"""Archival memory built on top of Llama Index"""
|
||||
|
||||
def __init__(self, agent_config, top_k: Optional[int] = 100):
|
||||
"""Init function for archival memory
|
||||
|
||||
:param archiva_memory_database: name of dataset to pre-fill archival with
|
||||
:type archival_memory_database: str
|
||||
"""
|
||||
|
||||
self.top_k = top_k
|
||||
self.agent_config = agent_config
|
||||
|
||||
# locate saved index
|
||||
# if self.agent_config.data_source is not None: # connected data source
|
||||
# directory = f"{MEMGPT_DIR}/archival/{self.agent_config.data_source}"
|
||||
# assert os.path.exists(directory), f"Archival memory database {self.agent_config.data_source} does not exist"
|
||||
# elif self.agent_config.name is not None:
|
||||
if self.agent_config.name is not None:
|
||||
directory = agent_config.save_agent_index_dir()
|
||||
if not os.path.exists(directory):
|
||||
# no existing archival storage
|
||||
directory = None
|
||||
|
||||
# load/create index
|
||||
if directory:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=directory)
|
||||
self.index = load_index_from_storage(storage_context)
|
||||
else:
|
||||
self.index = EmptyIndex()
|
||||
|
||||
# create retriever
|
||||
if isinstance(self.index, EmptyIndex):
|
||||
self.retriever = None # cant create retriever over empty indes
|
||||
else:
|
||||
self.retriever = VectorIndexRetriever(
|
||||
index=self.index, # does this get refreshed?
|
||||
similarity_top_k=self.top_k,
|
||||
)
|
||||
|
||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||
self.cache = {}
|
||||
|
||||
def save(self):
|
||||
"""Save the index to disk"""
|
||||
# if self.agent_config.data_sources: # update original archival index
|
||||
# # TODO: this corrupts the originally loaded data. do we want to do this?
|
||||
# utils.save_index(self.index, self.agent_config.data_sources)
|
||||
# else:
|
||||
|
||||
# don't need to save data source, since we assume data source data is already loaded into the agent index
|
||||
utils.save_agent_index(self.index, self.agent_config)
|
||||
|
||||
def insert(self, memory_string):
|
||||
self.index.insert(memory_string)
|
||||
|
||||
# TODO: figure out if this needs to be refreshed (probably not)
|
||||
self.retriever = VectorIndexRetriever(
|
||||
index=self.index,
|
||||
similarity_top_k=self.top_k,
|
||||
)
|
||||
|
||||
def search(self, query_string, count=None, start=None):
|
||||
print("searching with local")
|
||||
if self.retriever is None:
|
||||
print("Warning: archival memory is empty")
|
||||
return [], 0
|
||||
|
||||
start = start if start else 0
|
||||
count = count if count else self.top_k
|
||||
count = min(count + start, self.top_k)
|
||||
|
||||
if query_string not in self.cache:
|
||||
self.cache[query_string] = self.retriever.retrieve(query_string)
|
||||
|
||||
results = self.cache[query_string][start : start + count]
|
||||
results = [{"timestamp": get_local_time(), "content": node.node.text} for node in results]
|
||||
# from pprint import pprint
|
||||
# pprint(results)
|
||||
return results, len(results)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self.index, EmptyIndex):
|
||||
memory_str = "<empty>"
|
||||
else:
|
||||
memory_str = self.index.ref_doc_info
|
||||
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"
|
||||
|
||||
|
||||
class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
"""Archival memory with embedding based search"""
|
||||
|
||||
|
||||
@@ -311,39 +311,14 @@ def retry_with_exponential_backoff(
|
||||
return wrapper
|
||||
|
||||
|
||||
# TODO: delete/ignore --legacy
|
||||
@retry_with_exponential_backoff
|
||||
def completions_with_backoff(**kwargs):
|
||||
# Local model
|
||||
if HOST_TYPE is not None:
|
||||
return get_chat_completion(**kwargs)
|
||||
|
||||
# OpenAI / Azure model
|
||||
else:
|
||||
if using_azure():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs["deployment_id"] = azure_openai_deployment
|
||||
else:
|
||||
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[kwargs["model"]]
|
||||
kwargs.pop("model")
|
||||
if "context_window" in kwargs:
|
||||
kwargs.pop("context_window")
|
||||
|
||||
api_url = "https://api.openai.com/v1"
|
||||
api_key = os.get_env("OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise Exception("OPENAI_API_KEY is not defined - please set it")
|
||||
return openai_chat_completions_request(api_url, api_key, data=kwargs)
|
||||
|
||||
|
||||
@retry_with_exponential_backoff
|
||||
def chat_completion_with_backoff(
|
||||
def create(
|
||||
agent_config,
|
||||
messages,
|
||||
functions=None,
|
||||
function_call="auto",
|
||||
):
|
||||
"""Return response to chat completion with backoff"""
|
||||
from memgpt.utils import printd
|
||||
from memgpt.config import MemGPTConfig
|
||||
|
||||
@@ -392,91 +367,3 @@ def chat_completion_with_backoff(
|
||||
wrapper=agent_config.model_wrapper,
|
||||
user=config.anon_clientid,
|
||||
)
|
||||
|
||||
|
||||
# TODO: deprecate
|
||||
@retry_with_exponential_backoff
|
||||
def create_embedding_with_backoff(**kwargs):
|
||||
if using_azure():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs["deployment_id"] = azure_openai_deployment
|
||||
else:
|
||||
kwargs["engine"] = kwargs["model"]
|
||||
kwargs.pop("model")
|
||||
|
||||
api_key = os.get_env("AZURE_OPENAI_KEY")
|
||||
if api_key is None:
|
||||
raise Exception("AZURE_OPENAI_API_KEY is not defined - please set it")
|
||||
# TODO check
|
||||
# api_version???
|
||||
# resource_name???
|
||||
# "engine" instead of "model"???
|
||||
return azure_openai_embeddings_request(
|
||||
resource_name=None, deployment_id=azure_openai_deployment, api_version=None, api_key=api_key, data=kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
# return openai.Embedding.create(**kwargs)
|
||||
api_url = "https://api.openai.com/v1"
|
||||
api_key = os.get_env("OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise Exception("OPENAI_API_KEY is not defined - please set it")
|
||||
return openai_embeddings_request(url=api_url, api_key=api_key, data=kwargs)
|
||||
|
||||
|
||||
def get_embedding_with_backoff(text, model="text-embedding-ada-002"):
|
||||
text = text.replace("\n", " ")
|
||||
response = create_embedding_with_backoff(input=[text], model=model)
|
||||
embedding = response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
||||
|
||||
MODEL_TO_AZURE_ENGINE = {
|
||||
"gpt-4-1106-preview": "gpt-4-1106-preview", # TODO check
|
||||
"gpt-4": "gpt-4",
|
||||
"gpt-4-32k": "gpt-4-32k",
|
||||
"gpt-3.5": "gpt-35-turbo", # diff
|
||||
"gpt-3.5-turbo": "gpt-35-turbo", # diff
|
||||
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k", # diff
|
||||
}
|
||||
|
||||
|
||||
def get_set_azure_env_vars():
|
||||
azure_env_variables = [
|
||||
("AZURE_OPENAI_KEY", os.getenv("AZURE_OPENAI_KEY")),
|
||||
("AZURE_OPENAI_ENDPOINT", os.getenv("AZURE_OPENAI_ENDPOINT")),
|
||||
("AZURE_OPENAI_VERSION", os.getenv("AZURE_OPENAI_VERSION")),
|
||||
("AZURE_OPENAI_DEPLOYMENT", os.getenv("AZURE_OPENAI_DEPLOYMENT")),
|
||||
(
|
||||
"AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT",
|
||||
os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT"),
|
||||
),
|
||||
]
|
||||
return [x for x in azure_env_variables if x[1] is not None]
|
||||
|
||||
|
||||
def using_azure():
|
||||
return len(get_set_azure_env_vars()) > 0
|
||||
|
||||
|
||||
def configure_azure_support():
|
||||
azure_openai_key = os.getenv("AZURE_OPENAI_KEY")
|
||||
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
azure_openai_version = os.getenv("AZURE_OPENAI_VERSION")
|
||||
if None in [
|
||||
azure_openai_key,
|
||||
azure_openai_endpoint,
|
||||
azure_openai_version,
|
||||
]:
|
||||
print(f"Error: missing Azure OpenAI environment variables. Please see README section on Azure.")
|
||||
return
|
||||
|
||||
|
||||
def check_azure_embeddings():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None and azure_openai_embedding_deployment is None:
|
||||
raise ValueError(
|
||||
f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure"
|
||||
)
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import pickle
|
||||
from memgpt.config import AgentConfig
|
||||
from .memory import (
|
||||
from memgpt.memory import (
|
||||
DummyRecallMemory,
|
||||
DummyRecallMemoryWithEmbeddings,
|
||||
DummyArchivalMemory,
|
||||
DummyArchivalMemoryWithEmbeddings,
|
||||
DummyArchivalMemoryWithFaiss,
|
||||
EmbeddingArchivalMemory,
|
||||
)
|
||||
from .utils import get_local_time, printd
|
||||
from memgpt.utils import get_local_time, printd
|
||||
|
||||
|
||||
class PersistenceManager(ABC):
|
||||
@@ -35,73 +30,6 @@ class PersistenceManager(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryStateManager(PersistenceManager):
|
||||
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
||||
|
||||
recall_memory_cls = DummyRecallMemory
|
||||
archival_memory_cls = DummyArchivalMemory
|
||||
|
||||
def __init__(self):
|
||||
# Memory held in-state useful for debugging stateful versions
|
||||
self.memory = None
|
||||
self.messages = []
|
||||
self.all_messages = []
|
||||
|
||||
@staticmethod
|
||||
def load(filename):
|
||||
with open(filename, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, "wb") as fh:
|
||||
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def init(self, agent):
|
||||
printd(f"Initializing {self.__class__.__name__} with agent object")
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
||||
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
||||
|
||||
# Persistence manager also handles DB-related state
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
self.archival_memory_db = []
|
||||
self.archival_memory = self.archival_memory_cls(archival_memory_database=self.archival_memory_db)
|
||||
|
||||
def trim_messages(self, num):
|
||||
# printd(f"InMemoryStateManager.trim_messages")
|
||||
self.messages = [self.messages[0]] + self.messages[num:]
|
||||
|
||||
def prepend_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"{self.__class__.__name__}.prepend_to_message")
|
||||
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
||||
self.all_messages.extend(added_messages)
|
||||
|
||||
def append_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"{self.__class__.__name__}.append_to_messages")
|
||||
self.messages = self.messages + added_messages
|
||||
self.all_messages.extend(added_messages)
|
||||
|
||||
def swap_system_message(self, new_system_message):
|
||||
# first tag with timestamps
|
||||
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
||||
|
||||
printd(f"{self.__class__.__name__}.swap_system_message")
|
||||
self.messages[0] = new_system_message
|
||||
self.all_messages.append(new_system_message)
|
||||
|
||||
def update_memory(self, new_memory):
|
||||
printd(f"{self.__class__.__name__}.update_memory")
|
||||
self.memory = new_memory
|
||||
|
||||
|
||||
class LocalStateManager(PersistenceManager):
|
||||
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
||||
|
||||
@@ -189,54 +117,3 @@ class LocalStateManager(PersistenceManager):
|
||||
def update_memory(self, new_memory):
|
||||
printd(f"{self.__class__.__name__}.update_memory")
|
||||
self.memory = new_memory
|
||||
|
||||
|
||||
class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
|
||||
archival_memory_cls = DummyArchivalMemory
|
||||
recall_memory_cls = DummyRecallMemory
|
||||
|
||||
def __init__(self, archival_memory_db):
|
||||
self.archival_memory_db = archival_memory_db
|
||||
|
||||
def init(self, agent):
|
||||
print(f"Initializing {self.__class__.__name__} with agent object")
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
print(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
||||
print(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
self.archival_memory = self.archival_memory_cls(archival_memory_database=self.archival_memory_db)
|
||||
|
||||
|
||||
class InMemoryStateManagerWithEmbeddings(InMemoryStateManager):
|
||||
archival_memory_cls = DummyArchivalMemoryWithEmbeddings
|
||||
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
||||
|
||||
|
||||
class InMemoryStateManagerWithFaiss(InMemoryStateManager):
|
||||
archival_memory_cls = DummyArchivalMemoryWithFaiss
|
||||
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
||||
|
||||
def __init__(self, archival_index, archival_memory_db, a_k=100):
|
||||
super().__init__()
|
||||
self.archival_index = archival_index
|
||||
self.archival_memory_db = archival_memory_db
|
||||
self.a_k = a_k
|
||||
|
||||
def save(self, _filename):
|
||||
raise NotImplementedError
|
||||
|
||||
def init(self, agent):
|
||||
print(f"Initializing {self.__class__.__name__} with agent object")
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
print(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
||||
print(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
|
||||
|
||||
# Persistence manager also handles DB-related state
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
self.archival_memory = self.archival_memory_cls(
|
||||
index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k
|
||||
)
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
# MemGPT over LlamaIndex API Docs
|
||||
|
||||
MemGPT enables you to chat with your data -- try running this example to talk to the LlamaIndex API docs!
|
||||
|
||||
1.
|
||||
a. Download LlamaIndex API docs and FAISS index from [Hugging Face](https://huggingface.co/datasets/MemGPT/llamaindex-api-docs).
|
||||
```bash
|
||||
# Make sure you have git-lfs installed (https://git-lfs.com)
|
||||
git lfs install
|
||||
git clone https://huggingface.co/datasets/MemGPT/llamaindex-api-docs
|
||||
```
|
||||
|
||||
**-- OR --**
|
||||
|
||||
b. Build the index:
|
||||
1. Build `llama_index` API docs with `make text`. Instructions [here](https://github.com/run-llama/llama_index/blob/main/docs/DOCS_README.md). Copy over the generated `_build/text` folder to this directory.
|
||||
2. Generate embeddings and FAISS index.
|
||||
```bash
|
||||
python3 scrape_docs.py
|
||||
python3 generate_embeddings_for_docs.py all_docs.jsonl
|
||||
python3 build_index.py --embedding_files all_docs.embeddings.jsonl --output_index_file all_docs.index
|
||||
```
|
||||
|
||||
2. In the root `MemGPT` directory, run
|
||||
```bash
|
||||
python3 main.py --archival_storage_faiss_path=<ARCHIVAL_STORAGE_FAISS_PATH> --persona=memgpt_doc --human=basic
|
||||
```
|
||||
where `ARCHIVAL_STORAGE_FAISS_PATH` is the directory where `all_docs.jsonl` and `all_docs.index` are located.
|
||||
If you downloaded from Hugging Face, it will be `memgpt/personas/docqa/llamaindex-api-docs`.
|
||||
If you built the index yourself, it will be `memgpt/personas/docqa`.
|
||||
|
||||
## Demo
|
||||
<div align="center">
|
||||
<img src="https://memgpt.ai/assets/img/docqa_demo.gif" alt="MemGPT demo video for llamaindex api docs search" width="800">
|
||||
</div>
|
||||
@@ -1,41 +0,0 @@
|
||||
import faiss
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def build_index(embedding_files: str, index_name: str):
|
||||
index = faiss.IndexFlatL2(1536)
|
||||
file_list = sorted(glob(embedding_files))
|
||||
|
||||
for embedding_file in file_list:
|
||||
print(embedding_file)
|
||||
with open(embedding_file, "rt", encoding="utf-8") as file:
|
||||
embeddings = []
|
||||
l = 0
|
||||
for line in tqdm(file):
|
||||
# Parse each JSON line
|
||||
data = json.loads(line)
|
||||
embeddings.append(data)
|
||||
l += 1
|
||||
data = np.array(embeddings).astype("float32")
|
||||
print(data.shape)
|
||||
try:
|
||||
index.add(data)
|
||||
except Exception as e:
|
||||
print(data)
|
||||
raise e
|
||||
|
||||
faiss.write_index(index, index_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--embedding_files", type=str, help="embedding_filepaths glob expression")
|
||||
parser.add_argument("--output_index_file", type=str, help="output filepath")
|
||||
args = parser.parse_args()
|
||||
|
||||
build_index(embedding_files=args.embedding_files, index_name=args.output_index_file)
|
||||
@@ -1,134 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from tqdm import tqdm
|
||||
import openai
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
sys.path.append("../../../")
|
||||
from openai_tools import async_get_embedding_with_backoff
|
||||
from openai_parallel_request_processor import process_api_requests_from_file
|
||||
|
||||
|
||||
# some settings specific to our own OpenAI org limits
|
||||
# (specific to text-embedding-ada-002)
|
||||
TPM_LIMIT = 1000000
|
||||
RPM_LIMIT = 3000
|
||||
|
||||
DEFAULT_FILE = "iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz"
|
||||
EMBEDDING_MODEL = "text-embedding-ada-002"
|
||||
|
||||
|
||||
async def generate_requests_file(filename):
|
||||
"""Generate a file of requests, which we can feed to a pre-made openai cookbook function"""
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
requests_filename = f"{base_name}_embedding_requests.jsonl"
|
||||
|
||||
with open(filename, "r") as f:
|
||||
all_data = [json.loads(line) for line in f]
|
||||
|
||||
with open(requests_filename, "w") as f:
|
||||
for data in all_data:
|
||||
documents = data
|
||||
for idx, doc in enumerate(documents):
|
||||
title = doc["title"]
|
||||
text = doc["text"]
|
||||
document_string = f"Document [{idx+1}] (Title: {title}) {text}"
|
||||
request = {"model": EMBEDDING_MODEL, "input": document_string}
|
||||
json_string = json.dumps(request)
|
||||
f.write(json_string + "\n")
|
||||
|
||||
# Run your parallel processing function
|
||||
input(f"Generated requests file ({requests_filename}), continue with embedding batch requests? (hit enter)")
|
||||
await process_api_requests_from_file(
|
||||
requests_filepath=requests_filename,
|
||||
save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary
|
||||
request_url="https://api.openai.com/v1/embeddings",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
max_requests_per_minute=RPM_LIMIT,
|
||||
max_tokens_per_minute=TPM_LIMIT,
|
||||
token_encoding_name=EMBEDDING_MODEL,
|
||||
max_attempts=5,
|
||||
logging_level=logging.INFO,
|
||||
)
|
||||
|
||||
|
||||
async def generate_embedding_file(filename, parallel_mode=False):
|
||||
if parallel_mode:
|
||||
await generate_requests_file(filename)
|
||||
return
|
||||
|
||||
# Derive the sister filename
|
||||
# base_name = os.path.splitext(filename)[0]
|
||||
base_name = filename.rsplit(".jsonl", 1)[0]
|
||||
sister_filename = f"{base_name}.embeddings.jsonl"
|
||||
|
||||
# Check if the sister file already exists
|
||||
if os.path.exists(sister_filename):
|
||||
print(f"{sister_filename} already exists. Skipping embedding generation.")
|
||||
return
|
||||
|
||||
with open(filename, "rt") as f:
|
||||
all_data = [json.loads(line) for line in f]
|
||||
|
||||
embedding_data = []
|
||||
total_documents = sum(len(data) for data in all_data)
|
||||
|
||||
# Outer loop progress bar
|
||||
for i, data in enumerate(tqdm(all_data, desc="Processing data", total=len(all_data))):
|
||||
documents = data
|
||||
# Inner loop progress bar
|
||||
for idx, doc in enumerate(
|
||||
tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False)
|
||||
):
|
||||
title = doc["title"]
|
||||
text = doc["text"]
|
||||
document_string = f"[Title: {title}] {text}"
|
||||
try:
|
||||
embedding = await async_get_embedding_with_backoff(document_string, model=EMBEDDING_MODEL)
|
||||
except Exception as e:
|
||||
print(document_string)
|
||||
raise e
|
||||
embedding_data.append(embedding)
|
||||
|
||||
# Save the embeddings to the sister file
|
||||
# with gzip.open(sister_filename, 'wt') as f:
|
||||
with open(sister_filename, "wb") as f:
|
||||
for embedding in embedding_data:
|
||||
# f.write(json.dumps(embedding) + '\n')
|
||||
f.write((json.dumps(embedding) + "\n").encode("utf-8"))
|
||||
|
||||
print(f"Embeddings saved to {sister_filename}")
|
||||
|
||||
|
||||
async def main():
|
||||
if len(sys.argv) > 1:
|
||||
filename = sys.argv[1]
|
||||
else:
|
||||
filename = DEFAULT_FILE
|
||||
await generate_embedding_file(filename)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filename", nargs="?", default=DEFAULT_FILE, help="Path to the input file")
|
||||
parser.add_argument("--parallel", action="store_true", help="Enable parallel mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
await generate_embedding_file(args.filename, parallel_mode=args.parallel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
@@ -1,456 +0,0 @@
|
||||
"""
|
||||
API REQUEST PARALLEL PROCESSOR
|
||||
|
||||
Using the OpenAI API to process lots of text quickly takes some care.
|
||||
If you trickle in a million API requests one by one, they'll take days to complete.
|
||||
If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
|
||||
To maximize throughput, parallel requests need to be throttled to stay under rate limits.
|
||||
|
||||
This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
|
||||
|
||||
Features:
|
||||
- Streams requests from file, to avoid running out of memory for giant jobs
|
||||
- Makes requests concurrently, to maximize throughput
|
||||
- Throttles request and token usage, to stay under rate limits
|
||||
- Retries failed requests up to {max_attempts} times, to avoid missing data
|
||||
- Logs errors, to diagnose problems with requests
|
||||
|
||||
Example command to call script:
|
||||
```
|
||||
python examples/api_request_parallel_processor.py \
|
||||
--requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
|
||||
--save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
|
||||
--request_url https://api.openai.com/v1/embeddings \
|
||||
--max_requests_per_minute 1500 \
|
||||
--max_tokens_per_minute 6250000 \
|
||||
--token_encoding_name cl100k_base \
|
||||
--max_attempts 5 \
|
||||
--logging_level 20
|
||||
```
|
||||
|
||||
Inputs:
|
||||
- requests_filepath : str
|
||||
- path to the file containing the requests to be processed
|
||||
- file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
|
||||
- e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
|
||||
- as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
|
||||
- an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
|
||||
- the code to generate the example file is appended to the bottom of this script
|
||||
- save_filepath : str, optional
|
||||
- path to the file where the results will be saved
|
||||
- file will be a jsonl file, where each line is an array with the original request plus the API response
|
||||
- e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
|
||||
- if omitted, results will be saved to {requests_filename}_results.jsonl
|
||||
- request_url : str, optional
|
||||
- URL of the API endpoint to call
|
||||
- if omitted, will default to "https://api.openai.com/v1/embeddings"
|
||||
- api_key : str, optional
|
||||
- API key to use
|
||||
- if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
|
||||
- max_requests_per_minute : float, optional
|
||||
- target number of requests to make per minute (will make less if limited by tokens)
|
||||
- leave headroom by setting this to 50% or 75% of your limit
|
||||
- if requests are limiting you, try batching multiple embeddings or completions into one request
|
||||
- if omitted, will default to 1,500
|
||||
- max_tokens_per_minute : float, optional
|
||||
- target number of tokens to use per minute (will use less if limited by requests)
|
||||
- leave headroom by setting this to 50% or 75% of your limit
|
||||
- if omitted, will default to 125,000
|
||||
- token_encoding_name : str, optional
|
||||
- name of the token encoding used, as defined in the `tiktoken` package
|
||||
- if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
|
||||
- max_attempts : int, optional
|
||||
- number of times to retry a failed request before giving up
|
||||
- if omitted, will default to 5
|
||||
- logging_level : int, optional
|
||||
- level of logging to use; higher numbers will log fewer messages
|
||||
- 40 = ERROR; will log only when requests fail after all retries
|
||||
- 30 = WARNING; will log when requests his rate limits or other errors
|
||||
- 20 = INFO; will log when requests start and the status at finish
|
||||
- 10 = DEBUG; will log various things as the loop runs to see when they occur
|
||||
- if omitted, will default to 20 (INFO).
|
||||
|
||||
The script is structured as follows:
|
||||
- Imports
|
||||
- Define main()
|
||||
- Initialize things
|
||||
- In main loop:
|
||||
- Get next request if one is not already waiting for capacity
|
||||
- Update available token & request capacity
|
||||
- If enough capacity available, call API
|
||||
- The loop pauses if a rate limit error is hit
|
||||
- The loop breaks when no tasks remain
|
||||
- Define dataclasses
|
||||
- StatusTracker (stores script metadata counters; only one instance is created)
|
||||
- APIRequest (stores API inputs, outputs, metadata; one method to call API)
|
||||
- Define functions
|
||||
- api_endpoint_from_url (extracts API endpoint from request URL)
|
||||
- append_to_jsonl (writes to results file)
|
||||
- num_tokens_consumed_from_request (bigger function to infer token usage from request)
|
||||
- task_id_generator_function (yields 1, 2, 3, ...)
|
||||
- Run main()
|
||||
"""
|
||||
|
||||
# imports
|
||||
import aiohttp # for making API calls concurrently
|
||||
import argparse # for running script from command line
|
||||
import asyncio # for running API calls concurrently
|
||||
import json # for saving results to a jsonl file
|
||||
import logging # for logging rate limit warnings and other messages
|
||||
import os # for reading API key
|
||||
import re # for matching endpoint from request URL
|
||||
import tiktoken # for counting tokens
|
||||
import time # for sleeping after rate limit is hit
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
) # for storing API inputs, outputs, and metadata
|
||||
|
||||
|
||||
async def process_api_requests_from_file(
|
||||
requests_filepath: str,
|
||||
save_filepath: str,
|
||||
request_url: str,
|
||||
api_key: str,
|
||||
max_requests_per_minute: float,
|
||||
max_tokens_per_minute: float,
|
||||
token_encoding_name: str,
|
||||
max_attempts: int,
|
||||
logging_level: int,
|
||||
):
|
||||
"""Processes API requests in parallel, throttling to stay under rate limits."""
|
||||
# constants
|
||||
seconds_to_pause_after_rate_limit_error = 15
|
||||
seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second
|
||||
|
||||
# initialize logging
|
||||
logging.basicConfig(level=logging_level)
|
||||
logging.debug(f"Logging initialized at level {logging_level}")
|
||||
|
||||
# infer API endpoint and construct request header
|
||||
api_endpoint = api_endpoint_from_url(request_url)
|
||||
request_header = {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
# initialize trackers
|
||||
queue_of_requests_to_retry = asyncio.Queue()
|
||||
task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ...
|
||||
status_tracker = StatusTracker() # single instance to track a collection of variables
|
||||
next_request = None # variable to hold the next request to call
|
||||
|
||||
# initialize available capacity counts
|
||||
available_request_capacity = max_requests_per_minute
|
||||
available_token_capacity = max_tokens_per_minute
|
||||
last_update_time = time.time()
|
||||
|
||||
# initialize flags
|
||||
file_not_finished = True # after file is empty, we'll skip reading it
|
||||
logging.debug(f"Initialization complete.")
|
||||
|
||||
# initialize file reading
|
||||
with open(requests_filepath) as file:
|
||||
# `requests` will provide requests one at a time
|
||||
requests = file.__iter__()
|
||||
logging.debug(f"File opened. Entering main loop")
|
||||
async with aiohttp.ClientSession() as session: # Initialize ClientSession here
|
||||
while True:
|
||||
# get next request (if one is not already waiting for capacity)
|
||||
if next_request is None:
|
||||
if not queue_of_requests_to_retry.empty():
|
||||
next_request = queue_of_requests_to_retry.get_nowait()
|
||||
logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
|
||||
elif file_not_finished:
|
||||
try:
|
||||
# get new request
|
||||
request_json = json.loads(next(requests))
|
||||
next_request = APIRequest(
|
||||
task_id=next(task_id_generator),
|
||||
request_json=request_json,
|
||||
token_consumption=num_tokens_consumed_from_request(request_json, api_endpoint, token_encoding_name),
|
||||
attempts_left=max_attempts,
|
||||
metadata=request_json.pop("metadata", None),
|
||||
)
|
||||
status_tracker.num_tasks_started += 1
|
||||
status_tracker.num_tasks_in_progress += 1
|
||||
logging.debug(f"Reading request {next_request.task_id}: {next_request}")
|
||||
except StopIteration:
|
||||
# if file runs out, set flag to stop reading it
|
||||
logging.debug("Read file exhausted")
|
||||
file_not_finished = False
|
||||
|
||||
# update available capacity
|
||||
current_time = time.time()
|
||||
seconds_since_update = current_time - last_update_time
|
||||
available_request_capacity = min(
|
||||
available_request_capacity + max_requests_per_minute * seconds_since_update / 60.0,
|
||||
max_requests_per_minute,
|
||||
)
|
||||
available_token_capacity = min(
|
||||
available_token_capacity + max_tokens_per_minute * seconds_since_update / 60.0,
|
||||
max_tokens_per_minute,
|
||||
)
|
||||
last_update_time = current_time
|
||||
|
||||
# if enough capacity available, call API
|
||||
if next_request:
|
||||
next_request_tokens = next_request.token_consumption
|
||||
if available_request_capacity >= 1 and available_token_capacity >= next_request_tokens:
|
||||
# update counters
|
||||
available_request_capacity -= 1
|
||||
available_token_capacity -= next_request_tokens
|
||||
next_request.attempts_left -= 1
|
||||
|
||||
# call API
|
||||
asyncio.create_task(
|
||||
next_request.call_api(
|
||||
session=session,
|
||||
request_url=request_url,
|
||||
request_header=request_header,
|
||||
retry_queue=queue_of_requests_to_retry,
|
||||
save_filepath=save_filepath,
|
||||
status_tracker=status_tracker,
|
||||
)
|
||||
)
|
||||
next_request = None # reset next_request to empty
|
||||
|
||||
# if all tasks are finished, break
|
||||
if status_tracker.num_tasks_in_progress == 0:
|
||||
break
|
||||
|
||||
# main loop sleeps briefly so concurrent tasks can run
|
||||
await asyncio.sleep(seconds_to_sleep_each_loop)
|
||||
|
||||
# if a rate limit error was hit recently, pause to cool down
|
||||
seconds_since_rate_limit_error = time.time() - status_tracker.time_of_last_rate_limit_error
|
||||
if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
|
||||
remaining_seconds_to_pause = seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
|
||||
await asyncio.sleep(remaining_seconds_to_pause)
|
||||
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
|
||||
logging.warn(
|
||||
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
||||
)
|
||||
|
||||
# after finishing, log final status
|
||||
logging.info(f"""Parallel processing complete. Results saved to {save_filepath}""")
|
||||
if status_tracker.num_tasks_failed > 0:
|
||||
logging.warning(
|
||||
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
|
||||
)
|
||||
if status_tracker.num_rate_limit_errors > 0:
|
||||
logging.warning(f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.")
|
||||
|
||||
|
||||
# dataclasses
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusTracker:
|
||||
"""Stores metadata about the script's progress. Only one instance is created."""
|
||||
|
||||
num_tasks_started: int = 0
|
||||
num_tasks_in_progress: int = 0 # script ends when this reaches 0
|
||||
num_tasks_succeeded: int = 0
|
||||
num_tasks_failed: int = 0
|
||||
num_rate_limit_errors: int = 0
|
||||
num_api_errors: int = 0 # excluding rate limit errors, counted above
|
||||
num_other_errors: int = 0
|
||||
time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIRequest:
|
||||
"""Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
|
||||
|
||||
task_id: int
|
||||
request_json: dict
|
||||
token_consumption: int
|
||||
attempts_left: int
|
||||
metadata: dict
|
||||
result: list = field(default_factory=list)
|
||||
|
||||
async def call_api(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
request_url: str,
|
||||
request_header: dict,
|
||||
retry_queue: asyncio.Queue,
|
||||
save_filepath: str,
|
||||
status_tracker: StatusTracker,
|
||||
):
|
||||
"""Calls the OpenAI API and saves results."""
|
||||
logging.info(f"Starting request #{self.task_id}")
|
||||
error = None
|
||||
try:
|
||||
async with session.post(url=request_url, headers=request_header, json=self.request_json) as response:
|
||||
response = await response.json()
|
||||
if "error" in response:
|
||||
logging.warning(f"Request {self.task_id} failed with error {response['error']}")
|
||||
status_tracker.num_api_errors += 1
|
||||
error = response
|
||||
if "Rate limit" in response["error"].get("message", ""):
|
||||
status_tracker.time_of_last_rate_limit_error = time.time()
|
||||
status_tracker.num_rate_limit_errors += 1
|
||||
status_tracker.num_api_errors -= 1 # rate limit errors are counted separately
|
||||
|
||||
except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
|
||||
logging.warning(f"Request {self.task_id} failed with Exception {e}")
|
||||
status_tracker.num_other_errors += 1
|
||||
error = e
|
||||
if error:
|
||||
self.result.append(error)
|
||||
if self.attempts_left:
|
||||
retry_queue.put_nowait(self)
|
||||
else:
|
||||
logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}")
|
||||
data = (
|
||||
[self.request_json, [str(e) for e in self.result], self.metadata]
|
||||
if self.metadata
|
||||
else [self.request_json, [str(e) for e in self.result]]
|
||||
)
|
||||
append_to_jsonl(data, save_filepath)
|
||||
status_tracker.num_tasks_in_progress -= 1
|
||||
status_tracker.num_tasks_failed += 1
|
||||
else:
|
||||
data = [self.request_json, response, self.metadata] if self.metadata else [self.request_json, response]
|
||||
append_to_jsonl(data, save_filepath)
|
||||
status_tracker.num_tasks_in_progress -= 1
|
||||
status_tracker.num_tasks_succeeded += 1
|
||||
logging.debug(f"Request {self.task_id} saved to {save_filepath}")
|
||||
|
||||
|
||||
# functions
|
||||
|
||||
|
||||
def api_endpoint_from_url(request_url):
|
||||
"""Extract the API endpoint from the request URL."""
|
||||
match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url)
|
||||
return match[1]
|
||||
|
||||
|
||||
def append_to_jsonl(data, filename: str) -> None:
|
||||
"""Append a json payload to the end of a jsonl file."""
|
||||
json_string = json.dumps(data)
|
||||
with open(filename, "a") as f:
|
||||
f.write(json_string + "\n")
|
||||
|
||||
|
||||
def num_tokens_consumed_from_request(
|
||||
request_json: dict,
|
||||
api_endpoint: str,
|
||||
token_encoding_name: str,
|
||||
):
|
||||
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
|
||||
if token_encoding_name == "text-embedding-ada-002":
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
else:
|
||||
encoding = tiktoken.get_encoding(token_encoding_name)
|
||||
# if completions request, tokens = prompt + n * max_tokens
|
||||
if api_endpoint.endswith("completions"):
|
||||
max_tokens = request_json.get("max_tokens", 15)
|
||||
n = request_json.get("n", 1)
|
||||
completion_tokens = n * max_tokens
|
||||
|
||||
# chat completions
|
||||
if api_endpoint.startswith("chat/"):
|
||||
num_tokens = 0
|
||||
for message in request_json["messages"]:
|
||||
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name": # if there's a name, the role is omitted
|
||||
num_tokens -= 1 # role is always required and always 1 token
|
||||
num_tokens += 2 # every reply is primed with <im_start>assistant
|
||||
return num_tokens + completion_tokens
|
||||
# normal completions
|
||||
else:
|
||||
prompt = request_json["prompt"]
|
||||
if isinstance(prompt, str): # single prompt
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
num_tokens = prompt_tokens + completion_tokens
|
||||
return num_tokens
|
||||
elif isinstance(prompt, list): # multiple prompts
|
||||
prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
|
||||
num_tokens = prompt_tokens + completion_tokens * len(prompt)
|
||||
return num_tokens
|
||||
else:
|
||||
raise TypeError('Expecting either string or list of strings for "prompt" field in completion request')
|
||||
# if embeddings request, tokens = input tokens
|
||||
elif api_endpoint == "embeddings":
|
||||
input = request_json["input"]
|
||||
if isinstance(input, str): # single input
|
||||
num_tokens = len(encoding.encode(input))
|
||||
return num_tokens
|
||||
elif isinstance(input, list): # multiple inputs
|
||||
num_tokens = sum([len(encoding.encode(i)) for i in input])
|
||||
return num_tokens
|
||||
else:
|
||||
raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request')
|
||||
# more logic needed to support other API calls (e.g., edits, inserts, DALL-E)
|
||||
else:
|
||||
raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script')
|
||||
|
||||
|
||||
def task_id_generator_function():
|
||||
"""Generate integers 0, 1, 2, and so on."""
|
||||
task_id = 0
|
||||
while True:
|
||||
yield task_id
|
||||
task_id += 1
|
||||
|
||||
|
||||
# run script
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse command line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--requests_filepath")
|
||||
parser.add_argument("--save_filepath", default=None)
|
||||
parser.add_argument("--request_url", default="https://api.openai.com/v1/embeddings")
|
||||
parser.add_argument("--api_key", default=os.getenv("OPENAI_API_KEY"))
|
||||
parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
|
||||
parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
|
||||
parser.add_argument("--token_encoding_name", default="cl100k_base")
|
||||
parser.add_argument("--max_attempts", type=int, default=5)
|
||||
parser.add_argument("--logging_level", default=logging.INFO)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.save_filepath is None:
|
||||
args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
|
||||
|
||||
# run script
|
||||
asyncio.run(
|
||||
process_api_requests_from_file(
|
||||
requests_filepath=args.requests_filepath,
|
||||
save_filepath=args.save_filepath,
|
||||
request_url=args.request_url,
|
||||
api_key=args.api_key,
|
||||
max_requests_per_minute=float(args.max_requests_per_minute),
|
||||
max_tokens_per_minute=float(args.max_tokens_per_minute),
|
||||
token_encoding_name=args.token_encoding_name,
|
||||
max_attempts=int(args.max_attempts),
|
||||
logging_level=int(args.logging_level),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
APPENDIX
|
||||
|
||||
The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
|
||||
|
||||
It was generated with the following code:
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
filename = "data/example_requests_to_parallel_process.jsonl"
|
||||
n_requests = 10_000
|
||||
jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
|
||||
with open(filename, "w") as f:
|
||||
for job in jobs:
|
||||
json_string = json.dumps(job)
|
||||
f.write(json_string + "\n")
|
||||
```
|
||||
|
||||
As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
|
||||
"""
|
||||
@@ -1,68 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import tiktoken
|
||||
import json
|
||||
|
||||
# Define the directory where the documentation resides
|
||||
docs_dir = "text"
|
||||
|
||||
encoding = tiktoken.encoding_for_model("gpt-4")
|
||||
PASSAGE_TOKEN_LEN = 800
|
||||
|
||||
|
||||
def extract_text_from_sphinx_txt(file_path):
|
||||
lines = []
|
||||
title = ""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
if not title:
|
||||
title = line.strip()
|
||||
continue
|
||||
if line and re.match(r"^.*\S.*$", line) and not re.match(r"^[-=*]+$", line):
|
||||
lines.append(line)
|
||||
passages = []
|
||||
curr_passage = []
|
||||
curr_token_ct = 0
|
||||
for line in lines:
|
||||
try:
|
||||
line_token_ct = len(encoding.encode(line, allowed_special={"<|endoftext|>"}))
|
||||
except Exception as e:
|
||||
print("line", line)
|
||||
raise e
|
||||
if line_token_ct > PASSAGE_TOKEN_LEN:
|
||||
passages.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": line[:3200],
|
||||
"num_tokens": curr_token_ct,
|
||||
}
|
||||
)
|
||||
continue
|
||||
curr_token_ct += line_token_ct
|
||||
curr_passage.append(line)
|
||||
if curr_token_ct > PASSAGE_TOKEN_LEN:
|
||||
passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct})
|
||||
curr_passage = []
|
||||
curr_token_ct = 0
|
||||
|
||||
if len(curr_passage) > 0:
|
||||
passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct})
|
||||
return passages
|
||||
|
||||
|
||||
# Iterate over all files in the directory and its subdirectories
|
||||
passages = []
|
||||
total_files = 0
|
||||
for subdir, _, files in os.walk(docs_dir):
|
||||
for file in files:
|
||||
if file.endswith(".txt"):
|
||||
file_path = os.path.join(subdir, file)
|
||||
passages.append(extract_text_from_sphinx_txt(file_path))
|
||||
total_files += 1
|
||||
print("total .txt files:", total_files)
|
||||
|
||||
# Save to a new text file or process as needed
|
||||
with open("all_docs.jsonl", "w", encoding="utf-8") as file:
|
||||
for p in passages:
|
||||
file.write(json.dumps(p))
|
||||
file.write("\n")
|
||||
@@ -1,19 +0,0 @@
|
||||
# Preloading Archival Memory with Files
|
||||
MemGPT enables you to chat with your data locally -- this example gives the workflow for loading documents into MemGPT's archival memory.
|
||||
|
||||
To run our example where you can search over the SEC 10-K filings of Uber, Lyft, and Airbnb,
|
||||
|
||||
1. Download the .txt files from [Hugging Face](https://huggingface.co/datasets/MemGPT/example-sec-filings/tree/main) and place them in this directory.
|
||||
|
||||
2. In the root `MemGPT` directory, run
|
||||
```bash
|
||||
python3 main.py --archival_storage_files="memgpt/personas/examples/preload_archival/*.txt" --persona=memgpt_doc --human=basic
|
||||
```
|
||||
|
||||
|
||||
If you would like to load your own local files into MemGPT's archival memory, run the command above but replace `--archival_storage_files="memgpt/personas/examples/preload_archival/*.txt"` with your own file glob expression (enclosed in quotes).
|
||||
|
||||
## Demo
|
||||
<div align="center">
|
||||
<img src="https://memgpt.ai/assets/img/preload_archival_demo.gif" alt="MemGPT demo video for searching through preloaded files" width="800">
|
||||
</div>
|
||||
@@ -1,17 +0,0 @@
|
||||
import os
|
||||
|
||||
DEFAULT = "sam_pov"
|
||||
GPT35_DEFAULT = "sam_simple_pov_gpt35"
|
||||
|
||||
|
||||
def get_persona_text(key=DEFAULT, dir=None):
|
||||
if dir is None:
|
||||
dir = os.path.join(os.path.dirname(__file__), "examples")
|
||||
filename = key if key.endswith(".txt") else f"{key}.txt"
|
||||
file_path = os.path.join(dir, filename)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
return file.read().strip()
|
||||
else:
|
||||
raise FileNotFoundError(f"No file found for key {key}, path={file_path}")
|
||||
290
memgpt/utils.py
290
memgpt/utils.py
@@ -1,24 +1,16 @@
|
||||
from datetime import datetime
|
||||
import csv
|
||||
import difflib
|
||||
import demjson3 as demjson
|
||||
import numpy as np
|
||||
import json
|
||||
import pytz
|
||||
import os
|
||||
import tiktoken
|
||||
import glob
|
||||
import sqlite3
|
||||
import fitz
|
||||
from tqdm import tqdm
|
||||
import typer
|
||||
import memgpt
|
||||
from memgpt.openai_tools import get_embedding_with_backoff
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
# TODO: what is this?
|
||||
# DEBUG = True
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
@@ -26,19 +18,11 @@ def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
return len(encoding.encode(s))
|
||||
|
||||
|
||||
# DEBUG = True
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def printd(*args, **kwargs):
|
||||
if DEBUG:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
|
||||
def united_diff(str1, str2):
|
||||
lines1 = str1.splitlines(True)
|
||||
lines2 = str2.splitlines(True)
|
||||
@@ -88,6 +72,7 @@ def get_local_time(timezone=None):
|
||||
|
||||
|
||||
def parse_json(string):
|
||||
"""Parse JSON string into JSON with both json and demjson"""
|
||||
result = None
|
||||
try:
|
||||
result = json.loads(string)
|
||||
@@ -103,273 +88,6 @@ def parse_json(string):
|
||||
raise e
|
||||
|
||||
|
||||
def prepare_archival_index(folder):
|
||||
import faiss
|
||||
|
||||
index_file = os.path.join(folder, "all_docs.index")
|
||||
index = faiss.read_index(index_file)
|
||||
|
||||
archival_database_file = os.path.join(folder, "all_docs.jsonl")
|
||||
archival_database = []
|
||||
with open(archival_database_file, "rt") as f:
|
||||
all_data = [json.loads(line) for line in f]
|
||||
for doc in all_data:
|
||||
total = len(doc)
|
||||
for i, passage in enumerate(doc):
|
||||
archival_database.append(
|
||||
{
|
||||
"content": f"[Title: {passage['title']}, {i}/{total}] {passage['text']}",
|
||||
"timestamp": get_local_time(),
|
||||
}
|
||||
)
|
||||
return index, archival_database
|
||||
|
||||
|
||||
def read_in_chunks(file_object, chunk_size):
|
||||
while True:
|
||||
data = file_object.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
|
||||
|
||||
def read_pdf_in_chunks(file, chunk_size):
|
||||
doc = fitz.open(file)
|
||||
for page in doc:
|
||||
text = page.get_text()
|
||||
yield text
|
||||
|
||||
|
||||
def read_in_rows_csv(file_object, chunk_size):
|
||||
csvreader = csv.reader(file_object)
|
||||
header = next(csvreader)
|
||||
for row in csvreader:
|
||||
next_row_terms = []
|
||||
for h, v in zip(header, row):
|
||||
next_row_terms.append(f"{h}={v}")
|
||||
next_row_str = ", ".join(next_row_terms)
|
||||
yield next_row_str
|
||||
|
||||
|
||||
def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model="gpt-4"):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
files = glob.glob(glob_pattern, recursive=True)
|
||||
return chunk_files(files, tkns_per_chunk, model)
|
||||
|
||||
|
||||
def total_bytes(pattern):
|
||||
total = 0
|
||||
for filename in glob.glob(pattern, recursive=True):
|
||||
if os.path.isfile(filename): # ensure it's a file and not a directory
|
||||
total += os.path.getsize(filename)
|
||||
return total
|
||||
|
||||
|
||||
def chunk_file(file, tkns_per_chunk=300, model="gpt-4"):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
|
||||
if file.endswith(".db"):
|
||||
return # can't read the sqlite db this way, will get handled in main.py
|
||||
|
||||
with open(file, "r") as f:
|
||||
if file.endswith(".pdf"):
|
||||
lines = [l for l in read_pdf_in_chunks(file, tkns_per_chunk * 8)]
|
||||
if len(lines) == 0:
|
||||
print(f"Warning: {file} did not have any extractable text.")
|
||||
elif file.endswith(".csv"):
|
||||
lines = [l for l in read_in_rows_csv(f, tkns_per_chunk * 8)]
|
||||
else:
|
||||
lines = [l for l in read_in_chunks(f, tkns_per_chunk * 4)]
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
for i, line in enumerate(lines):
|
||||
line = line.rstrip()
|
||||
line = line.lstrip()
|
||||
line += "\n"
|
||||
try:
|
||||
line_token_ct = len(encoding.encode(line))
|
||||
except Exception as e:
|
||||
line_token_ct = len(line.split(" ")) / 0.75
|
||||
print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens")
|
||||
print(e)
|
||||
if line_token_ct > tkns_per_chunk:
|
||||
if len(curr_chunk) > 0:
|
||||
yield "".join(curr_chunk)
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
yield line[:3200]
|
||||
continue
|
||||
curr_token_ct += line_token_ct
|
||||
curr_chunk.append(line)
|
||||
if curr_token_ct > tkns_per_chunk:
|
||||
yield "".join(curr_chunk)
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
|
||||
if len(curr_chunk) > 0:
|
||||
yield "".join(curr_chunk)
|
||||
|
||||
|
||||
def chunk_files(files, tkns_per_chunk=300, model="gpt-4"):
|
||||
archival_database = []
|
||||
for file in files:
|
||||
timestamp = os.path.getmtime(file)
|
||||
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
file_stem = file.split(os.sep)[-1]
|
||||
chunks = [c for c in chunk_file(file, tkns_per_chunk, model)]
|
||||
for i, chunk in enumerate(chunks):
|
||||
archival_database.append(
|
||||
{
|
||||
"content": f"[File: {file_stem} Part {i}/{len(chunks)}] {chunk}",
|
||||
"timestamp": formatted_time,
|
||||
}
|
||||
)
|
||||
return archival_database
|
||||
|
||||
|
||||
def chunk_files_for_jsonl(files, tkns_per_chunk=300, model="gpt-4"):
|
||||
ret = []
|
||||
for file in files:
|
||||
file_stem = file.split(os.sep)[-1]
|
||||
curr_file = []
|
||||
for chunk in chunk_file(file, tkns_per_chunk, model):
|
||||
curr_file.append(
|
||||
{
|
||||
"title": file_stem,
|
||||
"text": chunk,
|
||||
}
|
||||
)
|
||||
ret.append(curr_file)
|
||||
return ret
|
||||
|
||||
|
||||
def process_chunk(i, chunk, model):
|
||||
try:
|
||||
return i, get_embedding_with_backoff(chunk["content"], model=model)
|
||||
except Exception as e:
|
||||
print(chunk)
|
||||
raise e
|
||||
|
||||
|
||||
def process_concurrently(archival_database, model, concurrency=10):
|
||||
embedding_data = [0 for _ in archival_database]
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||
# Submit tasks to the executor
|
||||
future_to_chunk = {executor.submit(process_chunk, i, chunk, model): i for i, chunk in enumerate(archival_database)}
|
||||
|
||||
# As each task completes, process the results
|
||||
for future in tqdm(as_completed(future_to_chunk), total=len(archival_database), desc="Processing file chunks"):
|
||||
i, result = future.result()
|
||||
embedding_data[i] = result
|
||||
return embedding_data
|
||||
|
||||
|
||||
def prepare_archival_index_from_files_compute_embeddings(
|
||||
glob_pattern,
|
||||
tkns_per_chunk=300,
|
||||
model="gpt-4",
|
||||
embeddings_model="text-embedding-ada-002",
|
||||
):
|
||||
files = sorted(glob.glob(glob_pattern, recursive=True))
|
||||
save_dir = os.path.join(
|
||||
MEMGPT_DIR,
|
||||
"archival_index_from_files_" + get_local_time().replace(" ", "_").replace(":", "_"),
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
total_tokens = total_bytes(glob_pattern) / 3
|
||||
price_estimate = total_tokens / 1000 * 0.0001
|
||||
confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ")
|
||||
if confirm != "y":
|
||||
raise Exception("embeddings were not computed")
|
||||
|
||||
# chunk the files, make embeddings
|
||||
archival_database = chunk_files(files, tkns_per_chunk, model)
|
||||
embedding_data = process_concurrently(archival_database, embeddings_model)
|
||||
embeddings_file = os.path.join(save_dir, "embeddings.json")
|
||||
with open(embeddings_file, "w") as f:
|
||||
print(f"Saving embeddings to {embeddings_file}")
|
||||
json.dump(embedding_data, f)
|
||||
|
||||
# make all_text.json
|
||||
archival_storage_file = os.path.join(save_dir, "all_docs.jsonl")
|
||||
chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model)
|
||||
with open(archival_storage_file, "w") as f:
|
||||
print(f"Saving archival storage with preloaded files to {archival_storage_file}")
|
||||
for c in chunks_by_file:
|
||||
json.dump(c, f)
|
||||
f.write("\n")
|
||||
|
||||
# make the faiss index
|
||||
import faiss
|
||||
|
||||
index = faiss.IndexFlatL2(1536)
|
||||
data = np.array(embedding_data).astype("float32")
|
||||
try:
|
||||
index.add(data)
|
||||
except Exception as e:
|
||||
print(data)
|
||||
raise e
|
||||
index_file = os.path.join(save_dir, "all_docs.index")
|
||||
print(f"Saving faiss index {index_file}")
|
||||
faiss.write_index(index, index_file)
|
||||
return save_dir
|
||||
|
||||
|
||||
def read_database_as_list(database_name):
|
||||
result_list = []
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(database_name)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
table_names = cursor.fetchall()
|
||||
for table_name in table_names:
|
||||
cursor.execute(f"PRAGMA table_info({table_name[0]});")
|
||||
schema_rows = cursor.fetchall()
|
||||
columns = [row[1] for row in schema_rows]
|
||||
cursor.execute(f"SELECT * FROM {table_name[0]};")
|
||||
rows = cursor.fetchall()
|
||||
result_list.append(f"Table: {table_name[0]}") # Add table name to the list
|
||||
schema_row = "\t".join(columns)
|
||||
result_list.append(schema_row)
|
||||
for row in rows:
|
||||
data_row = "\t".join(map(str, row))
|
||||
result_list.append(data_row)
|
||||
conn.close()
|
||||
except sqlite3.Error as e:
|
||||
result_list.append(f"Error reading database: {str(e)}")
|
||||
except Exception as e:
|
||||
result_list.append(f"Error: {str(e)}")
|
||||
return result_list
|
||||
|
||||
|
||||
def estimate_openai_cost(docs):
|
||||
"""Estimate OpenAI embedding cost
|
||||
|
||||
:param docs: Documents to be embedded
|
||||
:type docs: List[Document]
|
||||
:return: Estimated cost
|
||||
:rtype: float
|
||||
"""
|
||||
from llama_index import MockEmbedding
|
||||
from llama_index.callbacks import CallbackManager, TokenCountingHandler
|
||||
import tiktoken
|
||||
|
||||
embed_model = MockEmbedding(embed_dim=1536)
|
||||
|
||||
token_counter = TokenCountingHandler(tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode)
|
||||
|
||||
callback_manager = CallbackManager([token_counter])
|
||||
|
||||
set_global_service_context(ServiceContext.from_defaults(embed_model=embed_model, callback_manager=callback_manager))
|
||||
index = VectorStoreIndex.from_documents(docs)
|
||||
|
||||
# estimate cost
|
||||
cost = 0.0001 * token_counter.total_embedding_token_count / 1000
|
||||
token_counter.reset_counts()
|
||||
return cost
|
||||
|
||||
|
||||
def list_agent_config_files():
|
||||
"""List all agent config files, ignoring dotfiles."""
|
||||
files = os.listdir(os.path.join(MEMGPT_DIR, "agents"))
|
||||
|
||||
107
poetry.lock
generated
107
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
@@ -566,40 +566,6 @@ files = [
|
||||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "faiss-cpu"
|
||||
version = "1.7.4"
|
||||
description = "A library for efficient similarity search and clustering of dense vectors."
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "faiss-cpu-1.7.4.tar.gz", hash = "sha256:265dc31b0c079bf4433303bf6010f73922490adff9188b915e2d3f5e9c82dd0a"},
|
||||
{file = "faiss_cpu-1.7.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:50d4ebe7f1869483751c558558504f818980292a9b55be36f9a1ee1009d9a686"},
|
||||
{file = "faiss_cpu-1.7.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7b1db7fae7bd8312aeedd0c41536bcd19a6e297229e1dce526bde3a73ab8c0b5"},
|
||||
{file = "faiss_cpu-1.7.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17b7fa7194a228a84929d9e6619d0e7dbf00cc0f717e3462253766f5e3d07de8"},
|
||||
{file = "faiss_cpu-1.7.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dca531952a2e3eac56f479ff22951af4715ee44788a3fe991d208d766d3f95f3"},
|
||||
{file = "faiss_cpu-1.7.4-cp310-cp310-win_amd64.whl", hash = "sha256:7173081d605e74766f950f2e3d6568a6f00c53f32fd9318063e96728c6c62821"},
|
||||
{file = "faiss_cpu-1.7.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d0bbd6f55d7940cc0692f79e32a58c66106c3c950cee2341b05722de9da23ea3"},
|
||||
{file = "faiss_cpu-1.7.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e13c14280376100f143767d0efe47dcb32618f69e62bbd3ea5cd38c2e1755926"},
|
||||
{file = "faiss_cpu-1.7.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c521cb8462f3b00c0c7dfb11caff492bb67816528b947be28a3b76373952c41d"},
|
||||
{file = "faiss_cpu-1.7.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afdd9fe1141117fed85961fd36ee627c83fc3b9fd47bafb52d3c849cc2f088b7"},
|
||||
{file = "faiss_cpu-1.7.4-cp311-cp311-win_amd64.whl", hash = "sha256:2ff7f57889ea31d945e3b87275be3cad5d55b6261a4e3f51c7aba304d76b81fb"},
|
||||
{file = "faiss_cpu-1.7.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:eeaf92f27d76249fb53c1adafe617b0f217ab65837acf7b4ec818511caf6e3d8"},
|
||||
{file = "faiss_cpu-1.7.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:102b1bd763e9b0c281ac312590af3eaf1c8b663ccbc1145821fe6a9f92b8eaaf"},
|
||||
{file = "faiss_cpu-1.7.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5512da6707c967310c46ff712b00418b7ae28e93cb609726136e826e9f2f14fa"},
|
||||
{file = "faiss_cpu-1.7.4-cp37-cp37m-win_amd64.whl", hash = "sha256:0c2e5b9d8c28c99f990e87379d5bbcc6c914da91ebb4250166864fd12db5755b"},
|
||||
{file = "faiss_cpu-1.7.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:43f67f325393145d360171cd98786fcea6120ce50397319afd3bb78be409fb8a"},
|
||||
{file = "faiss_cpu-1.7.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6a4e4af194b8fce74c4b770cad67ad1dd1b4673677fc169723e4c50ba5bd97a8"},
|
||||
{file = "faiss_cpu-1.7.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31bfb7b9cffc36897ae02a983e04c09fe3b8c053110a287134751a115334a1df"},
|
||||
{file = "faiss_cpu-1.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52d7de96abef2340c0d373c1f5cbc78026a3cebb0f8f3a5920920a00210ead1f"},
|
||||
{file = "faiss_cpu-1.7.4-cp38-cp38-win_amd64.whl", hash = "sha256:699feef85b23c2c729d794e26ca69bebc0bee920d676028c06fd0e0becc15c7e"},
|
||||
{file = "faiss_cpu-1.7.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:559a0133f5ed44422acb09ee1ac0acffd90c6666d1bc0d671c18f6e93ad603e2"},
|
||||
{file = "faiss_cpu-1.7.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1d71539fe3dc0f1bed41ef954ca701678776f231046bf0ca22ccea5cf5bef6"},
|
||||
{file = "faiss_cpu-1.7.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12d45e0157024eb3249842163162983a1ac8b458f1a8b17bbf86f01be4585a99"},
|
||||
{file = "faiss_cpu-1.7.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f0eab359e066d32c874f51a7d4bf6440edeec068b7fe47e6d803c73605a8b4c"},
|
||||
{file = "faiss_cpu-1.7.4-cp39-cp39-win_amd64.whl", hash = "sha256:98459ceeeb735b9df1a5b94572106ffe0a6ce740eb7e4626715dd218657bb4dc"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.13.1"
|
||||
@@ -1046,13 +1012,13 @@ tests = ["pandas (>=1.4)", "pytest", "pytest-asyncio", "pytest-mock", "requests"
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "0.0.342"
|
||||
version = "0.0.343"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langchain-0.0.342-py3-none-any.whl", hash = "sha256:83c37898226666e0176d093f57fa49e176486608ef4c617a65aadf0b038ba0ec"},
|
||||
{file = "langchain-0.0.342.tar.gz", hash = "sha256:06341ee0b034847cbcea4b40a0a26b270abb6fd1237437735187c44d30a7a24d"},
|
||||
{file = "langchain-0.0.343-py3-none-any.whl", hash = "sha256:1959336b6076066bf233dd99dce44be2e9adccb53d799bff92c653098178b347"},
|
||||
{file = "langchain-0.0.343.tar.gz", hash = "sha256:166924d771a463009277f688f6dfc829a3af2d9cd5b41a64a7a6bd7860280e81"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1071,14 +1037,14 @@ SQLAlchemy = ">=1.4,<3"
|
||||
tenacity = ">=8.1.0,<9.0.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["O365 (>=2.0.26,<3.0.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "amadeus (>=8.1.0)", "arxiv (>=1.4,<2.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "awadb (>=0.3.9,<0.4.0)", "azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "beautifulsoup4 (>=4,<5)", "clarifai (>=9.1.0)", "clickhouse-connect (>=0.5.14,<0.6.0)", "cohere (>=4,<5)", "deeplake (>=3.8.3,<4.0.0)", "docarray[hnswlib] (>=0.32.0,<0.33.0)", "duckduckgo-search (>=3.8.3,<4.0.0)", "elasticsearch (>=8,<9)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "google-api-python-client (==2.70.0)", "google-auth (>=2.18.1,<3.0.0)", "google-search-results (>=2,<3)", "gptcache (>=0.1.7)", "html2text (>=2020.1.16,<2021.0.0)", "huggingface_hub (>=0,<1)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "lancedb (>=0.1,<0.2)", "langkit (>=0.0.6,<0.1.0)", "lark (>=1.1.5,<2.0.0)", "librosa (>=0.10.0.post2,<0.11.0)", "lxml (>=4.9.2,<5.0.0)", "manifest-ml (>=0.0.1,<0.0.2)", "marqo (>=1.2.4,<2.0.0)", "momento (>=1.13.0,<2.0.0)", "nebula3-python (>=3.4.0,<4.0.0)", "neo4j (>=5.8.1,<6.0.0)", "networkx (>=2.6.3,<4)", "nlpcloud (>=1,<2)", "nltk (>=3,<4)", "nomic (>=1.0.43,<2.0.0)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "opensearch-py (>=2.0.0,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pexpect (>=4.8.0,<5.0.0)", "pgvector (>=0.1.6,<0.2.0)", "pinecone-client (>=2,<3)", "pinecone-text (>=0.4.2,<0.5.0)", "psycopg2-binary (>=2.9.5,<3.0.0)", "pymongo (>=4.3.3,<5.0.0)", "pyowm (>=3.3.0,<4.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pytesseract (>=0.3.10,<0.4.0)", "python-arango (>=7.5.9,<8.0.0)", "pyvespa (>=0.33.0,<0.34.0)", "qdrant-client (>=1.3.1,<2.0.0)", "rdflib (>=6.3.2,<7.0.0)", "redis (>=4,<5)", "requests-toolbelt (>=1.0.0,<2.0.0)", "sentence-transformers (>=2,<3)", "singlestoredb (>=0.7.1,<0.8.0)", "tensorflow-text (>=2.11.0,<3.0.0)", "tigrisdb (>=1.0.0b6,<2.0.0)", "tiktoken (>=0.3.2,<0.6.0)", "torch (>=1,<3)", "transformers (>=4,<5)", "weaviate-client (>=3,<4)", "wikipedia (>=1,<2)", "wolframalpha (==5.0.0)"]
|
||||
all = ["O365 (>=2.0.26,<3.0.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "amadeus (>=8.1.0)", "arxiv (>=1.4,<2.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "awadb (>=0.3.9,<0.4.0)", "azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "beautifulsoup4 (>=4,<5)", "clarifai (>=9.1.0)", "clickhouse-connect (>=0.5.14,<0.6.0)", "cohere (>=4,<5)", "deeplake (>=3.8.3,<4.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "docarray[hnswlib] (>=0.32.0,<0.33.0)", "duckduckgo-search (>=3.8.3,<4.0.0)", "elasticsearch (>=8,<9)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "google-api-python-client (==2.70.0)", "google-auth (>=2.18.1,<3.0.0)", "google-search-results (>=2,<3)", "gptcache (>=0.1.7)", "html2text (>=2020.1.16,<2021.0.0)", "huggingface_hub (>=0,<1)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "lancedb (>=0.1,<0.2)", "langkit (>=0.0.6,<0.1.0)", "lark (>=1.1.5,<2.0.0)", "librosa (>=0.10.0.post2,<0.11.0)", "lxml (>=4.9.2,<5.0.0)", "manifest-ml (>=0.0.1,<0.0.2)", "marqo (>=1.2.4,<2.0.0)", "momento (>=1.13.0,<2.0.0)", "nebula3-python (>=3.4.0,<4.0.0)", "neo4j (>=5.8.1,<6.0.0)", "networkx (>=2.6.3,<4)", "nlpcloud (>=1,<2)", "nltk (>=3,<4)", "nomic (>=1.0.43,<2.0.0)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "opensearch-py (>=2.0.0,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pexpect (>=4.8.0,<5.0.0)", "pgvector (>=0.1.6,<0.2.0)", "pinecone-client (>=2,<3)", "pinecone-text (>=0.4.2,<0.5.0)", "psycopg2-binary (>=2.9.5,<3.0.0)", "pymongo (>=4.3.3,<5.0.0)", "pyowm (>=3.3.0,<4.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pytesseract (>=0.3.10,<0.4.0)", "python-arango (>=7.5.9,<8.0.0)", "pyvespa (>=0.33.0,<0.34.0)", "qdrant-client (>=1.3.1,<2.0.0)", "rdflib (>=6.3.2,<7.0.0)", "redis (>=4,<5)", "requests-toolbelt (>=1.0.0,<2.0.0)", "sentence-transformers (>=2,<3)", "singlestoredb (>=0.7.1,<0.8.0)", "tensorflow-text (>=2.11.0,<3.0.0)", "tigrisdb (>=1.0.0b6,<2.0.0)", "tiktoken (>=0.3.2,<0.6.0)", "torch (>=1,<3)", "transformers (>=4,<5)", "weaviate-client (>=3,<4)", "wikipedia (>=1,<2)", "wolframalpha (==5.0.0)"]
|
||||
azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (<2)"]
|
||||
clarifai = ["clarifai (>=9.1.0)"]
|
||||
cli = ["typer (>=0.9.0,<0.10.0)"]
|
||||
cohere = ["cohere (>=4,<5)"]
|
||||
docarray = ["docarray[hnswlib] (>=0.32.0,<0.33.0)"]
|
||||
embeddings = ["sentence-transformers (>=2,<3)"]
|
||||
extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.6.0,<0.7.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"]
|
||||
extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.6.0,<0.7.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"]
|
||||
javascript = ["esprima (>=4.0.1,<5.0.0)"]
|
||||
llms = ["clarifai (>=9.1.0)", "cohere (>=4,<5)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"]
|
||||
openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"]
|
||||
@@ -2344,64 +2310,6 @@ benchmarks = ["pytest-benchmark"]
|
||||
tests = ["duckdb", "ml_dtypes", "pandas (>=1.4,<2.1)", "polars[pandas,pyarrow]", "pytest", "semver", "tensorflow", "tqdm"]
|
||||
torch = ["torch"]
|
||||
|
||||
[[package]]
|
||||
name = "pymupdf"
|
||||
version = "1.23.6"
|
||||
description = "A high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:c4eb71b88a22c1008f764b3121b36a9d25340f9920b870508356050a365d9ca1"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:3ce2d3678dbf822cff213b1902f2e59756313e543efd516a2b4f15bb0353bd6c"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:2e27857a15c8a810d0b66455b8c8a79013640b6267a9b4ea808a5fe1f47711f2"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:5cd05700c8f18c9dafef63ac2ed3b1099ca06017ca0c32deea13093cea1b8671"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-win32.whl", hash = "sha256:951d280c1daafac2fd6a664b031f7f98b27eb2def55d39c92a19087bd8041c5d"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-win_amd64.whl", hash = "sha256:19d1711d5908c4527ad2deef5af2d066649f3f9a12950faf30be5f7251d18abc"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:3f0f9b76bc4f039e7587003cbd40684d93a98441549dd033cab38ca07d61988d"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e047571d799b30459ad7ee0bc6e68900a7f6b928876f956c976f279808814e72"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:1cbcf05c06f314fdf3042ceee674e9a0ac7fae598347d5442e2138c6046d4e82"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:e33f8ec5ba7265fe78b30332840b8f454184addfa79f9c27f160f19789aa5ffd"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-win32.whl", hash = "sha256:2c141f33e2733e48de8524dfd2de56d889feef0c7773b20a8cd216c03ab24793"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-win_amd64.whl", hash = "sha256:8fd9c4ee1dd4744a515b9190d8ba9133348b0d94c362293ed77726aa1c13b0a6"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:4d06751d5cd213e96f84f2faaa71a51cf4d641851e07579247ca1190121f173b"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:526b26a5207e923aab65877ad305644402851823a352cb92d362053426899354"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:0f852d125defc26716878b1796f4d68870e9065041d00cf46bde317fd8d30e68"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:5bdf7020b90987412381acc42427dd1b7a03d771ee9ec273de003e570164ec1a"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-win32.whl", hash = "sha256:e2d64799c6d9a3735be9e162a5d11061c0b7fbcb1e5fc7446e0993d0f815a93a"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-win_amd64.whl", hash = "sha256:c8ea81964c1433ea163ad4b53c56053a87a9ef6e1bd7a879d4d368a3988b60d1"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:761501a4965264e81acdd8f2224f993020bf24474e9b34fcdb5805a6826eda1c"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:fd8388e82b6045807d19addf310d8119d32908e89f76cc8bbf8cf1ec36fce947"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:4ac9673a6d6ee7e80cb242dacb43f9ca097b502d9c5e44687dbdffc2bce7961a"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:6e319c1f49476e07b9a12017c2d031687617713f8a46b7adcec03c636ed04607"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-win32.whl", hash = "sha256:1103eea4ab727e32b9cb93347b35f71562033018c333a7f3a17d115e980fea4a"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-win_amd64.whl", hash = "sha256:991a37e1cba43775ce094da87cf0bf72172a5532a09644003276bc8bfdfe9f1a"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:57725e15872f7ab67a9fb3e06e5384d1047b2121e85755c93a6d4266d3ca8983"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:224c341fe254adda97c8f06a4c5838cdbcf609fa89e70b1fb179752533378f2f"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:271bdf6059bb8347f9c9c6b721329bd353a933681b1fc62f43241b410e7ab7ae"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:57e22bea69690450197b34dcde16bd9fe0265ac4425b4033535ccc5c044246fb"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-win32.whl", hash = "sha256:2885a26220a32fb45ea443443b72194bb7107d6862d8d546b59e4ad0c8a1f2c9"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-win_amd64.whl", hash = "sha256:361cab1be45481bd3dc4e00ec82628ebc189b4f4b6fd9bd78a00cfeed54e0034"},
|
||||
{file = "PyMuPDF-1.23.6.tar.gz", hash = "sha256:618b8e884190ac1cca9df1c637f87669d2d532d421d4ee7e4763c848dc4f3a1e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
PyMuPDFb = "1.23.6"
|
||||
|
||||
[[package]]
|
||||
name = "pymupdfb"
|
||||
version = "1.23.6"
|
||||
description = "MuPDF shared libraries for PyMuPDF."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e5af77580aad3d1103aeec57009d156bfca429cecda14a17c573fcbe97bafb30"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9925816cbe3e05e920f9be925e5752c2eef42b793885b62075bb0f6a69178598"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:009e2cff166059e13bf71f93919e688f46b8fc11d122433574cfb0cc9134690e"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7132b30e6ad6ff2013344e3a481b2287fe0be3710d80694807dd6e0d8635f085"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-win32.whl", hash = "sha256:9d24ddadc204e895bee5000ddc7507c801643548e59f5a56aad6d32981d17eeb"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-win_amd64.whl", hash = "sha256:7bef75988e6979b10ca804cf9487f817aae43b0fff1c6e315b3b9ee0cf1cc32f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.4.3"
|
||||
@@ -3885,11 +3793,10 @@ multidict = ">=4.0"
|
||||
autogen = ["pyautogen"]
|
||||
dev = ["black", "datasets", "pre-commit", "pytest"]
|
||||
lancedb = ["lancedb"]
|
||||
legacy = ["faiss-cpu", "numpy"]
|
||||
local = ["huggingface-hub", "torch", "transformers"]
|
||||
postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.12,>=3.9"
|
||||
content-hash = "a1d04a1b10676fcb84fbce5440800706a2ae14cbe2a10bb7d59667b7c36b7709"
|
||||
content-hash = "61614071518e8b09eb7396b9f56caef3c08bd6c3c587c0048569d887e3d85601"
|
||||
|
||||
@@ -22,12 +22,7 @@ memgpt = "memgpt.main:app"
|
||||
python = "<3.12,>=3.9"
|
||||
typer = {extras = ["all"], version = "^0.9.0"}
|
||||
questionary = "^2.0.1"
|
||||
demjson3 = "^3.0.6"
|
||||
numpy = "^1.26.1"
|
||||
pytz = "^2023.3.post1"
|
||||
faiss-cpu = { version = "^1.7.4", optional = true }
|
||||
tiktoken = "^0.5.1"
|
||||
pymupdf = "^1.23.5"
|
||||
tqdm = "^4.66.1"
|
||||
black = { version = "^23.10.1", optional = true }
|
||||
pytest = { version = "^7.4.3", optional = true }
|
||||
@@ -49,10 +44,12 @@ docstring-parser = "^0.15"
|
||||
lancedb = {version = "^0.3.3", optional = true}
|
||||
httpx = "^0.25.2"
|
||||
pyautogen = {version = "0.1.14", optional = true}
|
||||
numpy = "^1.26.2"
|
||||
demjson3 = "^3.0.6"
|
||||
tiktoken = "^0.5.1"
|
||||
python-box = "^7.1.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
legacy = ["faiss-cpu", "numpy"]
|
||||
local = ["torch", "huggingface-hub", "transformers"]
|
||||
lancedb = ["lancedb"]
|
||||
postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"]
|
||||
|
||||
@@ -43,4 +43,3 @@ def test_save_load():
|
||||
if __name__ == "__main__":
|
||||
test_configure_memgpt()
|
||||
test_save_load()
|
||||
# test_legacy_cli_sequence()
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"])
|
||||
import pexpect
|
||||
|
||||
|
||||
TIMEOUT = 30 # seconds
|
||||
|
||||
|
||||
def test_legacy_cli_sequence():
|
||||
# Start the CLI process
|
||||
child = pexpect.spawn("memgpt --first --strip_ui")
|
||||
|
||||
child.expect("Continue with legacy CLI?", timeout=TIMEOUT)
|
||||
# Send 'Y' followed by newline
|
||||
child.sendline("Y")
|
||||
|
||||
# Since .memgpt is empty, should jump immediately to "Which model?"
|
||||
child.expect("Which model would you like to use?", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Which persona would you like MemGPT to use?", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Which user would you like to use?", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Would you like to preload anything into MemGPT's archival memory?", timeout=TIMEOUT)
|
||||
child.sendline() # Default No
|
||||
|
||||
child.expect("Testing messaging functionality", timeout=TIMEOUT)
|
||||
child.expect("Enter your message", timeout=TIMEOUT)
|
||||
child.sendline() # Send empty message
|
||||
|
||||
child.expect("Try again!", timeout=TIMEOUT) # Empty message
|
||||
child.sendline("/save")
|
||||
|
||||
child.expect("Saved checkpoint", timeout=TIMEOUT)
|
||||
child.sendline("/load")
|
||||
|
||||
child.expect("Loaded persistence manager", timeout=TIMEOUT)
|
||||
|
||||
child.sendline("/dump") # just testing no-crash
|
||||
# child.expect("", timeout=TIMEOUT)
|
||||
child.sendline("/dump 3") # just testing no-crash
|
||||
|
||||
child.sendline("/exit")
|
||||
child.expect("Finished.", timeout=TIMEOUT)
|
||||
|
||||
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
|
||||
child.close()
|
||||
assert child.isalive() is False, "CLI should have terminated."
|
||||
assert child.exitstatus == 0, "CLI did not exit cleanly."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_legacy_cli_sequence()
|
||||
@@ -1,18 +1,12 @@
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock
|
||||
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.server.websocket_interface import SyncWebSocketInterface
|
||||
import memgpt.presets as presets
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
import memgpt.utils as utils
|
||||
import memgpt.system as system
|
||||
from memgpt.persistence_manager import InMemoryStateManager
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
|
||||
|
||||
# def test_websockets():
|
||||
@@ -59,17 +53,20 @@ async def test_websockets():
|
||||
# Register the mock websocket as a client
|
||||
ws_interface.register_client(mock_websocket)
|
||||
|
||||
# Mock the persistence manager
|
||||
persistence_manager = InMemoryStateManager()
|
||||
|
||||
# Create an agent and hook it up to the WebSocket interface
|
||||
config = MemGPTConfig()
|
||||
|
||||
# Mock the persistence manager
|
||||
# create agents with defaults
|
||||
agent_config = AgentConfig(persona="sam_pov", human="basic", model="gpt-4-1106-preview")
|
||||
persistence_manager = LocalStateManager(agent_config=agent_config)
|
||||
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT_PRESET,
|
||||
config, # no agent config to provide
|
||||
"gpt-4-1106-preview",
|
||||
personas.get_persona_text("sam_pov"),
|
||||
humans.get_human_text("basic"),
|
||||
utils.get_persona_text("sam_pov"),
|
||||
utils.get_human_text("basic"),
|
||||
ws_interface,
|
||||
persistence_manager,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user