feat: move ollama to new agent loop (#3615)

This commit is contained in:
Sarah Wooders
2025-07-31 13:40:26 -07:00
committed by GitHub
parent cb04b877d6
commit c2b2d976b6
6 changed files with 173 additions and 317 deletions

View File

@@ -58,7 +58,7 @@ class LLMClient:
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.openai | ProviderType.together:
case ProviderType.openai | ProviderType.together | ProviderType.ollama:
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(

View File

@@ -13,6 +13,8 @@ from letta.schemas.providers.openai import OpenAIProvider
logger = get_logger(__name__)
ollama_prefix = "/v1"
class OllamaProvider(OpenAIProvider):
"""Ollama provider that uses the native /api/generate endpoint
@@ -43,13 +45,13 @@ class OllamaProvider(OpenAIProvider):
for model in response_json["models"]:
context_window = self.get_model_context_window(model["name"])
if context_window is None:
print(f"Ollama model {model['name']} has no context window")
continue
print(f"Ollama model {model['name']} has no context window, using default 32000")
context_window = 32000
configs.append(
LLMConfig(
model=model["name"],
model_endpoint_type="ollama",
model_endpoint=self.base_url,
model_endpoint_type=ProviderType.ollama,
model_endpoint=f"{self.base_url}{ollama_prefix}",
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
handle=self.get_handle(model["name"]),
@@ -75,13 +77,14 @@ class OllamaProvider(OpenAIProvider):
for model in response_json["models"]:
embedding_dim = await self._get_model_embedding_dim_async(model["name"])
if not embedding_dim:
print(f"Ollama model {model['name']} has no embedding dimension")
continue
print(f"Ollama model {model['name']} has no embedding dimension, using default 1024")
# continue
embedding_dim = 1024
configs.append(
EmbeddingConfig(
embedding_model=model["name"],
embedding_endpoint_type="ollama",
embedding_endpoint=self.base_url,
embedding_endpoint_type=ProviderType.ollama,
embedding_endpoint=f"{self.base_url}{ollama_prefix}",
embedding_dim=embedding_dim,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle(model["name"], is_embedding=True),

View File

@@ -865,7 +865,15 @@ async def send_message(
# TODO: This is redundant, remove soon
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
model_compatible = agent.llm_config.model_endpoint_type in [
"anthropic",
"openai",
"together",
"google_ai",
"google_vertex",
"bedrock",
"ollama",
]
# Create a new run for execution tracking
if settings.track_agent_run:
@@ -999,7 +1007,15 @@ async def send_message_streaming(
# TODO: This is redundant, remove soon
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
model_compatible = agent.llm_config.model_endpoint_type in [
"anthropic",
"openai",
"together",
"google_ai",
"google_vertex",
"bedrock",
"ollama",
]
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]
not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT
@@ -1194,6 +1210,7 @@ async def _process_message_background(
"google_ai",
"google_vertex",
"bedrock",
"ollama",
]
if agent_eligible and model_compatible:
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
@@ -1373,7 +1390,15 @@ async def preview_raw_payload(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
model_compatible = agent.llm_config.model_endpoint_type in [
"anthropic",
"openai",
"together",
"google_ai",
"google_vertex",
"bedrock",
"ollama",
]
if agent_eligible and model_compatible:
if agent.enable_sleeptime:
@@ -1433,7 +1458,15 @@ async def summarize_agent_conversation(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
model_compatible = agent.llm_config.model_endpoint_type in [
"anthropic",
"openai",
"together",
"google_ai",
"google_vertex",
"bedrock",
"ollama",
]
if agent_eligible and model_compatible:
agent = LettaAgent(

View File

@@ -1,7 +1,7 @@
{
"context_window": 8192,
"model_endpoint_type": "ollama",
"model_endpoint": "http://127.0.0.1:11434",
"model": "qwen3:32b",
"model_endpoint": "http://127.0.0.1:11434/v1",
"model": "qwen2.5:7b",
"put_inner_thoughts_in_kwargs": true
}

View File

@@ -108,47 +108,35 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
otid=USER_MESSAGE_OTID,
)
]
# configs for models that are to dumb to do much other than messaging
limited_configs = [
"ollama.json",
"together-qwen-2.5-72b-instruct.json",
]
all_configs = [
# "openai-gpt-4o-mini.json",
# "openai-o1.json",
# "openai-o3.json",
# "openai-o4-mini.json",
# "azure-gpt-4o-mini.json",
# "claude-4-sonnet.json",
# "claude-3-5-sonnet.json",
# "claude-3-7-sonnet.json",
"openai-gpt-4o-mini.json",
"openai-o1.json",
"openai-o3.json",
"openai-o4-mini.json",
"azure-gpt-4o-mini.json",
"claude-4-sonnet.json",
"claude-3-5-sonnet.json",
"claude-3-7-sonnet.json",
"claude-3-7-sonnet-extended.json",
# "bedrock-claude-4-sonnet.json",
# "gemini-2.5-pro.json",
# "gemini-2.5-flash.json",
# "gemini-2.5-flash-vertex.json",
# "gemini-2.5-pro-vertex.json",
# "together-qwen-2.5-72b-instruct.json",
# "ollama.json", # TODO (cliandy): enable this in ollama testing
"bedrock-claude-4-sonnet.json",
"gemini-1.5-pro.json",
"gemini-2.5-flash-vertex.json",
"gemini-2.5-pro-vertex.json",
"ollama.json",
"together-qwen-2.5-72b-instruct.json",
]
reasoning_configs = [
"openai-o1.json",
"openai-o3.json",
"openai-o4-mini.json",
# "azure-gpt-4o-mini.json",
# "claude-4-sonnet.json",
# "claude-3-5-sonnet.json",
# "claude-3-7-sonnet.json",
# "claude-3-7-sonnet-extended.json",
# "bedrock-claude-4-sonnet.json",
# "gemini-1.5-pro.json",
# "gemini-2.5-flash-vertex.json",
# "gemini-2.5-pro-vertex.json",
# "together-qwen-2.5-72b-instruct.json",
# "ollama.json", # TODO (cliandy): enable this in ollama testing
# TODO @jnjpng: not supported in CI yet, uncomment to test locally (requires lmstudio running locally with respective models loaded)
# "lmstudio-meta-llama-3.1-8b-instruct.json",
# "lmstudio-qwen-2.5-7b-instruct.json",
# "mlx-qwen-2.5-7b-instruct.json",
# "mlx-meta-llama-3.1-8b-instruct-8bit.json",
# "mlx-ministral-8b-instruct-2410.json",
# "bartowski-ministral-8b-instruct-2410.json"
]
@@ -360,7 +348,6 @@ def assert_tool_call_response(
def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
"""
Validate that OpenAI format assistant messages with tool calls have no inner thoughts content.
Args:
messages: List of message dictionaries in OpenAI format
"""
@@ -383,7 +370,6 @@ def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
"""
Validate that Anthropic/Claude format assistant messages with tool_use have no <thinking> tags.
Args:
messages: List of message dictionaries in Anthropic format
"""
@@ -424,7 +410,6 @@ def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
def validate_google_format_scrubbing(contents: List[Dict[str, Any]]) -> None:
"""
Validate that Google/Gemini format model messages with functionCall have no thinking field.
Args:
contents: List of content dictionaries in Google format (uses 'contents' instead of 'messages')
"""
@@ -753,6 +738,18 @@ def test_url_image_input(
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
@@ -779,6 +776,18 @@ def test_base64_image_input(
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
@@ -886,6 +895,18 @@ def test_step_streaming_tool_call(
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
@@ -1004,6 +1025,18 @@ def test_token_streaming_tool_call(
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
@@ -1152,6 +1185,17 @@ def test_async_tool_call(
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
@@ -1272,6 +1316,17 @@ def test_async_greeting_with_callback_url(
Tests sending a message as an asynchronous job with callback URL functionality.
Validates that callbacks are properly sent with correct payload structure.
"""
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
with callback_server() as server:
@@ -1337,6 +1392,17 @@ def test_async_callback_failure_scenarios(
Tests that job completion works even when callback URLs fail.
This ensures callback failures don't affect job processing.
"""
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Test with invalid callback URL - job should still complete
@@ -1375,6 +1441,18 @@ def test_async_callback_failure_scenarios(
)
def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLMConfig):
"""Test that summarization is automatically triggered."""
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a limited model (runs too slow)
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
# pydantic prevents us for overriding the context window paramter in the passed LLMConfig
new_llm_config = llm_config.model_dump()
new_llm_config["context_window"] = 3000
@@ -1641,6 +1719,10 @@ def test_inner_thoughts_false_non_reasoner_models(
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
# skip if this is a reasoning model
if not config_filename or config_filename in reasoning_configs:
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")
@@ -1682,6 +1764,10 @@ def test_inner_thoughts_false_non_reasoner_models_streaming(
config_filename = filename
break
# skip if this is a limited model
if not config_filename or config_filename in limited_configs:
pytest.skip(f"Skipping test for limited model {llm_config.model}")
# skip if this is a reasoning model
if not config_filename or config_filename in reasoning_configs:
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")

View File

@@ -1,266 +0,0 @@
import pytest
from letta_client import Letta
from letta.schemas.providers import OllamaProvider
from letta.settings import model_settings
@pytest.fixture
def ollama_provider():
"""Create an Ollama provider for testing"""
return OllamaProvider(
name="ollama",
base_url=model_settings.ollama_base_url or "http://localhost:11434",
api_key=None,
default_prompt_formatter="chatml",
)
@pytest.mark.asyncio
async def test_list_llm_models_async(ollama_provider):
"""Test async listing of LLM models from Ollama"""
models = await ollama_provider.list_llm_models_async()
assert len(models) >= 0
model = models[0]
assert model.handle == f"{ollama_provider.name}/{model.model}"
assert model.model_endpoint_type == "ollama"
assert model.model_endpoint == ollama_provider.base_url
assert model.context_window is not None
assert model.context_window > 0
# noinspection DuplicatedCode
@pytest.mark.asyncio
async def test_list_embedding_models_async(ollama_provider):
"""Test async listing of embedding models from Ollama"""
embedding_models = await ollama_provider.list_embedding_models_async()
assert len(embedding_models) >= 0
model = embedding_models[0]
assert model.handle == f"{ollama_provider.name}/{model.embedding_model}"
assert model.embedding_endpoint_type == "ollama"
assert model.embedding_endpoint == ollama_provider.base_url
assert model.embedding_dim is not None
assert model.embedding_dim > 0
def test_send_message_with_ollama_sync(ollama_provider):
"""Test sending a message with Ollama (sync version)"""
import os
import threading
from letta_client import MessageCreate
from tests.utils import wait_for_server
# Skip if no models available
models = ollama_provider.list_llm_models()
if len(models) == 0:
pytest.skip("No Ollama models available for testing")
# Use the first available model
model = models[0]
# Set up client (similar to other tests)
def run_server():
from letta.server.rest_api.app import start_server
start_server(debug=True)
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url)
client = Letta(base_url=server_url, token=None)
# Create agent with Ollama model
agent = client.agents.create(
name="test_ollama_agent",
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
model=model.handle,
embedding="letta/letta-free",
)
try:
# Send a simple message
response = client.agents.messages.create(
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond with just 'Hi there!'")]
)
# Verify response
assert response is not None
assert len(response.messages) > 0
# Find the assistant response
assistant_response = None
for msg in response.messages:
if msg.message_type == "assistant_message":
assistant_response = msg
break
assert assistant_response is not None
assert len(assistant_response.text) > 0
finally:
# Clean up
client.agents.delete(agent.id)
@pytest.mark.asyncio
async def test_send_message_with_ollama_async_streaming(ollama_provider):
"""Test sending a message with Ollama using async streaming"""
import os
import threading
from letta_client import MessageCreate
from tests.utils import wait_for_server
# Skip if no models available
models = await ollama_provider.list_llm_models_async()
if len(models) == 0:
pytest.skip("No Ollama models available for testing")
# Use the first available model
model = models[0]
# Set up client
def run_server():
from letta.server.rest_api.app import start_server
start_server(debug=True)
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url)
client = Letta(base_url=server_url, token=None)
# Create agent with Ollama model
agent = client.agents.create(
name="test_ollama_streaming_agent",
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
model=model.handle,
embedding="letta/letta-free",
)
try:
# Test step streaming (no token streaming)
response_stream = client.agents.messages.create_stream(
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")], stream_tokens=False
)
# Collect streamed messages
streamed_messages = []
for chunk in response_stream:
if hasattr(chunk, "messages"):
streamed_messages.extend(chunk.messages)
# Verify streaming response
assert len(streamed_messages) > 0
# Find assistant response in stream
assistant_response = None
for msg in streamed_messages:
if msg.message_type == "assistant_message":
assistant_response = msg
break
assert assistant_response is not None
assert len(assistant_response.text) > 0
finally:
# Clean up
client.agents.delete(agent.id)
@pytest.mark.asyncio
async def test_send_message_with_ollama_async_job(ollama_provider):
"""Test sending a message with Ollama using async background job"""
import os
import threading
import time
from letta_client import MessageCreate
from tests.utils import wait_for_server
# Skip if no models available
models = await ollama_provider.list_llm_models_async()
if len(models) == 0:
pytest.skip("No Ollama models available for testing")
# Use the first available model
model = models[0]
# Set up client
def run_server():
from letta.server.rest_api.app import start_server
start_server(debug=True)
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url)
client = Letta(base_url=server_url, token=None)
# Create agent with Ollama model
agent = client.agents.create(
name="test_ollama_async_agent",
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
model=model.handle,
embedding="letta/letta-free",
)
try:
# Start async job
run = client.agents.messages.create_async(
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")]
)
# Wait for completion
def wait_for_run_completion(run_id: str, timeout: float = 30.0):
start = time.time()
while True:
current_run = client.runs.retrieve(run_id)
if current_run.status == "completed":
return current_run
if current_run.status == "failed":
raise RuntimeError(f"Run {run_id} failed: {current_run.metadata}")
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} timed out")
time.sleep(0.5)
completed_run = wait_for_run_completion(run.id)
# Verify the job completed successfully
assert completed_run.status == "completed"
assert "result" in completed_run.metadata
# Get messages from the result
result = completed_run.metadata["result"]
assert "messages" in result
messages = result["messages"]
assert len(messages) > 0
# Find assistant response
assistant_response = None
for msg in messages:
if msg.get("message_type") == "assistant_message":
assistant_response = msg
break
assert assistant_response is not None
assert len(assistant_response.get("text", "")) > 0
finally:
# Clean up
client.agents.delete(agent.id)