From 8700158222af3bf4442b01edb472a456df60abe2 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Mon, 6 Nov 2023 17:19:45 -0800 Subject: [PATCH] Specify model inference and embedding endpoint separately (#286) --- memgpt/cli/cli_config.py | 22 ++++++++++++++++++++-- memgpt/cli/cli_load.py | 2 +- memgpt/connectors/db.py | 2 -- memgpt/connectors/local.py | 4 ---- memgpt/embeddings.py | 25 ++++++++++++------------- 5 files changed, 33 insertions(+), 22 deletions(-) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index f19c36b0..8131ded4 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -32,7 +32,9 @@ def configure(): # search for key in enviornment openai_key = os.getenv("OPENAI_API_KEY") if not openai_key: - openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask() + print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.") + # TODO: eventually stop relying on env variables and pass in keys explicitly + # openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask() # azure credentials use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask() @@ -77,7 +79,21 @@ def configure(): if len(endpoint_options) == 1: default_endpoint = endpoint_options[0] else: - default_endpoint = questionary.select("Select default endpoint:", endpoint_options).ask() + default_endpoint = questionary.select("Select default inference endpoint:", endpoint_options).ask() + + # configure embedding provider + endpoint_options.append("local") # can compute embeddings locally + if len(endpoint_options) == 1: + default_embedding_endpoint = endpoint_options[0] + print(f"Using embedding endpoint {default_embedding_endpoint}") + else: + default_embedding_endpoint = questionary.select("Select default embedding endpoint:", endpoint_options).ask() + + # configure embedding dimentions + default_embedding_dim = 1536 + if default_embedding_endpoint == "local": + # HF model uses lower dimentionality + default_embedding_dim = 384 # configure preset default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask() @@ -127,6 +143,8 @@ def configure(): model=default_model, preset=default_preset, model_endpoint=default_endpoint, + embedding_model=default_embedding_endpoint, + embedding_dim=default_embedding_dim, default_persona=default_persona, default_human=default_human, default_agent=default_agent, diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 10becc75..62939a35 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -45,7 +45,7 @@ def store_docs(name, docs, show_progress=True): text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters assert ( len(node.embedding) == config.embedding_dim - ), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}" + ), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}: {node.embedding}" passages.append(Passage(text=text, embedding=vector)) # insert into storage diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 5cefc5e3..2bbaebc4 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -124,8 +124,6 @@ class PostgresStorageConnector(StorageConnector): self.db_model.__table__.drop(self.engine) def save(self): - # don't need to save - print("Saving db") return @staticmethod diff --git a/memgpt/connectors/local.py b/memgpt/connectors/local.py index a916aac1..dd776379 100644 --- a/memgpt/connectors/local.py +++ b/memgpt/connectors/local.py @@ -84,7 +84,6 @@ class LocalStorageConnector(StorageConnector): def insert(self, passage: Passage): nodes = [TextNode(text=passage.text, embedding=passage.embedding)] - print("nodes", nodes) self.nodes += nodes if isinstance(self.index, EmptyIndex): self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True) @@ -96,7 +95,6 @@ class LocalStorageConnector(StorageConnector): self.nodes += nodes if isinstance(self.index, EmptyIndex): self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True) - print("new size", len(self.get_nodes())) else: orig_size = len(self.get_nodes()) self.index.insert_nodes(nodes) @@ -113,7 +111,6 @@ class LocalStorageConnector(StorageConnector): ) nodes = retriever.retrieve(query) results = [Passage(embedding=node.embedding, text=node.text) for node in nodes] - print(results) return results def save(self): @@ -121,7 +118,6 @@ class LocalStorageConnector(StorageConnector): self.nodes = self.get_nodes() os.makedirs(self.save_directory, exist_ok=True) pickle.dump(self.nodes, open(self.save_path, "wb")) - print("Saved local", self.save_path) @staticmethod def list_loaded_data(): diff --git a/memgpt/embeddings.py b/memgpt/embeddings.py index 20c6040e..0be65558 100644 --- a/memgpt/embeddings.py +++ b/memgpt/embeddings.py @@ -1,4 +1,5 @@ import typer +import os from llama_index.embeddings import OpenAIEmbedding @@ -10,10 +11,11 @@ def embedding_model(): # load config config = MemGPTConfig.load() - # TODO: use embedding_endpoint in the future - if config.model_endpoint == "openai": - return OpenAIEmbedding() - elif config.model_endpoint == "azure": + endpoint = config.embedding_model + if endpoint == "openai": + model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key) + return model + elif endpoint == "azure": return OpenAIEmbedding( model="text-embedding-ada-002", deployment_name=config.azure_embedding_deployment, @@ -22,17 +24,14 @@ def embedding_model(): api_type="azure", api_version=config.azure_version, ) - else: + elif endpoint == "local": # default to hugging face model from llama_index.embeddings import HuggingFaceEmbedding + os.environ["TOKENIZERS_PARALLELISM"] = "False" model = "BAAI/bge-small-en-v1.5" - typer.secho( - f"Warning: defaulting to HuggingFace embedding model {model} since model endpoint is not OpenAI or Azure.", - fg=typer.colors.YELLOW, - ) - typer.secho(f"Warning: ensure torch and transformers are installed") - # return f"local:{model}" - - # loads BAAI/bge-small-en-v1.5 return HuggingFaceEmbedding(model_name=model) + else: + # use env variable OPENAI_API_BASE + model = OpenAIEmbedding() + return model