feat: create memgpt server with postgres DB with docker compose up (#1183)
This commit is contained in:
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
|
||||
- name: Run server tests
|
||||
env:
|
||||
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
|
||||
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:5432/memgpt
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MEMGPT_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
|
||||
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:5432/memgpt
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MEMGPT_SERVER_PASS: test_server_token
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
@@ -47,7 +47,7 @@ jobs:
|
||||
|
||||
- name: Run storage tests
|
||||
env:
|
||||
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
|
||||
PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:5432/memgpt
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MEMGPT_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
|
||||
@@ -3,7 +3,7 @@ repos:
|
||||
rev: v2.3.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
exclude: 'docs/.*|tests/data/.*'
|
||||
exclude: 'docs/.*|tests/data/.*|configs/.*'
|
||||
- id: end-of-file-fixer
|
||||
exclude: 'docs/.*|tests/data/.*'
|
||||
- id: trailing-whitespace
|
||||
|
||||
29
compose.yaml
Normal file
29
compose.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
version: '3.8'
|
||||
services:
|
||||
pgvector_db:
|
||||
image: ankane/pgvector:latest
|
||||
environment:
|
||||
- POSTGRES_USER=${MEMGPT_PG_USER}
|
||||
- POSTGRES_PASSWORD=${MEMGPT_PG_PASSWORD}
|
||||
- POSTGRES_DB=${MEMGPT_PG_DB}
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
- ./init.sql:/docker-entrypoint-initdb.d/init.sql
|
||||
ports:
|
||||
- "5432:5432"
|
||||
|
||||
memgpt_server:
|
||||
image: memgpt/memgpt-server:0.3.6
|
||||
depends_on:
|
||||
- pgvector_db
|
||||
environment:
|
||||
- POSTGRES_URI=postgresql://${MEMGPT_PG_USER}:${MEMGPT_PG_PASSWORD}@pgvector_db:5432/${MEMGPT_PG_DB}
|
||||
- MEMGPT_SERVER_PASS=${MEMGPT_SERVER_PASS} # memgpt server password
|
||||
volumes:
|
||||
- ./configs/server_config.yaml:/root/.memgpt/config # config file
|
||||
- ~/.memgpt/credentials:/root/.memgpt/credentials # credentials file
|
||||
ports:
|
||||
- "8083:8083"
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
38
configs/server_config.yaml
Normal file
38
configs/server_config.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
[defaults]
|
||||
preset = memgpt_chat
|
||||
persona = sam_pov
|
||||
human = basic
|
||||
|
||||
[model]
|
||||
model = gpt-4
|
||||
model_endpoint = https://api.openai.com/v1
|
||||
model_endpoint_type = openai
|
||||
context_window = 8192
|
||||
|
||||
[embedding]
|
||||
embedding_endpoint_type = openai
|
||||
embedding_endpoint = https://api.openai.com/v1
|
||||
embedding_model = text-embedding-ada-002
|
||||
embedding_dim = 1536
|
||||
embedding_chunk_size = 300
|
||||
|
||||
[archival_storage]
|
||||
type = postgres
|
||||
path = /root/.memgpt/chroma
|
||||
uri = postgresql://memgpt:memgpt@pgvector_db:5432/memgpt
|
||||
|
||||
[recall_storage]
|
||||
type = postgres
|
||||
path = /root/.memgpt
|
||||
uri = postgresql://memgpt:memgpt@pgvector_db:5432/memgpt
|
||||
|
||||
[metadata_storage]
|
||||
type = postgres
|
||||
path = /root/.memgpt
|
||||
uri = postgresql://memgpt:memgpt@pgvector_db:5432/memgpt
|
||||
|
||||
[version]
|
||||
memgpt_version = 0.3.7
|
||||
|
||||
[client]
|
||||
anon_clientid = 00000000-0000-0000-0000-000000000000
|
||||
@@ -4,7 +4,7 @@ docker build -f db/Dockerfile.simple -t pg-test .
|
||||
# run container
|
||||
docker run -d --rm \
|
||||
--name memgpt-db-test \
|
||||
-p 8888:5432 \
|
||||
-p 5432:5432 \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-v memgpt_db_test:/var/lib/postgresql/data \
|
||||
pg-test:latest
|
||||
|
||||
37
init.sql
Normal file
37
init.sql
Normal file
@@ -0,0 +1,37 @@
|
||||
-- Title: Init MemGPT Database
|
||||
|
||||
-- Fetch the docker secrets, if they are available.
|
||||
-- Otherwise fall back to environment variables, or hardwired 'memgpt'
|
||||
\set db_user `([ -r /var/run/secrets/memgpt-user ] && cat /var/run/secrets/memgpt-user) || echo "${POSTGRES_USER:-memgpt}"`
|
||||
\set db_password `([ -r /var/run/secrets/memgpt-password ] && cat /var/run/secrets/memgpt-password) || echo "${POSTGRES_PASSWORD:-memgpt}"`
|
||||
\set db_name `([ -r /var/run/secrets/memgpt-db ] && cat /var/run/secrets/memgpt-db) || echo "${POSTGRES_DB:-memgpt}"`
|
||||
|
||||
|
||||
-- CREATE USER :"db_user"
|
||||
-- WITH PASSWORD :'db_password'
|
||||
-- NOCREATEDB
|
||||
-- NOCREATEROLE
|
||||
-- ;
|
||||
--
|
||||
-- CREATE DATABASE :"db_name"
|
||||
-- WITH
|
||||
-- OWNER = :"db_user"
|
||||
-- ENCODING = 'UTF8'
|
||||
-- LC_COLLATE = 'en_US.utf8'
|
||||
-- LC_CTYPE = 'en_US.utf8'
|
||||
-- LOCALE_PROVIDER = 'libc'
|
||||
-- TABLESPACE = pg_default
|
||||
-- CONNECTION LIMIT = -1;
|
||||
|
||||
-- Set up our schema and extensions in our new database.
|
||||
\c :"db_name"
|
||||
|
||||
CREATE SCHEMA :"db_name"
|
||||
AUTHORIZATION :"db_user";
|
||||
|
||||
ALTER DATABASE :"db_name"
|
||||
SET search_path TO :"db_name";
|
||||
|
||||
CREATE EXTENSION IF NOT EXISTS vector WITH SCHEMA :"db_name";
|
||||
|
||||
DROP SCHEMA IF EXISTS public CASCADE;
|
||||
@@ -437,17 +437,29 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
# create table
|
||||
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id)
|
||||
self.engine = create_engine(self.uri)
|
||||
|
||||
# construct URI from enviornment variables
|
||||
db = os.getenv("MEMGPT_PG_DB", "memgpt")
|
||||
user = os.getenv("MEMGPT_PG_USER", "memgpt")
|
||||
password = os.getenv("MEMGPT_PG_PASSWORD", "memgpt")
|
||||
port = os.getenv("MEMGPT_PG_PORT", "5432")
|
||||
url = os.getenv("MEMGPT_PG_URL", "localhost")
|
||||
uri = f"postgresql+pg8000://{user}:{password}@{url}:{port}/{db}"
|
||||
|
||||
# create engine
|
||||
self.engine = create_engine(uri)
|
||||
|
||||
for c in self.db_model.__table__.columns:
|
||||
if c.name == "embedding":
|
||||
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
|
||||
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
|
||||
self.session_maker = sessionmaker(bind=self.engine)
|
||||
with self.session_maker() as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
# create table
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
|
||||
filters = self.get_filters(filters)
|
||||
with self.session_maker() as session:
|
||||
|
||||
@@ -24,7 +24,7 @@ from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT
|
||||
from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig, User, Passage, Preset
|
||||
from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig, User, Passage
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
|
||||
@@ -301,6 +301,15 @@ def server(
|
||||
# # Add the handler to the logger
|
||||
# server_logger.addHandler(stream_handler)
|
||||
|
||||
# override config with postgres enviornment (messy, but necessary for docker compose)
|
||||
if os.getenv("POSTGRES_URI"):
|
||||
config = MemGPTConfig.load()
|
||||
config.archival_storage_uri = os.getenv("POSTGRES_URI")
|
||||
config.recall_storage_uri = os.getenv("POSTGRES_URI")
|
||||
config.metadata_storage_uri = os.getenv("POSTGRES_URI")
|
||||
print(f"Overriding DB config URI with enviornment variable: {config.archival_storage_uri}")
|
||||
config.save()
|
||||
|
||||
if type == ServerChoice.rest_api:
|
||||
import uvicorn
|
||||
from memgpt.server.rest_api.server import app
|
||||
@@ -643,7 +652,8 @@ def run(
|
||||
# create agent
|
||||
try:
|
||||
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
|
||||
preset_override = False
|
||||
human_obj = ms.get_human(human, user.id)
|
||||
persona_obj = ms.get_persona(persona, user.id)
|
||||
if preset_obj is None:
|
||||
# create preset records in metadata store
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
@@ -654,28 +664,14 @@ def run(
|
||||
if preset_obj is None:
|
||||
typer.secho("Couldn't find presets in database, please run `memgpt configure`", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
|
||||
human_obj = ms.get_human(human, user.id)
|
||||
if human_obj is None:
|
||||
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
|
||||
persona_obj = ms.get_persona(persona, user.id)
|
||||
if persona_obj is None:
|
||||
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
|
||||
|
||||
# Overwrite fields in the preset if they were specified
|
||||
if human_obj.text != preset_obj.human:
|
||||
preset_override = True
|
||||
preset_obj.human = human_obj.text
|
||||
if persona_obj.text != preset_obj.human:
|
||||
preset_override = True
|
||||
preset_obj.persona = persona_obj.text
|
||||
|
||||
# If the user overrode any parts of the preset, we need to create a new preset to refer back to
|
||||
if preset_override:
|
||||
# Change the name and uuid
|
||||
preset_obj = Preset.clone(preset_obj=preset_obj)
|
||||
# Then write out to the database for storage
|
||||
ms.create_preset(preset=preset_obj)
|
||||
preset_obj.human = ms.get_human(human, user.id).text
|
||||
preset_obj.persona = ms.get_persona(persona, user.id).text
|
||||
|
||||
typer.secho(f"-> 🤖 Using persona profile: '{preset_obj.persona_name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🧑 Using human profile: '{preset_obj.human_name}'", fg=typer.colors.WHITE)
|
||||
|
||||
@@ -304,7 +304,13 @@ class MetadataStore:
|
||||
def __init__(self, config: MemGPTConfig):
|
||||
# TODO: get DB URI or path
|
||||
if config.metadata_storage_type == "postgres":
|
||||
self.uri = config.metadata_storage_uri
|
||||
# construct URI from enviornment variables
|
||||
db = os.getenv("MEMGPT_PG_DB", "memgpt")
|
||||
user = os.getenv("MEMGPT_PG_USER", "memgpt")
|
||||
password = os.getenv("MEMGPT_PG_PASSWORD", "memgpt")
|
||||
port = os.getenv("MEMGPT_PG_PORT", "5432")
|
||||
url = os.getenv("MEMGPT_PG_URL", "localhost")
|
||||
self.uri = f"postgresql+pg8000://{user}:{password}@{url}:{port}/{db}"
|
||||
elif config.metadata_storage_type == "sqlite":
|
||||
path = os.path.join(config.metadata_storage_path, "sqlite.db")
|
||||
self.uri = f"sqlite:///{path}"
|
||||
|
||||
@@ -38,6 +38,16 @@ Start the server with:
|
||||
cd memgpt/server/rest_api
|
||||
poetry run uvicorn server:app --reload
|
||||
"""
|
||||
# override config with postgres enviornment (messy, but necessary for docker compose)
|
||||
# TODO: do something less gross
|
||||
if os.getenv("POSTGRES_URI"):
|
||||
config = MemGPTConfig.load()
|
||||
config.archival_storage_uri = os.getenv("POSTGRES_URI")
|
||||
config.recall_storage_uri = os.getenv("POSTGRES_URI")
|
||||
config.metadata_storage_uri = os.getenv("POSTGRES_URI")
|
||||
print(f"Overriding DB config URI with enviornment variable: {config.archival_storage_uri}")
|
||||
config.save()
|
||||
|
||||
|
||||
interface: QueuingInterface = QueuingInterface()
|
||||
server: SyncServer = SyncServer(default_interface=interface)
|
||||
@@ -54,6 +64,7 @@ else:
|
||||
password = secrets.token_urlsafe(16)
|
||||
print(f"Generated admin server password for this session: {password}")
|
||||
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import time
|
||||
import threading
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.data_types import Preset # TODO move to PresetModel
|
||||
@@ -29,7 +28,8 @@ client = None
|
||||
test_agent_state_post_message = None
|
||||
test_user_id = uuid.uuid4()
|
||||
|
||||
test_base_url = "http://localhost:8283"
|
||||
local_service_url = "http://localhost:8283"
|
||||
docker_compose_url = "http://localhost:8083"
|
||||
|
||||
# admin credentials
|
||||
test_server_token = "test_server_token"
|
||||
@@ -38,6 +38,7 @@ test_server_token = "test_server_token"
|
||||
def run_server():
|
||||
import uvicorn
|
||||
from memgpt.server.rest_api.server import app
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -99,45 +100,39 @@ def run_server():
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def start_uvicorn_server():
|
||||
"""Starts Uvicorn server in a background thread."""
|
||||
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
print("Starting server...")
|
||||
time.sleep(5)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def user_token():
|
||||
# Setup: Create a user via the client before the tests
|
||||
|
||||
admin = Admin(test_base_url, test_server_token)
|
||||
response = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
user_id = response.user_id
|
||||
token = response.api_key
|
||||
|
||||
yield token
|
||||
|
||||
# Teardown: Delete the user after the test (or after all tests if fixture scope is module/class)
|
||||
admin.delete_user(test_user_id) # Adjust as per your client's method
|
||||
|
||||
|
||||
# Fixture to create clients with different configurations
|
||||
# @pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module")
|
||||
@pytest.fixture(params=[{"base_url": test_base_url}], scope="module")
|
||||
def client(request, user_token):
|
||||
# use token or not
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
{"base_url": local_service_url},
|
||||
# {"base_url": docker_compose_url}, # TODO: add when docker compose added to tests
|
||||
# {"base_url": None} # TODO: add when implemented
|
||||
],
|
||||
scope="module",
|
||||
)
|
||||
# @pytest.fixture(params=[{"base_url": test_base_url}], scope="module")
|
||||
def client(request):
|
||||
if request.param["base_url"]:
|
||||
token = user_token
|
||||
if request.param["base_url"] == local_service_url:
|
||||
# start server
|
||||
print("Starting server...")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5)
|
||||
|
||||
admin = Admin(local_service_url, test_server_token)
|
||||
response = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
user_id = response.user_id
|
||||
token = response.api_key
|
||||
else:
|
||||
token = None
|
||||
|
||||
client = create_client(**request.param, token=token) # This yields control back to the test function
|
||||
yield client
|
||||
|
||||
# cleanup user
|
||||
if request.param["base_url"]:
|
||||
admin.delete_user(test_user_id) # Adjust as per your client's method
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture(scope="module")
|
||||
|
||||
Reference in New Issue
Block a user