From bbacf0fb3325f80bc7490864283d3bf5a750fba9 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 26 Oct 2023 15:30:31 -0700 Subject: [PATCH] add database test --- memgpt/config.py | 21 ++++++++++++ memgpt/connectors/connector.py | 52 +++++++++++++++++++----------- memgpt/memory.py | 15 --------- tests/test_load_archival.py | 58 ++++++++++++++++++++++++++++++++-- 4 files changed, 110 insertions(+), 36 deletions(-) diff --git a/memgpt/config.py b/memgpt/config.py index 3fa8804d..857872e9 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -25,6 +25,27 @@ model_choices = [ ] +class MemGPTConfig: + + # Model configuration + openai_key: str = None + azure_key: str = None + azure_endpoint: str = None + model_endpoint: str = None + + # Storage (archival/recall) configuration + storage_type: str = "local" # ["local", "vectordb"] + storage_url: str = None + + # Persona configuration + default_person = "" + + # Human configuration + default_human = "" + + + + class Config: personas_dir = os.path.join("memgpt", "personas", "examples") custom_personas_dir = os.path.join(MEMGPT_DIR, "personas") diff --git a/memgpt/connectors/connector.py b/memgpt/connectors/connector.py index e06d3418..549c2d7d 100644 --- a/memgpt/connectors/connector.py +++ b/memgpt/connectors/connector.py @@ -59,32 +59,48 @@ def load_webpage( # embed docs print("Indexing documents...") - index = index_docs(docs) + index = get_index(docs) # save connector information into .memgpt metadata file save_index(index, name) - @app.command("database") def load_database( name: str = typer.Option(help="Name of dataset to load."), - scheme: str = typer.Option(help="Database scheme."), - host: str = typer.Option(help="Database host."), - port: int = typer.Option(help="Database port."), - user: str = typer.Option(help="Database user."), - password: str = typer.Option(help="Database password."), - dbname: str = typer.Option(help="Database name."), - query: str = typer.Option(None, help="Database query."), + query: str = typer.Option(help="Database query."), + dump_path: str = typer.Option(None, help="Path to dump file."), + scheme: str = typer.Option(None, help="Database scheme."), + host: str = typer.Option(None, help="Database host."), + port: int = typer.Option(None, help="Database port."), + user: str = typer.Option(None, help="Database user."), + password: str = typer.Option(None, help="Database password."), + dbname: str = typer.Option(None, help="Database name."), ): from llama_index.readers.database import DatabaseReader - - db = DatabaseReader( - scheme=scheme, # Database Scheme - host=host, # Database Host - port=port, # Database Port - user=user, # Database User - password=password, # Database Password - dbname=dbname, # Database Name - ) + print(dump_path, scheme) + + if dump_path is not None: + # read from database dump file + from sqlalchemy import create_engine, MetaData + engine = create_engine(f'sqlite:///{dump_path}') + + db = DatabaseReader(engine=engine) + else: + assert dump_path is None, "Cannot provide both dump_path and database connection parameters." + assert scheme is not None, "Must provide database scheme." + assert host is not None, "Must provide database host." + assert port is not None, "Must provide database port." + assert user is not None, "Must provide database user." + assert password is not None, "Must provide database password." + assert dbname is not None, "Must provide database name." + + db = DatabaseReader( + scheme=scheme, # Database Scheme + host=host, # Database Host + port=port, # Database Port + user=user, # Database User + password=password, # Database Password + dbname=dbname, # Database Name + ) # load data docs = db.load_data(query=query) diff --git a/memgpt/memory.py b/memgpt/memory.py index c2dcaf6f..f4fa09ec 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -554,20 +554,6 @@ class LocalArchivalMemory(ArchivalMemory): index=self.index, # does this get refreshed? similarity_top_k=self.top_k, ) - - # configure response synthesizer - response_synthesizer = get_response_synthesizer() - - # assemble query engine - self.query_engine = RetrieverQueryEngine( - retriever=self.retriever, - #response_synthesizer=response_synthesizer, - #node_postprocessors=[ - # SimilarityPostprocessor(similarity_cutoff=0) # TODO: tune this - #] - ) - - # cache for repeated queries # TODO: have some mechanism for cleanup otherwise will lead to OOM self.cache = {} @@ -581,7 +567,6 @@ class LocalArchivalMemory(ArchivalMemory): count = min(count + start, self.top_k) if query_string not in self.cache: - #self.cache[query_string] = self.query_engine.query(query_string) self.cache[query_string] = self.retriever.retrieve(query_string) results = self.cache[query_string][start:start+count] diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index ddfa1f50..dc857372 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -1,7 +1,7 @@ import tempfile import asyncio import os -from memgpt.connectors.connector import load_directory +from memgpt.connectors.connector import load_directory, load_database, load_webpage import memgpt.agent as agent import memgpt.system as system import memgpt.utils as utils @@ -20,7 +20,7 @@ import memgpt.interface # for printing to terminal import asyncio from datasets import load_dataset -def test_archival(): +def test_load_directory(): # downloading hugging face dataset (if does not exist) dataset = load_dataset("MemGPT/example_short_stories") @@ -58,4 +58,56 @@ def test_archival(): results = query("cinderella be getting sick") assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}" -test_archival() \ No newline at end of file +def test_load_webpage(): + pass + +def test_load_database(): + + from sqlalchemy import create_engine, MetaData + import pandas as pd + + db_path = "memgpt/personas/examples/sqldb/test.db" + engine = create_engine(f'sqlite:///{db_path}') + + # Create a MetaData object and reflect the database to get table information. + metadata = MetaData() + metadata.reflect(bind=engine) + + # Get a list of table names from the reflected metadata. + table_names = metadata.tables.keys() + + print(table_names) + + # Define a SQL query to retrieve data from a table (replace 'your_table_name' with your actual table name). + query = f"SELECT * FROM {list(table_names)[0]}" + + # Use Pandas to read data from the database into a DataFrame. + df = pd.read_sql_query(query, engine) + print(df) + + load_database( + name="tmp_db_dataset", + #engine=engine, + dump_path=db_path, + query=f"SELECT * FROM {list(table_names)[0]}", + ) + + persistence_manager = LocalStateManager(archival_memory_db="tmp_db_dataset") + + # create agent + memgpt_agent = presets.use_preset( + presets.DEFAULT, + DEFAULT_MEMGPT_MODEL, + personas.get_persona_text(personas.DEFAULT), + humans.get_human_text(humans.DEFAULT), + memgpt.interface, + persistence_manager, + ) + print("Successfully loaded into index") + assert True + + + + +#test_load_directory() +test_load_database() \ No newline at end of file