feat: Ollama embeddings api + Ollama tests (#1433)

Co-authored-by: Krishna Mandal <krishna@mandal.us>
This commit is contained in:
Sarah Wooders
2024-06-04 20:49:20 -07:00
committed by GitHub
parent f56179050c
commit 97ef8ba022
10 changed files with 964 additions and 923 deletions

38
.github/workflows/test_ollama.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: Endpoint (Ollama)
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Start Ollama Server
run: |
curl -fsSL https://ollama.com/install.sh | sh
ollama serve &
sleep 10 # wait for server
ollama pull dolphin2.2-mistral:7b-q6_K
ollama pull mxbai-embed-large
- name: "Setup Python, Poetry and Dependencies"
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E ollama"
- name: Test LLM endpoint
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_ollama
- name: Test embedding endpoint
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_ollama

View File

@@ -28,8 +28,7 @@ jobs:
install-args: "-E dev -E postgres -E milvus"
- name: Initialize credentials
run: |
poetry run memgpt quickstart --backend memgpt
run: poetry run memgpt quickstart --backend memgpt
#- name: Run docker compose server
# env:

2
.gitignore vendored
View File

@@ -1015,4 +1015,4 @@ pgdata/
## pytest mirrors
memgpt/.pytest_cache/
memgpy/pytest.ini
memgpy/pytest.ini

View File

@@ -0,0 +1,7 @@
{
"embedding_endpoint_type": "ollama",
"embedding_endpoint": "http://127.0.0.1:11434",
"embedding_model": "mxbai-embed-large",
"embedding_dim": 512,
"embedding_chunk_size": 200
}

View File

@@ -0,0 +1,6 @@
{
"context_window": 8192,
"model_endpoint_type": "ollama",
"model_endpoint": "http://127.0.0.1:11434",
"model": "dolphin2.2-mistral:7b-q6_K"
}

View File

@@ -866,6 +866,42 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
embedding_dim = int(embedding_dim)
except Exception:
raise ValueError(f"Failed to cast {embedding_dim} to integer.")
elif embedding_provider == "ollama":
# configure ollama embedding endpoint
embedding_endpoint_type = "ollama"
embedding_endpoint = "http://localhost:11434/api/embeddings"
# Source: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings:~:text=http%3A//localhost%3A11434/api/embeddings
# get endpoint (is this necessary?)
embedding_endpoint = questionary.text("Enter Ollama API endpoint:").ask()
if embedding_endpoint is None:
raise KeyboardInterrupt
while not utils.is_valid_url(embedding_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
embedding_endpoint = questionary.text("Enter Ollama API endpoint:").ask()
if embedding_endpoint is None:
raise KeyboardInterrupt
# get model type
default_embedding_model = (
config.default_embedding_config.embedding_model if config.default_embedding_config else "mxbai-embed-large"
)
embedding_model = questionary.text(
"Enter Ollama model tag (e.g. mxbai-embed-large):",
default=default_embedding_model,
).ask()
if embedding_model is None:
raise KeyboardInterrupt
# get model dimensions
default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config else "512"
embedding_dim = questionary.text("Enter embedding model dimensions (e.g. 512):", default=str(default_embedding_dim)).ask()
if embedding_dim is None:
raise KeyboardInterrupt
try:
embedding_dim = int(embedding_dim)
except Exception:
raise ValueError(f"Failed to cast {embedding_dim} to integer.")
else: # local models
embedding_endpoint_type = "local"
embedding_endpoint = None

View File

@@ -199,6 +199,20 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
base_url=config.embedding_endpoint,
user=user_id,
)
elif endpoint_type == "ollama":
from llama_index.embeddings.ollama import OllamaEmbedding
ollama_additional_kwargs = {}
callback_manager = None
model = OllamaEmbedding(
model_name=config.embedding_model,
base_url=config.embedding_endpoint,
ollama_additional_kwargs=ollama_additional_kwargs or {},
callback_manager=callback_manager or None,
)
return model
else:
return default_embedding_model()

1766
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -35,7 +35,7 @@ numpy = "^1.26.2"
demjson3 = "^3.0.6"
tiktoken = "^0.5.1"
pyyaml = "^6.0.1"
chromadb = "^0.4.18"
chromadb = "^0.5.0"
sqlalchemy-json = "^0.7.0"
fastapi = {version = "^0.104.1", optional = true}
uvicorn = {version = "^0.24.0.post1", optional = true}
@@ -62,6 +62,8 @@ pytest = { version = "^7.4.4", optional = true }
pydantic-settings = "^2.2.1"
httpx-sse = "^0.4.0"
isort = { version = "^5.13.2", optional = true }
llama-index-embeddings-ollama = {version = "^0.1.2", optional = true}
protobuf = "3.20.0"
[tool.poetry.extras]
local = ["llama-index-embeddings-huggingface"]
@@ -70,6 +72,7 @@ milvus = ["pymilvus"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort"]
server = ["websockets", "fastapi", "uvicorn"]
autogen = ["pyautogen"]
ollama = ["llama-index-embeddings-ollama"]
[tool.black]
line-length = 140

View File

@@ -83,3 +83,13 @@ def test_embedding_endpoint_memgpt_hosted():
def test_embedding_endpoint_local():
filename = os.path.join(embedding_config_dir, "local.json")
run_embedding_endpoint(filename)
def test_llm_endpoint_ollama():
filename = os.path.join(llm_config_dir, "ollama.json")
run_llm_endpoint(filename)
def test_embedding_endpoint_ollama():
filename = os.path.join(embedding_config_dir, "ollama.json")
run_embedding_endpoint(filename)