fix: fix DB session management to avoid connection overflow error (#1758)

This commit is contained in:
Sarah Wooders
2024-09-18 13:14:19 -07:00
committed by GitHub
parent 930ab946f2
commit da9bbeada0
5 changed files with 271 additions and 238 deletions

View File

@@ -32,5 +32,8 @@ type = postgres
path = /root/.memgpt
uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt
[version]
memgpt_version = 0.4.0
[client]
anon_clientid = 00000000-0000-0000-0000-000000000000

View File

@@ -21,18 +21,20 @@ class MemGPTUser(HttpUser):
# Create a user and get the token
self.client.headers = {"Authorization": "Bearer password"}
user_data = {"name": f"User-{''.join(random.choices(string.ascii_lowercase + string.digits, k=8))}"}
response = self.client.post("/admin/users", json=user_data)
response = self.client.post("/v1/admin/users", json=user_data)
response_json = response.json()
print(response_json)
self.user_id = response_json["id"]
# create a token
response = self.client.post("/admin/users/keys", json={"user_id": self.user_id})
response = self.client.post("/v1/admin/users/keys", json={"user_id": self.user_id})
self.token = response.json()["key"]
# reset to use user token as headers
self.client.headers = {"Authorization": f"Bearer {self.token}"}
# @task(1)
# def create_agent(self):
# generate random name
name = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
request = CreateAgent(
@@ -42,7 +44,7 @@ class MemGPTUser(HttpUser):
)
# create an agent
with self.client.post("/api/agents", json=request.model_dump(), headers=self.client.headers, catch_response=True) as response:
with self.client.post("/v1/agents", json=request.model_dump(), headers=self.client.headers, catch_response=True) as response:
if response.status_code != 200:
response.failure(f"Failed to create agent: {response.text}")
@@ -57,10 +59,10 @@ class MemGPTUser(HttpUser):
request = MemGPTRequest(messages=messages, stream_steps=False, stream_tokens=False, return_message_object=False)
with self.client.post(
f"/api/agents/{self.agent_id}/messages", json=request.model_dump(), headers=self.client.headers, catch_response=True
f"/v1/agents/{self.agent_id}/messages", json=request.model_dump(), headers=self.client.headers, catch_response=True
) as response:
if response.status_code != 200:
response.failure(f"Failed to send message: {response.text}")
response.failure(f"Failed to send message {response.status_code}: {response.text}")
response = MemGPTResponse(**response.json())
print("Response", response.usage)

View File

@@ -13,13 +13,12 @@ from sqlalchemy import (
TypeDecorator,
and_,
asc,
create_engine,
desc,
or_,
select,
text,
)
from sqlalchemy.orm import declarative_base, mapped_column, sessionmaker
from sqlalchemy.orm import declarative_base, mapped_column
from sqlalchemy.orm.session import close_all_sessions
from sqlalchemy.sql import func
from sqlalchemy_json import MutableJson
@@ -36,6 +35,9 @@ from memgpt.schemas.openai.chat_completions import ToolCall
from memgpt.schemas.passage import Passage
from memgpt.settings import settings
Base = declarative_base()
config = MemGPTConfig()
class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""
@@ -66,149 +68,116 @@ class CommonVector(TypeDecorator):
return np.frombuffer(value, dtype=np.float32)
# Custom serialization / de-serialization for JSON columns
class MessageModel(Base):
"""Defines data model for storing Message objects"""
__tablename__ = "messages"
__table_args__ = {"extend_existing": True}
# Assuming message_id is the primary key
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
# openai info
role = Column(String, nullable=False)
text = Column(String) # optional: can be null if function call
model = Column(String) # optional: can be null if LLM backend doesn't require specifying
name = Column(String) # optional: multi-agent only
# tool call request info
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
# TODO align with OpenAI spec of multiple tool calls
# tool_calls = Column(ToolCallColumn)
tool_calls = Column(ToolCallColumn)
# tool call response info
# if role == "tool", then this must be specified
# if role != "tool", this must be null
tool_call_id = Column(String)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
Index("message_idx_user", user_id, agent_id),
def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}')>"
def to_record(self):
# calls = (
# [ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
# if self.tool_calls
# else None
# )
# if calls:
# assert isinstance(calls[0], ToolCall)
if self.tool_calls and len(self.tool_calls) > 0:
assert isinstance(self.tool_calls[0], ToolCall), type(self.tool_calls[0])
for tool in self.tool_calls:
assert isinstance(tool, ToolCall), type(tool)
return Message(
user_id=self.user_id,
agent_id=self.agent_id,
role=self.role,
name=self.name,
text=self.text,
model=self.model,
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id,
created_at=self.created_at,
id=self.id,
)
Base = declarative_base()
class PassageModel(Base):
"""Defines data model for storing Passages (consisting of text, embedding)"""
__tablename__ = "passages"
__table_args__ = {"extend_existing": True}
def get_db_model(
config: MemGPTConfig,
table_name: str,
table_type: TableType,
user_id: str,
agent_id: Optional[str] = None,
dialect="postgresql",
):
# Define a helper function to create or get the model class
def create_or_get_model(class_name, base_model, table_name):
if class_name in globals():
return globals()[class_name]
Model = type(class_name, (base_model,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
globals()[class_name] = Model
return Model
# Assuming passage_id is the primary key
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
text = Column(String)
doc_id = Column(String)
agent_id = Column(String)
source_id = Column(String)
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
# create schema for archival memory
class PassageModel(Base):
"""Defines data model for storing Passages (consisting of text, embedding)"""
__abstract__ = True # this line is necessary
# Assuming passage_id is the primary key
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
text = Column(String)
doc_id = Column(String)
agent_id = Column(String)
source_id = Column(String)
# vector storage
if dialect == "sqlite":
embedding = Column(CommonVector)
else:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_config = Column(EmbeddingConfigColumn)
metadata_ = Column(MutableJson)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
Index("passage_idx_user", user_id, agent_id, doc_id),
def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Passage(
text=self.text,
embedding=self.embedding,
embedding_config=self.embedding_config,
doc_id=self.doc_id,
user_id=self.user_id,
id=self.id,
source_id=self.source_id,
agent_id=self.agent_id,
metadata_=self.metadata_,
created_at=self.created_at,
)
"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model" + dialect
return create_or_get_model(class_name, PassageModel, table_name)
elif table_type == TableType.RECALL_MEMORY:
class MessageModel(Base):
"""Defines data model for storing Message objects"""
__abstract__ = True # this line is necessary
# Assuming message_id is the primary key
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
# openai info
role = Column(String, nullable=False)
text = Column(String) # optional: can be null if function call
model = Column(String) # optional: can be null if LLM backend doesn't require specifying
name = Column(String) # optional: multi-agent only
# tool call request info
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
# TODO align with OpenAI spec of multiple tool calls
# tool_calls = Column(ToolCallColumn)
tool_calls = Column(ToolCallColumn)
# tool call response info
# if role == "tool", then this must be specified
# if role != "tool", this must be null
tool_call_id = Column(String)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
Index("message_idx_user", user_id, agent_id),
def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}')>"
def to_record(self):
# calls = (
# [ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
# if self.tool_calls
# else None
# )
# if calls:
# assert isinstance(calls[0], ToolCall)
if self.tool_calls and len(self.tool_calls) > 0:
assert isinstance(self.tool_calls[0], ToolCall), type(self.tool_calls[0])
for tool in self.tool_calls:
assert isinstance(tool, ToolCall), type(tool)
return Message(
user_id=self.user_id,
agent_id=self.agent_id,
role=self.role,
name=self.name,
text=self.text,
model=self.model,
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id,
created_at=self.created_at,
id=self.id,
)
"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model" + dialect
return create_or_get_model(class_name, MessageModel, table_name)
# vector storage
if settings.memgpt_pg_uri_no_default:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
elif config.archival_storage_type == "sqlite" or config.archival_storage_type == "chroma":
embedding = Column(CommonVector)
else:
raise ValueError(f"Table type {table_type} not implemented")
raise ValueError(f"Unsupported archival_storage_type: {config.archival_storage_type}")
embedding_config = Column(EmbeddingConfigColumn)
metadata_ = Column(MutableJson)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
Index("passage_idx_user", user_id, agent_id, doc_id),
def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Passage(
text=self.text,
embedding=self.embedding,
embedding_config=self.embedding_config,
doc_id=self.doc_id,
user_id=self.user_id,
id=self.id,
source_id=self.source_id,
agent_id=self.agent_id,
metadata_=self.metadata_,
created_at=self.created_at,
)
class SQLStorageConnector(StorageConnector):
@@ -386,9 +355,6 @@ class PostgresStorageConnector(SQLStorageConnector):
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
# create table
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id)
# construct URI from enviornment variables
if settings.pg_uri:
self.uri = settings.pg_uri
@@ -397,29 +363,29 @@ class PostgresStorageConnector(SQLStorageConnector):
# TODO: remove this eventually (config should NOT contain URI)
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
self.uri = self.config.archival_storage_uri
self.db_model = PassageModel
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
self.db_model = MessageModel
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
# create engine
self.engine = create_engine(self.uri)
for c in self.db_model.__table__.columns:
if c.name == "embedding":
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
self.session_maker = sessionmaker(bind=self.engine)
from memgpt.server.server import db_context
self.session_maker = db_context
# TODO: move to DB init
with self.session_maker() as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
# create table
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
filters = self.get_filters(filters)
with self.session_maker() as session:
@@ -432,31 +398,40 @@ class PostgresStorageConnector(SQLStorageConnector):
return records
def insert_many(self, records, exists_ok=True, show_progress=False):
from sqlalchemy.dialects.postgresql import insert
pass
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
if len(records) == 0:
return
if isinstance(records[0], Passage):
with self.engine.connect() as conn:
db_records = [vars(record) for record in records]
stmt = insert(self.db_model.__table__).values(db_records)
if exists_ok:
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
)
conn.execute(upsert_stmt)
added_ids = [] # avoid adding duplicates
# NOTE: this has not great performance due to the excessive commits
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
# db_record = self.db_model(**vars(record))
if record.id in added_ids:
continue
existing_record = session.query(self.db_model).filter_by(id=record.id).first()
if existing_record:
if exists_ok:
fields = record.model_dump()
fields.pop("id")
session.query(self.db_model).filter(self.db_model.id == record.id).update(fields)
print(f"Updated record with id {record.id}")
session.commit()
else:
raise ValueError(f"Record with id {record.id} already exists.")
else:
conn.execute(stmt)
conn.commit()
else:
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
# db_record = self.db_model(**vars(record))
db_record = self.db_model(**record.dict())
session.add(db_record)
session.commit()
print(f"Added record with id {record.id}")
session.commit()
added_ids.append(record.id)
def insert(self, record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)
@@ -515,16 +490,15 @@ class SQLLiteStorageConnector(SQLStorageConnector):
self.path = self.config.recall_storage_path
if self.path is None:
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
self.db_model = MessageModel
else:
raise ValueError(f"Table type {table_type} not implemented")
self.path = os.path.join(self.path, f"sqlite.db")
# Create the SQLAlchemy engine
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id, dialect="sqlite")
self.engine = create_engine(f"sqlite:///{self.path}")
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
self.session_maker = sessionmaker(bind=self.engine)
from memgpt.server.server import db_context
self.session_maker = db_context
# import sqlite3
@@ -532,31 +506,18 @@ class SQLLiteStorageConnector(SQLStorageConnector):
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
def insert_many(self, records, exists_ok=True, show_progress=False):
from sqlalchemy.dialects.sqlite import insert
pass
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
if len(records) == 0:
return
if isinstance(records[0], Passage):
with self.engine.connect() as conn:
db_records = [vars(record) for record in records]
stmt = insert(self.db_model.__table__).values(db_records)
if exists_ok:
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
)
conn.execute(upsert_stmt)
else:
conn.execute(stmt)
conn.commit()
else:
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
# db_record = self.db_model(**vars(record))
db_record = self.db_model(**record.dict())
session.add(db_record)
session.commit()
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
# db_record = self.db_model(**vars(record))
db_record = self.db_model(**record.dict())
session.add(db_record)
session.commit()
def insert(self, record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)

