Specify model inference and embedding endpoint separately (#286)
This commit is contained in:
@@ -32,7 +32,9 @@ def configure():
|
||||
# search for key in enviornment
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
if not openai_key:
|
||||
openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
|
||||
print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.")
|
||||
# TODO: eventually stop relying on env variables and pass in keys explicitly
|
||||
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
|
||||
|
||||
# azure credentials
|
||||
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask()
|
||||
@@ -77,7 +79,21 @@ def configure():
|
||||
if len(endpoint_options) == 1:
|
||||
default_endpoint = endpoint_options[0]
|
||||
else:
|
||||
default_endpoint = questionary.select("Select default endpoint:", endpoint_options).ask()
|
||||
default_endpoint = questionary.select("Select default inference endpoint:", endpoint_options).ask()
|
||||
|
||||
# configure embedding provider
|
||||
endpoint_options.append("local") # can compute embeddings locally
|
||||
if len(endpoint_options) == 1:
|
||||
default_embedding_endpoint = endpoint_options[0]
|
||||
print(f"Using embedding endpoint {default_embedding_endpoint}")
|
||||
else:
|
||||
default_embedding_endpoint = questionary.select("Select default embedding endpoint:", endpoint_options).ask()
|
||||
|
||||
# configure embedding dimentions
|
||||
default_embedding_dim = 1536
|
||||
if default_embedding_endpoint == "local":
|
||||
# HF model uses lower dimentionality
|
||||
default_embedding_dim = 384
|
||||
|
||||
# configure preset
|
||||
default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask()
|
||||
@@ -127,6 +143,8 @@ def configure():
|
||||
model=default_model,
|
||||
preset=default_preset,
|
||||
model_endpoint=default_endpoint,
|
||||
embedding_model=default_embedding_endpoint,
|
||||
embedding_dim=default_embedding_dim,
|
||||
default_persona=default_persona,
|
||||
default_human=default_human,
|
||||
default_agent=default_agent,
|
||||
|
||||
@@ -45,7 +45,7 @@ def store_docs(name, docs, show_progress=True):
|
||||
text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters
|
||||
assert (
|
||||
len(node.embedding) == config.embedding_dim
|
||||
), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}"
|
||||
), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}: {node.embedding}"
|
||||
passages.append(Passage(text=text, embedding=vector))
|
||||
|
||||
# insert into storage
|
||||
|
||||
@@ -124,8 +124,6 @@ class PostgresStorageConnector(StorageConnector):
|
||||
self.db_model.__table__.drop(self.engine)
|
||||
|
||||
def save(self):
|
||||
# don't need to save
|
||||
print("Saving db")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -84,7 +84,6 @@ class LocalStorageConnector(StorageConnector):
|
||||
|
||||
def insert(self, passage: Passage):
|
||||
nodes = [TextNode(text=passage.text, embedding=passage.embedding)]
|
||||
print("nodes", nodes)
|
||||
self.nodes += nodes
|
||||
if isinstance(self.index, EmptyIndex):
|
||||
self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
|
||||
@@ -96,7 +95,6 @@ class LocalStorageConnector(StorageConnector):
|
||||
self.nodes += nodes
|
||||
if isinstance(self.index, EmptyIndex):
|
||||
self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
|
||||
print("new size", len(self.get_nodes()))
|
||||
else:
|
||||
orig_size = len(self.get_nodes())
|
||||
self.index.insert_nodes(nodes)
|
||||
@@ -113,7 +111,6 @@ class LocalStorageConnector(StorageConnector):
|
||||
)
|
||||
nodes = retriever.retrieve(query)
|
||||
results = [Passage(embedding=node.embedding, text=node.text) for node in nodes]
|
||||
print(results)
|
||||
return results
|
||||
|
||||
def save(self):
|
||||
@@ -121,7 +118,6 @@ class LocalStorageConnector(StorageConnector):
|
||||
self.nodes = self.get_nodes()
|
||||
os.makedirs(self.save_directory, exist_ok=True)
|
||||
pickle.dump(self.nodes, open(self.save_path, "wb"))
|
||||
print("Saved local", self.save_path)
|
||||
|
||||
@staticmethod
|
||||
def list_loaded_data():
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import typer
|
||||
import os
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
|
||||
@@ -10,10 +11,11 @@ def embedding_model():
|
||||
# load config
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# TODO: use embedding_endpoint in the future
|
||||
if config.model_endpoint == "openai":
|
||||
return OpenAIEmbedding()
|
||||
elif config.model_endpoint == "azure":
|
||||
endpoint = config.embedding_model
|
||||
if endpoint == "openai":
|
||||
model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key)
|
||||
return model
|
||||
elif endpoint == "azure":
|
||||
return OpenAIEmbedding(
|
||||
model="text-embedding-ada-002",
|
||||
deployment_name=config.azure_embedding_deployment,
|
||||
@@ -22,17 +24,14 @@ def embedding_model():
|
||||
api_type="azure",
|
||||
api_version=config.azure_version,
|
||||
)
|
||||
else:
|
||||
elif endpoint == "local":
|
||||
# default to hugging face model
|
||||
from llama_index.embeddings import HuggingFaceEmbedding
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "False"
|
||||
model = "BAAI/bge-small-en-v1.5"
|
||||
typer.secho(
|
||||
f"Warning: defaulting to HuggingFace embedding model {model} since model endpoint is not OpenAI or Azure.",
|
||||
fg=typer.colors.YELLOW,
|
||||
)
|
||||
typer.secho(f"Warning: ensure torch and transformers are installed")
|
||||
# return f"local:{model}"
|
||||
|
||||
# loads BAAI/bge-small-en-v1.5
|
||||
return HuggingFaceEmbedding(model_name=model)
|
||||
else:
|
||||
# use env variable OPENAI_API_BASE
|
||||
model = OpenAIEmbedding()
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user