fix: fix DB session management to avoid connection overflow error (#1758)
This commit is contained in:
@@ -32,5 +32,8 @@ type = postgres
|
||||
path = /root/.memgpt
|
||||
uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt
|
||||
|
||||
[version]
|
||||
memgpt_version = 0.4.0
|
||||
|
||||
[client]
|
||||
anon_clientid = 00000000-0000-0000-0000-000000000000
|
||||
|
||||
@@ -21,18 +21,20 @@ class MemGPTUser(HttpUser):
|
||||
# Create a user and get the token
|
||||
self.client.headers = {"Authorization": "Bearer password"}
|
||||
user_data = {"name": f"User-{''.join(random.choices(string.ascii_lowercase + string.digits, k=8))}"}
|
||||
response = self.client.post("/admin/users", json=user_data)
|
||||
response = self.client.post("/v1/admin/users", json=user_data)
|
||||
response_json = response.json()
|
||||
print(response_json)
|
||||
self.user_id = response_json["id"]
|
||||
|
||||
# create a token
|
||||
response = self.client.post("/admin/users/keys", json={"user_id": self.user_id})
|
||||
response = self.client.post("/v1/admin/users/keys", json={"user_id": self.user_id})
|
||||
self.token = response.json()["key"]
|
||||
|
||||
# reset to use user token as headers
|
||||
self.client.headers = {"Authorization": f"Bearer {self.token}"}
|
||||
|
||||
# @task(1)
|
||||
# def create_agent(self):
|
||||
# generate random name
|
||||
name = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
|
||||
request = CreateAgent(
|
||||
@@ -42,7 +44,7 @@ class MemGPTUser(HttpUser):
|
||||
)
|
||||
|
||||
# create an agent
|
||||
with self.client.post("/api/agents", json=request.model_dump(), headers=self.client.headers, catch_response=True) as response:
|
||||
with self.client.post("/v1/agents", json=request.model_dump(), headers=self.client.headers, catch_response=True) as response:
|
||||
if response.status_code != 200:
|
||||
response.failure(f"Failed to create agent: {response.text}")
|
||||
|
||||
@@ -57,10 +59,10 @@ class MemGPTUser(HttpUser):
|
||||
request = MemGPTRequest(messages=messages, stream_steps=False, stream_tokens=False, return_message_object=False)
|
||||
|
||||
with self.client.post(
|
||||
f"/api/agents/{self.agent_id}/messages", json=request.model_dump(), headers=self.client.headers, catch_response=True
|
||||
f"/v1/agents/{self.agent_id}/messages", json=request.model_dump(), headers=self.client.headers, catch_response=True
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
response.failure(f"Failed to send message: {response.text}")
|
||||
response.failure(f"Failed to send message {response.status_code}: {response.text}")
|
||||
|
||||
response = MemGPTResponse(**response.json())
|
||||
print("Response", response.usage)
|
||||
|
||||
@@ -13,13 +13,12 @@ from sqlalchemy import (
|
||||
TypeDecorator,
|
||||
and_,
|
||||
asc,
|
||||
create_engine,
|
||||
desc,
|
||||
or_,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, mapped_column, sessionmaker
|
||||
from sqlalchemy.orm import declarative_base, mapped_column
|
||||
from sqlalchemy.orm.session import close_all_sessions
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy_json import MutableJson
|
||||
@@ -36,6 +35,9 @@ from memgpt.schemas.openai.chat_completions import ToolCall
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.settings import settings
|
||||
|
||||
Base = declarative_base()
|
||||
config = MemGPTConfig()
|
||||
|
||||
|
||||
class CommonVector(TypeDecorator):
|
||||
"""Common type for representing vectors in SQLite"""
|
||||
@@ -66,149 +68,116 @@ class CommonVector(TypeDecorator):
|
||||
return np.frombuffer(value, dtype=np.float32)
|
||||
|
||||
|
||||
# Custom serialization / de-serialization for JSON columns
|
||||
class MessageModel(Base):
|
||||
"""Defines data model for storing Message objects"""
|
||||
|
||||
__tablename__ = "messages"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
|
||||
# openai info
|
||||
role = Column(String, nullable=False)
|
||||
text = Column(String) # optional: can be null if function call
|
||||
model = Column(String) # optional: can be null if LLM backend doesn't require specifying
|
||||
name = Column(String) # optional: multi-agent only
|
||||
|
||||
# tool call request info
|
||||
# if role == "assistant", this MAY be specified
|
||||
# if role != "assistant", this must be null
|
||||
# TODO align with OpenAI spec of multiple tool calls
|
||||
# tool_calls = Column(ToolCallColumn)
|
||||
tool_calls = Column(ToolCallColumn)
|
||||
|
||||
# tool call response info
|
||||
# if role == "tool", then this must be specified
|
||||
# if role != "tool", this must be null
|
||||
tool_call_id = Column(String)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
Index("message_idx_user", user_id, agent_id),
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}')>"
|
||||
|
||||
def to_record(self):
|
||||
# calls = (
|
||||
# [ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
|
||||
# if self.tool_calls
|
||||
# else None
|
||||
# )
|
||||
# if calls:
|
||||
# assert isinstance(calls[0], ToolCall)
|
||||
if self.tool_calls and len(self.tool_calls) > 0:
|
||||
assert isinstance(self.tool_calls[0], ToolCall), type(self.tool_calls[0])
|
||||
for tool in self.tool_calls:
|
||||
assert isinstance(tool, ToolCall), type(tool)
|
||||
return Message(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
role=self.role,
|
||||
name=self.name,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls,
|
||||
tool_call_id=self.tool_call_id,
|
||||
created_at=self.created_at,
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
class PassageModel(Base):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
|
||||
__tablename__ = "passages"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
def get_db_model(
|
||||
config: MemGPTConfig,
|
||||
table_name: str,
|
||||
table_type: TableType,
|
||||
user_id: str,
|
||||
agent_id: Optional[str] = None,
|
||||
dialect="postgresql",
|
||||
):
|
||||
# Define a helper function to create or get the model class
|
||||
def create_or_get_model(class_name, base_model, table_name):
|
||||
if class_name in globals():
|
||||
return globals()[class_name]
|
||||
Model = type(class_name, (base_model,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
|
||||
globals()[class_name] = Model
|
||||
return Model
|
||||
# Assuming passage_id is the primary key
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
text = Column(String)
|
||||
doc_id = Column(String)
|
||||
agent_id = Column(String)
|
||||
source_id = Column(String)
|
||||
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
# create schema for archival memory
|
||||
class PassageModel(Base):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming passage_id is the primary key
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
text = Column(String)
|
||||
doc_id = Column(String)
|
||||
agent_id = Column(String)
|
||||
source_id = Column(String)
|
||||
|
||||
# vector storage
|
||||
if dialect == "sqlite":
|
||||
embedding = Column(CommonVector)
|
||||
else:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
|
||||
embedding_config = Column(EmbeddingConfigColumn)
|
||||
metadata_ = Column(MutableJson)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
Index("passage_idx_user", user_id, agent_id, doc_id),
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
embedding_config=self.embedding_config,
|
||||
doc_id=self.doc_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
source_id=self.source_id,
|
||||
agent_id=self.agent_id,
|
||||
metadata_=self.metadata_,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
||||
"""Create database model for table_name"""
|
||||
class_name = f"{table_name.capitalize()}Model" + dialect
|
||||
return create_or_get_model(class_name, PassageModel, table_name)
|
||||
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
|
||||
class MessageModel(Base):
|
||||
"""Defines data model for storing Message objects"""
|
||||
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
|
||||
# openai info
|
||||
role = Column(String, nullable=False)
|
||||
text = Column(String) # optional: can be null if function call
|
||||
model = Column(String) # optional: can be null if LLM backend doesn't require specifying
|
||||
name = Column(String) # optional: multi-agent only
|
||||
|
||||
# tool call request info
|
||||
# if role == "assistant", this MAY be specified
|
||||
# if role != "assistant", this must be null
|
||||
# TODO align with OpenAI spec of multiple tool calls
|
||||
# tool_calls = Column(ToolCallColumn)
|
||||
tool_calls = Column(ToolCallColumn)
|
||||
|
||||
# tool call response info
|
||||
# if role == "tool", then this must be specified
|
||||
# if role != "tool", this must be null
|
||||
tool_call_id = Column(String)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
Index("message_idx_user", user_id, agent_id),
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}')>"
|
||||
|
||||
def to_record(self):
|
||||
# calls = (
|
||||
# [ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
|
||||
# if self.tool_calls
|
||||
# else None
|
||||
# )
|
||||
# if calls:
|
||||
# assert isinstance(calls[0], ToolCall)
|
||||
if self.tool_calls and len(self.tool_calls) > 0:
|
||||
assert isinstance(self.tool_calls[0], ToolCall), type(self.tool_calls[0])
|
||||
for tool in self.tool_calls:
|
||||
assert isinstance(tool, ToolCall), type(tool)
|
||||
return Message(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
role=self.role,
|
||||
name=self.name,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls,
|
||||
tool_call_id=self.tool_call_id,
|
||||
created_at=self.created_at,
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
"""Create database model for table_name"""
|
||||
class_name = f"{table_name.capitalize()}Model" + dialect
|
||||
return create_or_get_model(class_name, MessageModel, table_name)
|
||||
# vector storage
|
||||
if settings.memgpt_pg_uri_no_default:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
elif config.archival_storage_type == "sqlite" or config.archival_storage_type == "chroma":
|
||||
embedding = Column(CommonVector)
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
raise ValueError(f"Unsupported archival_storage_type: {config.archival_storage_type}")
|
||||
embedding_config = Column(EmbeddingConfigColumn)
|
||||
metadata_ = Column(MutableJson)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
Index("passage_idx_user", user_id, agent_id, doc_id),
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
embedding_config=self.embedding_config,
|
||||
doc_id=self.doc_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
source_id=self.source_id,
|
||||
agent_id=self.agent_id,
|
||||
metadata_=self.metadata_,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
||||
|
||||
class SQLStorageConnector(StorageConnector):
|
||||
@@ -386,9 +355,6 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
|
||||
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
|
||||
|
||||
# create table
|
||||
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id)
|
||||
|
||||
# construct URI from enviornment variables
|
||||
if settings.pg_uri:
|
||||
self.uri = settings.pg_uri
|
||||
@@ -397,29 +363,29 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
# TODO: remove this eventually (config should NOT contain URI)
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
self.uri = self.config.archival_storage_uri
|
||||
self.db_model = PassageModel
|
||||
if self.config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.uri = self.config.recall_storage_uri
|
||||
self.db_model = MessageModel
|
||||
if self.config.recall_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
# create engine
|
||||
self.engine = create_engine(self.uri)
|
||||
|
||||
for c in self.db_model.__table__.columns:
|
||||
if c.name == "embedding":
|
||||
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
|
||||
|
||||
self.session_maker = sessionmaker(bind=self.engine)
|
||||
from memgpt.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
# TODO: move to DB init
|
||||
with self.session_maker() as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
# create table
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
|
||||
filters = self.get_filters(filters)
|
||||
with self.session_maker() as session:
|
||||
@@ -432,31 +398,40 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
return records
|
||||
|
||||
def insert_many(self, records, exists_ok=True, show_progress=False):
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
pass
|
||||
|
||||
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
|
||||
if len(records) == 0:
|
||||
return
|
||||
if isinstance(records[0], Passage):
|
||||
with self.engine.connect() as conn:
|
||||
db_records = [vars(record) for record in records]
|
||||
stmt = insert(self.db_model.__table__).values(db_records)
|
||||
if exists_ok:
|
||||
upsert_stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
|
||||
)
|
||||
conn.execute(upsert_stmt)
|
||||
|
||||
added_ids = [] # avoid adding duplicates
|
||||
# NOTE: this has not great performance due to the excessive commits
|
||||
with self.session_maker() as session:
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for record in iterable:
|
||||
# db_record = self.db_model(**vars(record))
|
||||
|
||||
if record.id in added_ids:
|
||||
continue
|
||||
|
||||
existing_record = session.query(self.db_model).filter_by(id=record.id).first()
|
||||
if existing_record:
|
||||
if exists_ok:
|
||||
fields = record.model_dump()
|
||||
fields.pop("id")
|
||||
session.query(self.db_model).filter(self.db_model.id == record.id).update(fields)
|
||||
print(f"Updated record with id {record.id}")
|
||||
session.commit()
|
||||
else:
|
||||
raise ValueError(f"Record with id {record.id} already exists.")
|
||||
|
||||
else:
|
||||
conn.execute(stmt)
|
||||
conn.commit()
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for record in iterable:
|
||||
# db_record = self.db_model(**vars(record))
|
||||
db_record = self.db_model(**record.dict())
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
print(f"Added record with id {record.id}")
|
||||
session.commit()
|
||||
|
||||
added_ids.append(record.id)
|
||||
|
||||
def insert(self, record, exists_ok=True):
|
||||
self.insert_many([record], exists_ok=exists_ok)
|
||||
@@ -515,16 +490,15 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
self.path = self.config.recall_storage_path
|
||||
if self.path is None:
|
||||
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
|
||||
self.db_model = MessageModel
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
self.path = os.path.join(self.path, f"sqlite.db")
|
||||
|
||||
# Create the SQLAlchemy engine
|
||||
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id, dialect="sqlite")
|
||||
self.engine = create_engine(f"sqlite:///{self.path}")
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
self.session_maker = sessionmaker(bind=self.engine)
|
||||
from memgpt.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
# import sqlite3
|
||||
|
||||
@@ -532,31 +506,18 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
|
||||
|
||||
def insert_many(self, records, exists_ok=True, show_progress=False):
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
pass
|
||||
|
||||
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
|
||||
if len(records) == 0:
|
||||
return
|
||||
if isinstance(records[0], Passage):
|
||||
with self.engine.connect() as conn:
|
||||
db_records = [vars(record) for record in records]
|
||||
stmt = insert(self.db_model.__table__).values(db_records)
|
||||
if exists_ok:
|
||||
upsert_stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
|
||||
)
|
||||
conn.execute(upsert_stmt)
|
||||
else:
|
||||
conn.execute(stmt)
|
||||
conn.commit()
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for record in iterable:
|
||||
# db_record = self.db_model(**vars(record))
|
||||
db_record = self.db_model(**record.dict())
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
with self.session_maker() as session:
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for record in iterable:
|
||||
# db_record = self.db_model(**vars(record))
|
||||
db_record = self.db_model(**record.dict())
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
|
||||
def insert(self, record, exists_ok=True):
|
||||
self.insert_many([record], exists_ok=exists_ok)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
@@ -14,12 +13,10 @@ from sqlalchemy import (
|
||||
Index,
|
||||
String,
|
||||
TypeDecorator,
|
||||
create_engine,
|
||||
desc,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.exc import InterfaceError, OperationalError
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
@@ -52,7 +49,9 @@ class LLMConfigColumn(TypeDecorator):
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
return vars(value)
|
||||
# return vars(value)
|
||||
if isinstance(value, LLMConfig):
|
||||
return value.model_dump()
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
@@ -72,7 +71,9 @@ class EmbeddingConfigColumn(TypeDecorator):
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
return vars(value)
|
||||
# return vars(value)
|
||||
if isinstance(value, EmbeddingConfig):
|
||||
return value.model_dump()
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
@@ -399,41 +400,45 @@ class MetadataStore:
|
||||
# Ensure valid URI
|
||||
assert self.uri, "Database URI is not provided or is invalid."
|
||||
|
||||
# Check if tables need to be created
|
||||
self.engine = create_engine(self.uri)
|
||||
try:
|
||||
Base.metadata.create_all(
|
||||
self.engine,
|
||||
tables=[
|
||||
UserModel.__table__,
|
||||
AgentModel.__table__,
|
||||
SourceModel.__table__,
|
||||
AgentSourceMappingModel.__table__,
|
||||
APIKeyModel.__table__,
|
||||
BlockModel.__table__,
|
||||
ToolModel.__table__,
|
||||
JobModel.__table__,
|
||||
],
|
||||
)
|
||||
except (InterfaceError, OperationalError) as e:
|
||||
traceback.print_exc()
|
||||
if config.metadata_storage_type == "postgres":
|
||||
raise ValueError(
|
||||
f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. "
|
||||
+ "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). "
|
||||
+ "\npostgres detected: Make sure the postgres database is running (https://memgpt.readme.io/docs/storage#postgres)."
|
||||
)
|
||||
elif config.metadata_storage_type == "sqlite":
|
||||
raise ValueError(
|
||||
f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. "
|
||||
+ "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). "
|
||||
+ "\nsqlite detected: Make sure that the sqlite.db file exists at the URI."
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
except:
|
||||
raise
|
||||
self.session_maker = sessionmaker(bind=self.engine)
|
||||
from memgpt.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
# # Check if tables need to be created
|
||||
# self.engine = create_engine(self.uri)
|
||||
# try:
|
||||
# Base.metadata.create_all(
|
||||
# self.engine,
|
||||
# tables=[
|
||||
# UserModel.__table__,
|
||||
# AgentModel.__table__,
|
||||
# SourceModel.__table__,
|
||||
# AgentSourceMappingModel.__table__,
|
||||
# APIKeyModel.__table__,
|
||||
# BlockModel.__table__,
|
||||
# ToolModel.__table__,
|
||||
# JobModel.__table__,
|
||||
# ],
|
||||
# )
|
||||
# except (InterfaceError, OperationalError) as e:
|
||||
# traceback.print_exc()
|
||||
# if config.metadata_storage_type == "postgres":
|
||||
# raise ValueError(
|
||||
# f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. "
|
||||
# + "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). "
|
||||
# + "\npostgres detected: Make sure the postgres database is running (https://memgpt.readme.io/docs/storage#postgres)."
|
||||
# )
|
||||
# elif config.metadata_storage_type == "sqlite":
|
||||
# raise ValueError(
|
||||
# f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. "
|
||||
# + "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). "
|
||||
# + "\nsqlite detected: Make sure that the sqlite.db file exists at the URI."
|
||||
# )
|
||||
# else:
|
||||
# raise e
|
||||
# except:
|
||||
# raise
|
||||
# self.session_maker = sessionmaker(bind=self.engine)
|
||||
|
||||
@enforce_types
|
||||
def create_api_key(self, user_id: str, name: str) -> APIKey:
|
||||
|
||||
@@ -133,6 +133,68 @@ class Server(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
from memgpt.agent_store.db import MessageModel, PassageModel
|
||||
from memgpt.config import MemGPTConfig
|
||||
|
||||
# NOTE: hack to see if single session management works
|
||||
from memgpt.metadata import (
|
||||
AgentModel,
|
||||
AgentSourceMappingModel,
|
||||
APIKeyModel,
|
||||
BlockModel,
|
||||
JobModel,
|
||||
SourceModel,
|
||||
ToolModel,
|
||||
UserModel,
|
||||
)
|
||||
from memgpt.settings import settings
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# determine the storage type
|
||||
if config.recall_storage_type == "postgres":
|
||||
engine = create_engine(settings.memgpt_pg_uri)
|
||||
elif config.recall_storage_type == "sqlite":
|
||||
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))
|
||||
else:
|
||||
raise ValueError(f"Unknown recall_storage_type: {config.recall_storage_type}")
|
||||
|
||||
Base = declarative_base()
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(
|
||||
engine,
|
||||
tables=[
|
||||
UserModel.__table__,
|
||||
AgentModel.__table__,
|
||||
SourceModel.__table__,
|
||||
AgentSourceMappingModel.__table__,
|
||||
APIKeyModel.__table__,
|
||||
BlockModel.__table__,
|
||||
ToolModel.__table__,
|
||||
JobModel.__table__,
|
||||
PassageModel.__table__,
|
||||
MessageModel.__table__,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
db_context = contextmanager(get_db)
|
||||
|
||||
|
||||
class SyncServer(Server):
|
||||
"""Simple single-threaded / blocking server process"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user