Files
letta-server/tests/test_ollama.py

267 lines
8.4 KiB
Python

import pytest
from letta_client import Letta
from letta.schemas.providers import OllamaProvider
from letta.settings import model_settings
@pytest.fixture
def ollama_provider():
"""Create an Ollama provider for testing"""
return OllamaProvider(
name="ollama",
base_url=model_settings.ollama_base_url or "http://localhost:11434",
api_key=None,
default_prompt_formatter="chatml",
)
@pytest.mark.asyncio
async def test_list_llm_models_async(ollama_provider):
"""Test async listing of LLM models from Ollama"""
models = await ollama_provider.list_llm_models_async()
assert len(models) >= 0
model = models[0]
assert model.handle == f"{ollama_provider.name}/{model.model}"
assert model.model_endpoint_type == "ollama"
assert model.model_endpoint == ollama_provider.base_url
assert model.context_window is not None
assert model.context_window > 0
# noinspection DuplicatedCode
@pytest.mark.asyncio
async def test_list_embedding_models_async(ollama_provider):
"""Test async listing of embedding models from Ollama"""
embedding_models = await ollama_provider.list_embedding_models_async()
assert len(embedding_models) >= 0
model = embedding_models[0]
assert model.handle == f"{ollama_provider.name}/{model.embedding_model}"
assert model.embedding_endpoint_type == "ollama"
assert model.embedding_endpoint == ollama_provider.base_url
assert model.embedding_dim is not None
assert model.embedding_dim > 0
def test_send_message_with_ollama_sync(ollama_provider):
"""Test sending a message with Ollama (sync version)"""
import os
import threading
from letta_client import MessageCreate
from tests.utils import wait_for_server
# Skip if no models available
models = ollama_provider.list_llm_models()
if len(models) == 0:
pytest.skip("No Ollama models available for testing")
# Use the first available model
model = models[0]
# Set up client (similar to other tests)
def run_server():
from letta.server.rest_api.app import start_server
start_server(debug=True)
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url)
client = Letta(base_url=server_url, token=None)
# Create agent with Ollama model
agent = client.agents.create(
name="test_ollama_agent",
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
model=model.handle,
embedding="letta/letta-free",
)
try:
# Send a simple message
response = client.agents.messages.create(
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond with just 'Hi there!'")]
)
# Verify response
assert response is not None
assert len(response.messages) > 0
# Find the assistant response
assistant_response = None
for msg in response.messages:
if msg.message_type == "assistant_message":
assistant_response = msg
break
assert assistant_response is not None
assert len(assistant_response.text) > 0
finally:
# Clean up
client.agents.delete(agent.id)
@pytest.mark.asyncio
async def test_send_message_with_ollama_async_streaming(ollama_provider):
"""Test sending a message with Ollama using async streaming"""
import os
import threading
from letta_client import MessageCreate
from tests.utils import wait_for_server
# Skip if no models available
models = await ollama_provider.list_llm_models_async()
if len(models) == 0:
pytest.skip("No Ollama models available for testing")
# Use the first available model
model = models[0]
# Set up client
def run_server():
from letta.server.rest_api.app import start_server
start_server(debug=True)
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url)
client = Letta(base_url=server_url, token=None)
# Create agent with Ollama model
agent = client.agents.create(
name="test_ollama_streaming_agent",
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
model=model.handle,
embedding="letta/letta-free",
)
try:
# Test step streaming (no token streaming)
response_stream = client.agents.messages.create_stream(
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")], stream_tokens=False
)
# Collect streamed messages
streamed_messages = []
for chunk in response_stream:
if hasattr(chunk, "messages"):
streamed_messages.extend(chunk.messages)
# Verify streaming response
assert len(streamed_messages) > 0
# Find assistant response in stream
assistant_response = None
for msg in streamed_messages:
if msg.message_type == "assistant_message":
assistant_response = msg
break
assert assistant_response is not None
assert len(assistant_response.text) > 0
finally:
# Clean up
client.agents.delete(agent.id)
@pytest.mark.asyncio
async def test_send_message_with_ollama_async_job(ollama_provider):
"""Test sending a message with Ollama using async background job"""
import os
import threading
import time
from letta_client import MessageCreate
from tests.utils import wait_for_server
# Skip if no models available
models = await ollama_provider.list_llm_models_async()
if len(models) == 0:
pytest.skip("No Ollama models available for testing")
# Use the first available model
model = models[0]
# Set up client
def run_server():
from letta.server.rest_api.app import start_server
start_server(debug=True)
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url)
client = Letta(base_url=server_url, token=None)
# Create agent with Ollama model
agent = client.agents.create(
name="test_ollama_async_agent",
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
model=model.handle,
embedding="letta/letta-free",
)
try:
# Start async job
run = client.agents.messages.create_async(
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")]
)
# Wait for completion
def wait_for_run_completion(run_id: str, timeout: float = 30.0):
start = time.time()
while True:
current_run = client.runs.retrieve(run_id)
if current_run.status == "completed":
return current_run
if current_run.status == "failed":
raise RuntimeError(f"Run {run_id} failed: {current_run.metadata}")
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} timed out")
time.sleep(0.5)
completed_run = wait_for_run_completion(run.id)
# Verify the job completed successfully
assert completed_run.status == "completed"
assert "result" in completed_run.metadata
# Get messages from the result
result = completed_run.metadata["result"]
assert "messages" in result
messages = result["messages"]
assert len(messages) > 0
# Find assistant response
assistant_response = None
for msg in messages:
if msg.get("message_type") == "assistant_message":
assistant_response = msg
break
assert assistant_response is not None
assert len(assistant_response.get("text", "")) > 0
finally:
# Clean up
client.agents.delete(agent.id)