feat: Ollama embeddings api + Ollama tests (#1433)
Co-authored-by: Krishna Mandal <krishna@mandal.us>
This commit is contained in:
38
.github/workflows/test_ollama.yml
vendored
Normal file
38
.github/workflows/test_ollama.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: Endpoint (Ollama)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Start Ollama Server
|
||||
run: |
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
ollama serve &
|
||||
sleep 10 # wait for server
|
||||
ollama pull dolphin2.2-mistral:7b-q6_K
|
||||
ollama pull mxbai-embed-large
|
||||
|
||||
- name: "Setup Python, Poetry and Dependencies"
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev -E ollama"
|
||||
|
||||
- name: Test LLM endpoint
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_ollama
|
||||
|
||||
- name: Test embedding endpoint
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_ollama
|
||||
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -28,8 +28,7 @@ jobs:
|
||||
install-args: "-E dev -E postgres -E milvus"
|
||||
|
||||
- name: Initialize credentials
|
||||
run: |
|
||||
poetry run memgpt quickstart --backend memgpt
|
||||
run: poetry run memgpt quickstart --backend memgpt
|
||||
|
||||
#- name: Run docker compose server
|
||||
# env:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1015,4 +1015,4 @@ pgdata/
|
||||
|
||||
## pytest mirrors
|
||||
memgpt/.pytest_cache/
|
||||
memgpy/pytest.ini
|
||||
memgpy/pytest.ini
|
||||
|
||||
7
configs/embedding_model_configs/ollama.json
Normal file
7
configs/embedding_model_configs/ollama.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"embedding_endpoint_type": "ollama",
|
||||
"embedding_endpoint": "http://127.0.0.1:11434",
|
||||
"embedding_model": "mxbai-embed-large",
|
||||
"embedding_dim": 512,
|
||||
"embedding_chunk_size": 200
|
||||
}
|
||||
6
configs/llm_model_configs/ollama.json
Normal file
6
configs/llm_model_configs/ollama.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model_endpoint_type": "ollama",
|
||||
"model_endpoint": "http://127.0.0.1:11434",
|
||||
"model": "dolphin2.2-mistral:7b-q6_K"
|
||||
}
|
||||
@@ -866,6 +866,42 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
|
||||
embedding_dim = int(embedding_dim)
|
||||
except Exception:
|
||||
raise ValueError(f"Failed to cast {embedding_dim} to integer.")
|
||||
elif embedding_provider == "ollama":
|
||||
# configure ollama embedding endpoint
|
||||
embedding_endpoint_type = "ollama"
|
||||
embedding_endpoint = "http://localhost:11434/api/embeddings"
|
||||
# Source: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings:~:text=http%3A//localhost%3A11434/api/embeddings
|
||||
|
||||
# get endpoint (is this necessary?)
|
||||
embedding_endpoint = questionary.text("Enter Ollama API endpoint:").ask()
|
||||
if embedding_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
while not utils.is_valid_url(embedding_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
embedding_endpoint = questionary.text("Enter Ollama API endpoint:").ask()
|
||||
if embedding_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# get model type
|
||||
default_embedding_model = (
|
||||
config.default_embedding_config.embedding_model if config.default_embedding_config else "mxbai-embed-large"
|
||||
)
|
||||
embedding_model = questionary.text(
|
||||
"Enter Ollama model tag (e.g. mxbai-embed-large):",
|
||||
default=default_embedding_model,
|
||||
).ask()
|
||||
if embedding_model is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# get model dimensions
|
||||
default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config else "512"
|
||||
embedding_dim = questionary.text("Enter embedding model dimensions (e.g. 512):", default=str(default_embedding_dim)).ask()
|
||||
if embedding_dim is None:
|
||||
raise KeyboardInterrupt
|
||||
try:
|
||||
embedding_dim = int(embedding_dim)
|
||||
except Exception:
|
||||
raise ValueError(f"Failed to cast {embedding_dim} to integer.")
|
||||
else: # local models
|
||||
embedding_endpoint_type = "local"
|
||||
embedding_endpoint = None
|
||||
|
||||
@@ -199,6 +199,20 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
base_url=config.embedding_endpoint,
|
||||
user=user_id,
|
||||
)
|
||||
elif endpoint_type == "ollama":
|
||||
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
ollama_additional_kwargs = {}
|
||||
callback_manager = None
|
||||
|
||||
model = OllamaEmbedding(
|
||||
model_name=config.embedding_model,
|
||||
base_url=config.embedding_endpoint,
|
||||
ollama_additional_kwargs=ollama_additional_kwargs or {},
|
||||
callback_manager=callback_manager or None,
|
||||
)
|
||||
return model
|
||||
|
||||
else:
|
||||
return default_embedding_model()
|
||||
|
||||
1766
poetry.lock
generated
1766
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -35,7 +35,7 @@ numpy = "^1.26.2"
|
||||
demjson3 = "^3.0.6"
|
||||
tiktoken = "^0.5.1"
|
||||
pyyaml = "^6.0.1"
|
||||
chromadb = "^0.4.18"
|
||||
chromadb = "^0.5.0"
|
||||
sqlalchemy-json = "^0.7.0"
|
||||
fastapi = {version = "^0.104.1", optional = true}
|
||||
uvicorn = {version = "^0.24.0.post1", optional = true}
|
||||
@@ -62,6 +62,8 @@ pytest = { version = "^7.4.4", optional = true }
|
||||
pydantic-settings = "^2.2.1"
|
||||
httpx-sse = "^0.4.0"
|
||||
isort = { version = "^5.13.2", optional = true }
|
||||
llama-index-embeddings-ollama = {version = "^0.1.2", optional = true}
|
||||
protobuf = "3.20.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["llama-index-embeddings-huggingface"]
|
||||
@@ -70,6 +72,7 @@ milvus = ["pymilvus"]
|
||||
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort"]
|
||||
server = ["websockets", "fastapi", "uvicorn"]
|
||||
autogen = ["pyautogen"]
|
||||
ollama = ["llama-index-embeddings-ollama"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 140
|
||||
|
||||
@@ -83,3 +83,13 @@ def test_embedding_endpoint_memgpt_hosted():
|
||||
def test_embedding_endpoint_local():
|
||||
filename = os.path.join(embedding_config_dir, "local.json")
|
||||
run_embedding_endpoint(filename)
|
||||
|
||||
|
||||
def test_llm_endpoint_ollama():
|
||||
filename = os.path.join(llm_config_dir, "ollama.json")
|
||||
run_llm_endpoint(filename)
|
||||
|
||||
|
||||
def test_embedding_endpoint_ollama():
|
||||
filename = os.path.join(embedding_config_dir, "ollama.json")
|
||||
run_embedding_endpoint(filename)
|
||||
|
||||
Reference in New Issue
Block a user