Files
letta-server/memgpt/embeddings.py
Sarah Wooders 23f3d42fae Refactoring CLI to use config file, connect to Llama Index data sources, and allow for multiple agents (#154)
* Migrate to `memgpt run` and `memgpt configure` 
* Add Llama index data sources via `memgpt load` 
* Save config files for defaults and agents
2023-10-30 16:47:54 -07:00

33 lines
1.1 KiB
Python

from memgpt.config import MemGPTConfig
import typer
from llama_index.embeddings import OpenAIEmbedding
def embedding_model(config: MemGPTConfig):
# TODO: use embedding_endpoint in the future
if config.model_endpoint == "openai":
return OpenAIEmbedding()
elif config.model_endpoint == "azure":
return OpenAIEmbedding(
model="text-embedding-ada-002",
deployment_name=config.azure_embedding_deployment,
api_key=config.azure_key,
api_base=config.azure_endpoint,
api_type="azure",
api_version=config.azure_version,
)
else:
# default to hugging face model
from llama_index.embeddings import HuggingFaceEmbedding
model = "BAAI/bge-small-en-v1.5"
typer.secho(
f"Warning: defaulting to HuggingFace embedding model {model} since model endpoint is not OpenAI or Azure.",
fg=typer.colors.YELLOW,
)
typer.secho(f"Warning: ensure torch and transformers are installed")
# return f"local:{model}"
# loads BAAI/bge-small-en-v1.5
return HuggingFaceEmbedding(model_name=model)