fix: fallback to MemGPTConfig URI for postgres if no environment variables (#1216)

This commit is contained in:
Sarah Wooders
2024-04-03 13:06:44 -07:00
committed by GitHub
parent 01badfc782
commit 5677edb2b7
2 changed files with 17 additions and 13 deletions

View File

@@ -424,30 +424,32 @@ class PostgresStorageConnector(SQLStorageConnector):
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
self.uri = self.config.archival_storage_uri
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
# create table
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id)
# construct URI from enviornment variables
if os.getenv("MEMGPT_PGURI"):
self.uri = os.getenv("MEMGPT_PGURI")
else:
elif os.getenv("MEMGPT_PG_DB"):
db = os.getenv("MEMGPT_PG_DB", "memgpt")
user = os.getenv("MEMGPT_PG_USER", "memgpt")
password = os.getenv("MEMGPT_PG_PASSWORD", "memgpt")
port = os.getenv("MEMGPT_PG_PORT", "5432")
url = os.getenv("MEMGPT_PG_URL", "localhost")
self.uri = f"postgresql+pg8000://{user}:{password}@{url}:{port}/{db}"
else:
# use config URI
# TODO: remove this eventually (config should NOT contain URI)
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
self.uri = self.config.archival_storage_uri
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
# create engine
self.engine = create_engine(self.uri)

View File

@@ -307,13 +307,15 @@ class MetadataStore:
# construct URI from enviornment variables
if os.getenv("MEMGPT_PGURI"):
self.uri = os.getenv("MEMGPT_PGURI")
else:
elif os.getenv("MEMGPT_PG_DB"):
db = os.getenv("MEMGPT_PG_DB", "memgpt")
user = os.getenv("MEMGPT_PG_USER", "memgpt")
password = os.getenv("MEMGPT_PG_PASSWORD", "memgpt")
port = os.getenv("MEMGPT_PG_PORT", "5432")
url = os.getenv("MEMGPT_PG_URL", "localhost")
self.uri = f"postgresql+pg8000://{user}:{password}@{url}:{port}/{db}"
else:
self.uri = config.metadata_storage_uri
elif config.metadata_storage_type == "sqlite":
path = os.path.join(config.metadata_storage_path, "sqlite.db")
self.uri = f"sqlite:///{path}"