feat(ci): Add coverage for self-hosted providers (#2976)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model_endpoint_type": "ollama",
|
||||
"model_endpoint": "http://127.0.0.1:11434",
|
||||
"model": "thewindmom/hermes-3-llama-3.1-8b",
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
"context_window": 8192,
|
||||
"model_endpoint_type": "ollama",
|
||||
"model_endpoint": "http://127.0.0.1:11434",
|
||||
"model": "qwen3:32b",
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
||||
|
||||
7
tests/configs/llm_model_configs/vllm.json
Normal file
7
tests/configs/llm_model_configs/vllm.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model_endpoint_type": "vllm",
|
||||
"model_endpoint": "http://127.0.0.1:8000",
|
||||
"model": "Qwen/Qwen3-32B-AWQ",
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
||||
@@ -115,7 +115,7 @@ all_configs = [
|
||||
"gemini-2.5-flash-vertex.json",
|
||||
"gemini-2.5-pro-vertex.json",
|
||||
"together-qwen-2.5-72b-instruct.json",
|
||||
# "ollama.json", # TODO (cliandy): enable this in ollama testing
|
||||
"ollama.json", # TODO (cliandy): enable this in ollama testing
|
||||
]
|
||||
|
||||
|
||||
|
||||
266
tests/test_ollama.py
Normal file
266
tests/test_ollama.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import pytest
|
||||
from letta_client import Letta
|
||||
|
||||
from letta.schemas.providers import OllamaProvider
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_provider():
|
||||
"""Create an Ollama provider for testing"""
|
||||
return OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=model_settings.ollama_base_url or "http://localhost:11434",
|
||||
api_key=None,
|
||||
default_prompt_formatter="chatml",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_llm_models_async(ollama_provider):
|
||||
"""Test async listing of LLM models from Ollama"""
|
||||
models = await ollama_provider.list_llm_models_async()
|
||||
assert len(models) >= 0
|
||||
|
||||
model = models[0]
|
||||
assert model.handle == f"{ollama_provider.name}/{model.model}"
|
||||
assert model.model_endpoint_type == "ollama"
|
||||
assert model.model_endpoint == ollama_provider.base_url
|
||||
assert model.context_window is not None
|
||||
assert model.context_window > 0
|
||||
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_embedding_models_async(ollama_provider):
|
||||
"""Test async listing of embedding models from Ollama"""
|
||||
embedding_models = await ollama_provider.list_embedding_models_async()
|
||||
assert len(embedding_models) >= 0
|
||||
|
||||
model = embedding_models[0]
|
||||
assert model.handle == f"{ollama_provider.name}/{model.embedding_model}"
|
||||
assert model.embedding_endpoint_type == "ollama"
|
||||
assert model.embedding_endpoint == ollama_provider.base_url
|
||||
assert model.embedding_dim is not None
|
||||
assert model.embedding_dim > 0
|
||||
|
||||
|
||||
def test_send_message_with_ollama_sync(ollama_provider):
|
||||
"""Test sending a message with Ollama (sync version)"""
|
||||
import os
|
||||
import threading
|
||||
|
||||
from letta_client import MessageCreate
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Skip if no models available
|
||||
models = ollama_provider.list_llm_models()
|
||||
if len(models) == 0:
|
||||
pytest.skip("No Ollama models available for testing")
|
||||
|
||||
# Use the first available model
|
||||
model = models[0]
|
||||
|
||||
# Set up client (similar to other tests)
|
||||
def run_server():
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
|
||||
client = Letta(base_url=server_url, token=None)
|
||||
|
||||
# Create agent with Ollama model
|
||||
agent = client.agents.create(
|
||||
name="test_ollama_agent",
|
||||
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
|
||||
model=model.handle,
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
try:
|
||||
# Send a simple message
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond with just 'Hi there!'")]
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
# Find the assistant response
|
||||
assistant_response = None
|
||||
for msg in response.messages:
|
||||
if msg.message_type == "assistant_message":
|
||||
assistant_response = msg
|
||||
break
|
||||
|
||||
assert assistant_response is not None
|
||||
assert len(assistant_response.text) > 0
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_ollama_async_streaming(ollama_provider):
|
||||
"""Test sending a message with Ollama using async streaming"""
|
||||
import os
|
||||
import threading
|
||||
|
||||
from letta_client import MessageCreate
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Skip if no models available
|
||||
models = await ollama_provider.list_llm_models_async()
|
||||
if len(models) == 0:
|
||||
pytest.skip("No Ollama models available for testing")
|
||||
|
||||
# Use the first available model
|
||||
model = models[0]
|
||||
|
||||
# Set up client
|
||||
def run_server():
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
|
||||
client = Letta(base_url=server_url, token=None)
|
||||
|
||||
# Create agent with Ollama model
|
||||
agent = client.agents.create(
|
||||
name="test_ollama_streaming_agent",
|
||||
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
|
||||
model=model.handle,
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
try:
|
||||
# Test step streaming (no token streaming)
|
||||
response_stream = client.agents.messages.create_stream(
|
||||
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")], stream_tokens=False
|
||||
)
|
||||
|
||||
# Collect streamed messages
|
||||
streamed_messages = []
|
||||
for chunk in response_stream:
|
||||
if hasattr(chunk, "messages"):
|
||||
streamed_messages.extend(chunk.messages)
|
||||
|
||||
# Verify streaming response
|
||||
assert len(streamed_messages) > 0
|
||||
|
||||
# Find assistant response in stream
|
||||
assistant_response = None
|
||||
for msg in streamed_messages:
|
||||
if msg.message_type == "assistant_message":
|
||||
assistant_response = msg
|
||||
break
|
||||
|
||||
assert assistant_response is not None
|
||||
assert len(assistant_response.text) > 0
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_ollama_async_job(ollama_provider):
|
||||
"""Test sending a message with Ollama using async background job"""
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from letta_client import MessageCreate
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Skip if no models available
|
||||
models = await ollama_provider.list_llm_models_async()
|
||||
if len(models) == 0:
|
||||
pytest.skip("No Ollama models available for testing")
|
||||
|
||||
# Use the first available model
|
||||
model = models[0]
|
||||
|
||||
# Set up client
|
||||
def run_server():
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
|
||||
client = Letta(base_url=server_url, token=None)
|
||||
|
||||
# Create agent with Ollama model
|
||||
agent = client.agents.create(
|
||||
name="test_ollama_async_agent",
|
||||
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
|
||||
model=model.handle,
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
try:
|
||||
# Start async job
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")]
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
def wait_for_run_completion(run_id: str, timeout: float = 30.0):
|
||||
start = time.time()
|
||||
while True:
|
||||
current_run = client.runs.retrieve(run_id)
|
||||
if current_run.status == "completed":
|
||||
return current_run
|
||||
if current_run.status == "failed":
|
||||
raise RuntimeError(f"Run {run_id} failed: {current_run.metadata}")
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Run {run_id} timed out")
|
||||
time.sleep(0.5)
|
||||
|
||||
completed_run = wait_for_run_completion(run.id)
|
||||
|
||||
# Verify the job completed successfully
|
||||
assert completed_run.status == "completed"
|
||||
assert "result" in completed_run.metadata
|
||||
|
||||
# Get messages from the result
|
||||
result = completed_run.metadata["result"]
|
||||
assert "messages" in result
|
||||
messages = result["messages"]
|
||||
assert len(messages) > 0
|
||||
|
||||
# Find assistant response
|
||||
assistant_response = None
|
||||
for msg in messages:
|
||||
if msg.get("message_type") == "assistant_message":
|
||||
assistant_response = msg
|
||||
break
|
||||
|
||||
assert assistant_response is not None
|
||||
assert len(assistant_response.get("text", "")) > 0
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.agents.delete(agent.id)
|
||||
@@ -7,8 +7,10 @@ from letta.schemas.providers import (
|
||||
GoogleAIProvider,
|
||||
GoogleVertexProvider,
|
||||
GroqProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
TogetherProvider,
|
||||
VLLMChatCompletionsProvider,
|
||||
)
|
||||
from letta.settings import model_settings
|
||||
|
||||
@@ -98,20 +100,21 @@ def test_azure():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
# def test_ollama():
|
||||
# provider = OllamaProvider(
|
||||
# name="ollama",
|
||||
# base_url=model_settings.ollama_base_url,
|
||||
# api_key=None,
|
||||
# default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
# )
|
||||
# models = provider.list_llm_models()
|
||||
# assert len(models) > 0
|
||||
# assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
#
|
||||
# embedding_models = provider.list_embedding_models()
|
||||
# assert len(embedding_models) > 0
|
||||
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
@pytest.mark.skipif(model_settings.ollama_base_url is None, reason="Only run if OLLAMA_BASE_URL is set.")
|
||||
def test_ollama():
|
||||
provider = OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=model_settings.ollama_base_url,
|
||||
api_key=None,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_googleai():
|
||||
@@ -225,9 +228,10 @@ def test_custom_anthropic():
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
# def test_vllm():
|
||||
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
|
||||
# models = provider.list_llm_models()
|
||||
# print(models)
|
||||
#
|
||||
# provider.list_embedding_models()
|
||||
@pytest.mark.skipif(model_settings.vllm_api_base is None, reason="Only run if VLLM_API_BASE is set.")
|
||||
def test_vllm():
|
||||
provider = VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
|
||||
provider.list_embedding_models()
|
||||
|
||||
Reference in New Issue
Block a user