From cef4eb191085c214fb07aae28bfe5bb1503b29c8 Mon Sep 17 00:00:00 2001 From: Kian Jones <11655409+kianjones9@users.noreply.github.com> Date: Mon, 21 Jul 2025 18:26:23 -0700 Subject: [PATCH] feat(ci): Add coverage for self-hosted providers (#2976) --- tests/configs/llm_model_configs/ollama.json | 10 +- tests/configs/llm_model_configs/vllm.json | 7 + tests/integration_test_send_message.py | 2 +- tests/test_ollama.py | 266 ++++++++++++++++++++ tests/test_providers.py | 44 ++-- 5 files changed, 303 insertions(+), 26 deletions(-) create mode 100644 tests/configs/llm_model_configs/vllm.json create mode 100644 tests/test_ollama.py diff --git a/tests/configs/llm_model_configs/ollama.json b/tests/configs/llm_model_configs/ollama.json index bce3ba74..3db25192 100644 --- a/tests/configs/llm_model_configs/ollama.json +++ b/tests/configs/llm_model_configs/ollama.json @@ -1,7 +1,7 @@ { - "context_window": 8192, - "model_endpoint_type": "ollama", - "model_endpoint": "http://127.0.0.1:11434", - "model": "thewindmom/hermes-3-llama-3.1-8b", - "put_inner_thoughts_in_kwargs": true + "context_window": 8192, + "model_endpoint_type": "ollama", + "model_endpoint": "http://127.0.0.1:11434", + "model": "qwen3:32b", + "put_inner_thoughts_in_kwargs": true } diff --git a/tests/configs/llm_model_configs/vllm.json b/tests/configs/llm_model_configs/vllm.json new file mode 100644 index 00000000..54440ac4 --- /dev/null +++ b/tests/configs/llm_model_configs/vllm.json @@ -0,0 +1,7 @@ +{ + "context_window": 8192, + "model_endpoint_type": "vllm", + "model_endpoint": "http://127.0.0.1:8000", + "model": "Qwen/Qwen3-32B-AWQ", + "put_inner_thoughts_in_kwargs": true +} diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 4493c6b5..1a48dc8c 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -115,7 +115,7 @@ all_configs = [ "gemini-2.5-flash-vertex.json", "gemini-2.5-pro-vertex.json", "together-qwen-2.5-72b-instruct.json", - # "ollama.json", # TODO (cliandy): enable this in ollama testing + "ollama.json", # TODO (cliandy): enable this in ollama testing ] diff --git a/tests/test_ollama.py b/tests/test_ollama.py new file mode 100644 index 00000000..3da98480 --- /dev/null +++ b/tests/test_ollama.py @@ -0,0 +1,266 @@ +import pytest +from letta_client import Letta + +from letta.schemas.providers import OllamaProvider +from letta.settings import model_settings + + +@pytest.fixture +def ollama_provider(): + """Create an Ollama provider for testing""" + return OllamaProvider( + name="ollama", + base_url=model_settings.ollama_base_url or "http://localhost:11434", + api_key=None, + default_prompt_formatter="chatml", + ) + + +@pytest.mark.asyncio +async def test_list_llm_models_async(ollama_provider): + """Test async listing of LLM models from Ollama""" + models = await ollama_provider.list_llm_models_async() + assert len(models) >= 0 + + model = models[0] + assert model.handle == f"{ollama_provider.name}/{model.model}" + assert model.model_endpoint_type == "ollama" + assert model.model_endpoint == ollama_provider.base_url + assert model.context_window is not None + assert model.context_window > 0 + + +# noinspection DuplicatedCode +@pytest.mark.asyncio +async def test_list_embedding_models_async(ollama_provider): + """Test async listing of embedding models from Ollama""" + embedding_models = await ollama_provider.list_embedding_models_async() + assert len(embedding_models) >= 0 + + model = embedding_models[0] + assert model.handle == f"{ollama_provider.name}/{model.embedding_model}" + assert model.embedding_endpoint_type == "ollama" + assert model.embedding_endpoint == ollama_provider.base_url + assert model.embedding_dim is not None + assert model.embedding_dim > 0 + + +def test_send_message_with_ollama_sync(ollama_provider): + """Test sending a message with Ollama (sync version)""" + import os + import threading + + from letta_client import MessageCreate + + from tests.utils import wait_for_server + + # Skip if no models available + models = ollama_provider.list_llm_models() + if len(models) == 0: + pytest.skip("No Ollama models available for testing") + + # Use the first available model + model = models[0] + + # Set up client (similar to other tests) + def run_server(): + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + wait_for_server(server_url) + + client = Letta(base_url=server_url, token=None) + + # Create agent with Ollama model + agent = client.agents.create( + name="test_ollama_agent", + memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}], + model=model.handle, + embedding="letta/letta-free", + ) + + try: + # Send a simple message + response = client.agents.messages.create( + agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond with just 'Hi there!'")] + ) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + + # Find the assistant response + assistant_response = None + for msg in response.messages: + if msg.message_type == "assistant_message": + assistant_response = msg + break + + assert assistant_response is not None + assert len(assistant_response.text) > 0 + + finally: + # Clean up + client.agents.delete(agent.id) + + +@pytest.mark.asyncio +async def test_send_message_with_ollama_async_streaming(ollama_provider): + """Test sending a message with Ollama using async streaming""" + import os + import threading + + from letta_client import MessageCreate + + from tests.utils import wait_for_server + + # Skip if no models available + models = await ollama_provider.list_llm_models_async() + if len(models) == 0: + pytest.skip("No Ollama models available for testing") + + # Use the first available model + model = models[0] + + # Set up client + def run_server(): + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + wait_for_server(server_url) + + client = Letta(base_url=server_url, token=None) + + # Create agent with Ollama model + agent = client.agents.create( + name="test_ollama_streaming_agent", + memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}], + model=model.handle, + embedding="letta/letta-free", + ) + + try: + # Test step streaming (no token streaming) + response_stream = client.agents.messages.create_stream( + agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")], stream_tokens=False + ) + + # Collect streamed messages + streamed_messages = [] + for chunk in response_stream: + if hasattr(chunk, "messages"): + streamed_messages.extend(chunk.messages) + + # Verify streaming response + assert len(streamed_messages) > 0 + + # Find assistant response in stream + assistant_response = None + for msg in streamed_messages: + if msg.message_type == "assistant_message": + assistant_response = msg + break + + assert assistant_response is not None + assert len(assistant_response.text) > 0 + + finally: + # Clean up + client.agents.delete(agent.id) + + +@pytest.mark.asyncio +async def test_send_message_with_ollama_async_job(ollama_provider): + """Test sending a message with Ollama using async background job""" + import os + import threading + import time + + from letta_client import MessageCreate + + from tests.utils import wait_for_server + + # Skip if no models available + models = await ollama_provider.list_llm_models_async() + if len(models) == 0: + pytest.skip("No Ollama models available for testing") + + # Use the first available model + model = models[0] + + # Set up client + def run_server(): + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + wait_for_server(server_url) + + client = Letta(base_url=server_url, token=None) + + # Create agent with Ollama model + agent = client.agents.create( + name="test_ollama_async_agent", + memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}], + model=model.handle, + embedding="letta/letta-free", + ) + + try: + # Start async job + run = client.agents.messages.create_async( + agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")] + ) + + # Wait for completion + def wait_for_run_completion(run_id: str, timeout: float = 30.0): + start = time.time() + while True: + current_run = client.runs.retrieve(run_id) + if current_run.status == "completed": + return current_run + if current_run.status == "failed": + raise RuntimeError(f"Run {run_id} failed: {current_run.metadata}") + if time.time() - start > timeout: + raise TimeoutError(f"Run {run_id} timed out") + time.sleep(0.5) + + completed_run = wait_for_run_completion(run.id) + + # Verify the job completed successfully + assert completed_run.status == "completed" + assert "result" in completed_run.metadata + + # Get messages from the result + result = completed_run.metadata["result"] + assert "messages" in result + messages = result["messages"] + assert len(messages) > 0 + + # Find assistant response + assistant_response = None + for msg in messages: + if msg.get("message_type") == "assistant_message": + assistant_response = msg + break + + assert assistant_response is not None + assert len(assistant_response.get("text", "")) > 0 + + finally: + # Clean up + client.agents.delete(agent.id) diff --git a/tests/test_providers.py b/tests/test_providers.py index 50d03e5f..439d3417 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -7,8 +7,10 @@ from letta.schemas.providers import ( GoogleAIProvider, GoogleVertexProvider, GroqProvider, + OllamaProvider, OpenAIProvider, TogetherProvider, + VLLMChatCompletionsProvider, ) from letta.settings import model_settings @@ -98,20 +100,21 @@ def test_azure(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -# def test_ollama(): -# provider = OllamaProvider( -# name="ollama", -# base_url=model_settings.ollama_base_url, -# api_key=None, -# default_prompt_formatter=model_settings.default_prompt_formatter, -# ) -# models = provider.list_llm_models() -# assert len(models) > 0 -# assert models[0].handle == f"{provider.name}/{models[0].model}" -# -# embedding_models = provider.list_embedding_models() -# assert len(embedding_models) > 0 -# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" +@pytest.mark.skipif(model_settings.ollama_base_url is None, reason="Only run if OLLAMA_BASE_URL is set.") +def test_ollama(): + provider = OllamaProvider( + name="ollama", + base_url=model_settings.ollama_base_url, + api_key=None, + default_prompt_formatter=model_settings.default_prompt_formatter, + ) + models = provider.list_llm_models() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = provider.list_embedding_models() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_googleai(): @@ -225,9 +228,10 @@ def test_custom_anthropic(): assert models[0].handle == f"{provider.name}/{models[0].model}" -# def test_vllm(): -# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE")) -# models = provider.list_llm_models() -# print(models) -# -# provider.list_embedding_models() +@pytest.mark.skipif(model_settings.vllm_api_base is None, reason="Only run if VLLM_API_BASE is set.") +def test_vllm(): + provider = VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base) + models = provider.list_llm_models() + print(models) + + provider.list_embedding_models()