diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index d36a584a..6ffc558b 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -424,30 +424,32 @@ class PostgresStorageConnector(SQLStorageConnector): super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) - # get storage URI - if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: - self.uri = self.config.archival_storage_uri - if self.config.archival_storage_uri is None: - raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}") - elif table_type == TableType.RECALL_MEMORY: - self.uri = self.config.recall_storage_uri - if self.config.recall_storage_uri is None: - raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}") - else: - raise ValueError(f"Table type {table_type} not implemented") # create table self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id) # construct URI from enviornment variables if os.getenv("MEMGPT_PGURI"): self.uri = os.getenv("MEMGPT_PGURI") - else: + elif os.getenv("MEMGPT_PG_DB"): db = os.getenv("MEMGPT_PG_DB", "memgpt") user = os.getenv("MEMGPT_PG_USER", "memgpt") password = os.getenv("MEMGPT_PG_PASSWORD", "memgpt") port = os.getenv("MEMGPT_PG_PORT", "5432") url = os.getenv("MEMGPT_PG_URL", "localhost") self.uri = f"postgresql+pg8000://{user}:{password}@{url}:{port}/{db}" + else: + # use config URI + # TODO: remove this eventually (config should NOT contain URI) + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: + self.uri = self.config.archival_storage_uri + if self.config.archival_storage_uri is None: + raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}") + elif table_type == TableType.RECALL_MEMORY: + self.uri = self.config.recall_storage_uri + if self.config.recall_storage_uri is None: + raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}") + else: + raise ValueError(f"Table type {table_type} not implemented") # create engine self.engine = create_engine(self.uri) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 7d34ddee..8da4846f 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -307,13 +307,15 @@ class MetadataStore: # construct URI from enviornment variables if os.getenv("MEMGPT_PGURI"): self.uri = os.getenv("MEMGPT_PGURI") - else: + elif os.getenv("MEMGPT_PG_DB"): db = os.getenv("MEMGPT_PG_DB", "memgpt") user = os.getenv("MEMGPT_PG_USER", "memgpt") password = os.getenv("MEMGPT_PG_PASSWORD", "memgpt") port = os.getenv("MEMGPT_PG_PORT", "5432") url = os.getenv("MEMGPT_PG_URL", "localhost") self.uri = f"postgresql+pg8000://{user}:{password}@{url}:{port}/{db}" + else: + self.uri = config.metadata_storage_uri elif config.metadata_storage_type == "sqlite": path = os.path.join(config.metadata_storage_path, "sqlite.db") self.uri = f"sqlite:///{path}"