View File

@@ -2,7 +2,6 @@
import os
import secrets
import traceback
from typing import List, Optional
from sqlalchemy import (
@@ -14,12 +13,10 @@ from sqlalchemy import (
Index,
String,
TypeDecorator,
create_engine,
desc,
func,
)
from sqlalchemy.exc import InterfaceError, OperationalError
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql import func
from memgpt.config import MemGPTConfig
@@ -52,7 +49,9 @@ class LLMConfigColumn(TypeDecorator):
def process_bind_param(self, value, dialect):
if value:
return vars(value)
# return vars(value)
if isinstance(value, LLMConfig):
return value.model_dump()
return value
def process_result_value(self, value, dialect):
@@ -72,7 +71,9 @@ class EmbeddingConfigColumn(TypeDecorator):
def process_bind_param(self, value, dialect):
if value:
return vars(value)
# return vars(value)
if isinstance(value, EmbeddingConfig):
return value.model_dump()
return value
def process_result_value(self, value, dialect):
@@ -399,41 +400,45 @@ class MetadataStore:
# Ensure valid URI
assert self.uri, "Database URI is not provided or is invalid."
# Check if tables need to be created
self.engine = create_engine(self.uri)
try:
Base.metadata.create_all(
self.engine,
tables=[
UserModel.__table__,
AgentModel.__table__,
SourceModel.__table__,
AgentSourceMappingModel.__table__,
APIKeyModel.__table__,
BlockModel.__table__,
ToolModel.__table__,
JobModel.__table__,
],
)
except (InterfaceError, OperationalError) as e:
traceback.print_exc()
if config.metadata_storage_type == "postgres":
raise ValueError(
f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. "
+ "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). "
+ "\npostgres detected: Make sure the postgres database is running (https://memgpt.readme.io/docs/storage#postgres)."
)
elif config.metadata_storage_type == "sqlite":
raise ValueError(
f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. "
+ "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). "
+ "\nsqlite detected: Make sure that the sqlite.db file exists at the URI."
)
else:
raise e
except:
raise
self.session_maker = sessionmaker(bind=self.engine)
from memgpt.server.server import db_context
self.session_maker = db_context
# # Check if tables need to be created
# self.engine = create_engine(self.uri)
# try:
# Base.metadata.create_all(
# self.engine,
# tables=[
# UserModel.__table__,
# AgentModel.__table__,
# SourceModel.__table__,
# AgentSourceMappingModel.__table__,
# APIKeyModel.__table__,
# BlockModel.__table__,
# ToolModel.__table__,
# JobModel.__table__,
# ],
# )
# except (InterfaceError, OperationalError) as e:
# traceback.print_exc()
# if config.metadata_storage_type == "postgres":
# raise ValueError(
# f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. "
# + "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). "
# + "\npostgres detected: Make sure the postgres database is running (https://memgpt.readme.io/docs/storage#postgres)."
# )
# elif config.metadata_storage_type == "sqlite":
# raise ValueError(
# f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. "
# + "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). "
# + "\nsqlite detected: Make sure that the sqlite.db file exists at the URI."
# )
# else:
# raise e
# except:
# raise
# self.session_maker = sessionmaker(bind=self.engine)
@enforce_types
def create_api_key(self, user_id: str, name: str) -> APIKey:

View File

@@ -133,6 +133,68 @@ class Server(object):
raise NotImplementedError
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from memgpt.agent_store.db import MessageModel, PassageModel
from memgpt.config import MemGPTConfig
# NOTE: hack to see if single session management works
from memgpt.metadata import (
AgentModel,
AgentSourceMappingModel,
APIKeyModel,
BlockModel,
JobModel,
SourceModel,
ToolModel,
UserModel,
)
from memgpt.settings import settings
config = MemGPTConfig.load()
# determine the storage type
if config.recall_storage_type == "postgres":
engine = create_engine(settings.memgpt_pg_uri)
elif config.recall_storage_type == "sqlite":
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))
else:
raise ValueError(f"Unknown recall_storage_type: {config.recall_storage_type}")
Base = declarative_base()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(
engine,
tables=[
UserModel.__table__,
AgentModel.__table__,
SourceModel.__table__,
AgentSourceMappingModel.__table__,
APIKeyModel.__table__,
BlockModel.__table__,
ToolModel.__table__,
JobModel.__table__,
PassageModel.__table__,
MessageModel.__table__,
],
)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
from contextlib import contextmanager
db_context = contextmanager(get_db)
class SyncServer(Server):
"""Simple single-threaded / blocking server process"""