Specify model inference and embedding endpoint separately (#286)

This commit is contained in:
Sarah Wooders
2023-11-06 17:19:45 -08:00
committed by GitHub
parent 8adef204e6
commit 8700158222
5 changed files with 33 additions and 22 deletions

View File

@@ -32,7 +32,9 @@ def configure():
# search for key in enviornment
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.")
# TODO: eventually stop relying on env variables and pass in keys explicitly
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
# azure credentials
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask()
@@ -77,7 +79,21 @@ def configure():
if len(endpoint_options) == 1:
default_endpoint = endpoint_options[0]
else:
default_endpoint = questionary.select("Select default endpoint:", endpoint_options).ask()
default_endpoint = questionary.select("Select default inference endpoint:", endpoint_options).ask()
# configure embedding provider
endpoint_options.append("local") # can compute embeddings locally
if len(endpoint_options) == 1:
default_embedding_endpoint = endpoint_options[0]
print(f"Using embedding endpoint {default_embedding_endpoint}")
else:
default_embedding_endpoint = questionary.select("Select default embedding endpoint:", endpoint_options).ask()
# configure embedding dimentions
default_embedding_dim = 1536
if default_embedding_endpoint == "local":
# HF model uses lower dimentionality
default_embedding_dim = 384
# configure preset
default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask()
@@ -127,6 +143,8 @@ def configure():
model=default_model,
preset=default_preset,
model_endpoint=default_endpoint,
embedding_model=default_embedding_endpoint,
embedding_dim=default_embedding_dim,
default_persona=default_persona,
default_human=default_human,
default_agent=default_agent,

View File

@@ -45,7 +45,7 @@ def store_docs(name, docs, show_progress=True):
text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters
assert (
len(node.embedding) == config.embedding_dim
), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}"
), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}: {node.embedding}"
passages.append(Passage(text=text, embedding=vector))
# insert into storage

View File

@@ -124,8 +124,6 @@ class PostgresStorageConnector(StorageConnector):
self.db_model.__table__.drop(self.engine)
def save(self):
# don't need to save
print("Saving db")
return
@staticmethod

View File

@@ -84,7 +84,6 @@ class LocalStorageConnector(StorageConnector):
def insert(self, passage: Passage):
nodes = [TextNode(text=passage.text, embedding=passage.embedding)]
print("nodes", nodes)
self.nodes += nodes
if isinstance(self.index, EmptyIndex):
self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
@@ -96,7 +95,6 @@ class LocalStorageConnector(StorageConnector):
self.nodes += nodes
if isinstance(self.index, EmptyIndex):
self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
print("new size", len(self.get_nodes()))
else:
orig_size = len(self.get_nodes())
self.index.insert_nodes(nodes)
@@ -113,7 +111,6 @@ class LocalStorageConnector(StorageConnector):
)
nodes = retriever.retrieve(query)
results = [Passage(embedding=node.embedding, text=node.text) for node in nodes]
print(results)
return results
def save(self):
@@ -121,7 +118,6 @@ class LocalStorageConnector(StorageConnector):
self.nodes = self.get_nodes()
os.makedirs(self.save_directory, exist_ok=True)
pickle.dump(self.nodes, open(self.save_path, "wb"))
print("Saved local", self.save_path)
@staticmethod
def list_loaded_data():

View File

@@ -1,4 +1,5 @@
import typer
import os
from llama_index.embeddings import OpenAIEmbedding
@@ -10,10 +11,11 @@ def embedding_model():
# load config
config = MemGPTConfig.load()
# TODO: use embedding_endpoint in the future
if config.model_endpoint == "openai":
return OpenAIEmbedding()
elif config.model_endpoint == "azure":
endpoint = config.embedding_model
if endpoint == "openai":
model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key)
return model
elif endpoint == "azure":
return OpenAIEmbedding(
model="text-embedding-ada-002",
deployment_name=config.azure_embedding_deployment,
@@ -22,17 +24,14 @@ def embedding_model():
api_type="azure",
api_version=config.azure_version,
)
else:
elif endpoint == "local":
# default to hugging face model
from llama_index.embeddings import HuggingFaceEmbedding
os.environ["TOKENIZERS_PARALLELISM"] = "False"
model = "BAAI/bge-small-en-v1.5"
typer.secho(
f"Warning: defaulting to HuggingFace embedding model {model} since model endpoint is not OpenAI or Azure.",
fg=typer.colors.YELLOW,
)
typer.secho(f"Warning: ensure torch and transformers are installed")
# return f"local:{model}"
# loads BAAI/bge-small-en-v1.5
return HuggingFaceEmbedding(model_name=model)
else:
# use env variable OPENAI_API_BASE
model = OpenAIEmbedding()
return model