feat: move ollama to new agent loop (#3615)
This commit is contained in:
@@ -58,7 +58,7 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.openai | ProviderType.together:
|
||||
case ProviderType.openai | ProviderType.together | ProviderType.ollama:
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
|
||||
@@ -13,6 +13,8 @@ from letta.schemas.providers.openai import OpenAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ollama_prefix = "/v1"
|
||||
|
||||
|
||||
class OllamaProvider(OpenAIProvider):
|
||||
"""Ollama provider that uses the native /api/generate endpoint
|
||||
@@ -43,13 +45,13 @@ class OllamaProvider(OpenAIProvider):
|
||||
for model in response_json["models"]:
|
||||
context_window = self.get_model_context_window(model["name"])
|
||||
if context_window is None:
|
||||
print(f"Ollama model {model['name']} has no context window")
|
||||
continue
|
||||
print(f"Ollama model {model['name']} has no context window, using default 32000")
|
||||
context_window = 32000
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["name"],
|
||||
model_endpoint_type="ollama",
|
||||
model_endpoint=self.base_url,
|
||||
model_endpoint_type=ProviderType.ollama,
|
||||
model_endpoint=f"{self.base_url}{ollama_prefix}",
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model["name"]),
|
||||
@@ -75,13 +77,14 @@ class OllamaProvider(OpenAIProvider):
|
||||
for model in response_json["models"]:
|
||||
embedding_dim = await self._get_model_embedding_dim_async(model["name"])
|
||||
if not embedding_dim:
|
||||
print(f"Ollama model {model['name']} has no embedding dimension")
|
||||
continue
|
||||
print(f"Ollama model {model['name']} has no embedding dimension, using default 1024")
|
||||
# continue
|
||||
embedding_dim = 1024
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model["name"],
|
||||
embedding_endpoint_type="ollama",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_endpoint_type=ProviderType.ollama,
|
||||
embedding_endpoint=f"{self.base_url}{ollama_prefix}",
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle(model["name"], is_embedding=True),
|
||||
|
||||
@@ -865,7 +865,15 @@ async def send_message(
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
|
||||
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"together",
|
||||
"google_ai",
|
||||
"google_vertex",
|
||||
"bedrock",
|
||||
"ollama",
|
||||
]
|
||||
|
||||
# Create a new run for execution tracking
|
||||
if settings.track_agent_run:
|
||||
@@ -999,7 +1007,15 @@ async def send_message_streaming(
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
|
||||
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"together",
|
||||
"google_ai",
|
||||
"google_vertex",
|
||||
"bedrock",
|
||||
"ollama",
|
||||
]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]
|
||||
not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT
|
||||
|
||||
@@ -1194,6 +1210,7 @@ async def _process_message_background(
|
||||
"google_ai",
|
||||
"google_vertex",
|
||||
"bedrock",
|
||||
"ollama",
|
||||
]
|
||||
if agent_eligible and model_compatible:
|
||||
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
|
||||
@@ -1373,7 +1390,15 @@ async def preview_raw_payload(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
|
||||
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"together",
|
||||
"google_ai",
|
||||
"google_vertex",
|
||||
"bedrock",
|
||||
"ollama",
|
||||
]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
if agent.enable_sleeptime:
|
||||
@@ -1433,7 +1458,15 @@ async def summarize_agent_conversation(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
|
||||
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"together",
|
||||
"google_ai",
|
||||
"google_vertex",
|
||||
"bedrock",
|
||||
"ollama",
|
||||
]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
agent = LettaAgent(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model_endpoint_type": "ollama",
|
||||
"model_endpoint": "http://127.0.0.1:11434",
|
||||
"model": "qwen3:32b",
|
||||
"model_endpoint": "http://127.0.0.1:11434/v1",
|
||||
"model": "qwen2.5:7b",
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
||||
|
||||
@@ -108,47 +108,35 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
|
||||
# configs for models that are to dumb to do much other than messaging
|
||||
limited_configs = [
|
||||
"ollama.json",
|
||||
"together-qwen-2.5-72b-instruct.json",
|
||||
]
|
||||
|
||||
all_configs = [
|
||||
# "openai-gpt-4o-mini.json",
|
||||
# "openai-o1.json",
|
||||
# "openai-o3.json",
|
||||
# "openai-o4-mini.json",
|
||||
# "azure-gpt-4o-mini.json",
|
||||
# "claude-4-sonnet.json",
|
||||
# "claude-3-5-sonnet.json",
|
||||
# "claude-3-7-sonnet.json",
|
||||
"openai-gpt-4o-mini.json",
|
||||
"openai-o1.json",
|
||||
"openai-o3.json",
|
||||
"openai-o4-mini.json",
|
||||
"azure-gpt-4o-mini.json",
|
||||
"claude-4-sonnet.json",
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
# "bedrock-claude-4-sonnet.json",
|
||||
# "gemini-2.5-pro.json",
|
||||
# "gemini-2.5-flash.json",
|
||||
# "gemini-2.5-flash-vertex.json",
|
||||
# "gemini-2.5-pro-vertex.json",
|
||||
# "together-qwen-2.5-72b-instruct.json",
|
||||
# "ollama.json", # TODO (cliandy): enable this in ollama testing
|
||||
"bedrock-claude-4-sonnet.json",
|
||||
"gemini-1.5-pro.json",
|
||||
"gemini-2.5-flash-vertex.json",
|
||||
"gemini-2.5-pro-vertex.json",
|
||||
"ollama.json",
|
||||
"together-qwen-2.5-72b-instruct.json",
|
||||
]
|
||||
|
||||
reasoning_configs = [
|
||||
"openai-o1.json",
|
||||
"openai-o3.json",
|
||||
"openai-o4-mini.json",
|
||||
# "azure-gpt-4o-mini.json",
|
||||
# "claude-4-sonnet.json",
|
||||
# "claude-3-5-sonnet.json",
|
||||
# "claude-3-7-sonnet.json",
|
||||
# "claude-3-7-sonnet-extended.json",
|
||||
# "bedrock-claude-4-sonnet.json",
|
||||
# "gemini-1.5-pro.json",
|
||||
# "gemini-2.5-flash-vertex.json",
|
||||
# "gemini-2.5-pro-vertex.json",
|
||||
# "together-qwen-2.5-72b-instruct.json",
|
||||
# "ollama.json", # TODO (cliandy): enable this in ollama testing
|
||||
# TODO @jnjpng: not supported in CI yet, uncomment to test locally (requires lmstudio running locally with respective models loaded)
|
||||
# "lmstudio-meta-llama-3.1-8b-instruct.json",
|
||||
# "lmstudio-qwen-2.5-7b-instruct.json",
|
||||
# "mlx-qwen-2.5-7b-instruct.json",
|
||||
# "mlx-meta-llama-3.1-8b-instruct-8bit.json",
|
||||
# "mlx-ministral-8b-instruct-2410.json",
|
||||
# "bartowski-ministral-8b-instruct-2410.json"
|
||||
]
|
||||
|
||||
|
||||
@@ -360,7 +348,6 @@ def assert_tool_call_response(
|
||||
def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Validate that OpenAI format assistant messages with tool calls have no inner thoughts content.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries in OpenAI format
|
||||
"""
|
||||
@@ -383,7 +370,6 @@ def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
|
||||
def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Validate that Anthropic/Claude format assistant messages with tool_use have no <thinking> tags.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries in Anthropic format
|
||||
"""
|
||||
@@ -424,7 +410,6 @@ def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
|
||||
def validate_google_format_scrubbing(contents: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Validate that Google/Gemini format model messages with functionCall have no thinking field.
|
||||
|
||||
Args:
|
||||
contents: List of content dictionaries in Google format (uses 'contents' instead of 'messages')
|
||||
"""
|
||||
@@ -753,6 +738,18 @@ def test_url_image_input(
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
# get the config filename
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
@@ -779,6 +776,18 @@ def test_base64_image_input(
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
# get the config filename
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
@@ -886,6 +895,18 @@ def test_step_streaming_tool_call(
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
# get the config filename
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
@@ -1004,6 +1025,18 @@ def test_token_streaming_tool_call(
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
# get the config filename
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
@@ -1152,6 +1185,17 @@ def test_async_tool_call(
|
||||
Tests sending a message as an asynchronous job using the synchronous client.
|
||||
Waits for job completion and asserts that the result messages are as expected.
|
||||
"""
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
@@ -1272,6 +1316,17 @@ def test_async_greeting_with_callback_url(
|
||||
Tests sending a message as an asynchronous job with callback URL functionality.
|
||||
Validates that callbacks are properly sent with correct payload structure.
|
||||
"""
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
with callback_server() as server:
|
||||
@@ -1337,6 +1392,17 @@ def test_async_callback_failure_scenarios(
|
||||
Tests that job completion works even when callback URLs fail.
|
||||
This ensures callback failures don't affect job processing.
|
||||
"""
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
# Test with invalid callback URL - job should still complete
|
||||
@@ -1375,6 +1441,18 @@ def test_async_callback_failure_scenarios(
|
||||
)
|
||||
def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLMConfig):
|
||||
"""Test that summarization is automatically triggered."""
|
||||
# get the config filename
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model (runs too slow)
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
# pydantic prevents us for overriding the context window paramter in the passed LLMConfig
|
||||
new_llm_config = llm_config.model_dump()
|
||||
new_llm_config["context_window"] = 3000
|
||||
@@ -1641,6 +1719,10 @@ def test_inner_thoughts_false_non_reasoner_models(
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
# skip if this is a reasoning model
|
||||
if not config_filename or config_filename in reasoning_configs:
|
||||
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")
|
||||
@@ -1682,6 +1764,10 @@ def test_inner_thoughts_false_non_reasoner_models_streaming(
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
# skip if this is a reasoning model
|
||||
if not config_filename or config_filename in reasoning_configs:
|
||||
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
import pytest
|
||||
from letta_client import Letta
|
||||
|
||||
from letta.schemas.providers import OllamaProvider
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_provider():
|
||||
"""Create an Ollama provider for testing"""
|
||||
return OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=model_settings.ollama_base_url or "http://localhost:11434",
|
||||
api_key=None,
|
||||
default_prompt_formatter="chatml",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_llm_models_async(ollama_provider):
|
||||
"""Test async listing of LLM models from Ollama"""
|
||||
models = await ollama_provider.list_llm_models_async()
|
||||
assert len(models) >= 0
|
||||
|
||||
model = models[0]
|
||||
assert model.handle == f"{ollama_provider.name}/{model.model}"
|
||||
assert model.model_endpoint_type == "ollama"
|
||||
assert model.model_endpoint == ollama_provider.base_url
|
||||
assert model.context_window is not None
|
||||
assert model.context_window > 0
|
||||
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_embedding_models_async(ollama_provider):
|
||||
"""Test async listing of embedding models from Ollama"""
|
||||
embedding_models = await ollama_provider.list_embedding_models_async()
|
||||
assert len(embedding_models) >= 0
|
||||
|
||||
model = embedding_models[0]
|
||||
assert model.handle == f"{ollama_provider.name}/{model.embedding_model}"
|
||||
assert model.embedding_endpoint_type == "ollama"
|
||||
assert model.embedding_endpoint == ollama_provider.base_url
|
||||
assert model.embedding_dim is not None
|
||||
assert model.embedding_dim > 0
|
||||
|
||||
|
||||
def test_send_message_with_ollama_sync(ollama_provider):
|
||||
"""Test sending a message with Ollama (sync version)"""
|
||||
import os
|
||||
import threading
|
||||
|
||||
from letta_client import MessageCreate
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Skip if no models available
|
||||
models = ollama_provider.list_llm_models()
|
||||
if len(models) == 0:
|
||||
pytest.skip("No Ollama models available for testing")
|
||||
|
||||
# Use the first available model
|
||||
model = models[0]
|
||||
|
||||
# Set up client (similar to other tests)
|
||||
def run_server():
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
|
||||
client = Letta(base_url=server_url, token=None)
|
||||
|
||||
# Create agent with Ollama model
|
||||
agent = client.agents.create(
|
||||
name="test_ollama_agent",
|
||||
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
|
||||
model=model.handle,
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
try:
|
||||
# Send a simple message
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond with just 'Hi there!'")]
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
# Find the assistant response
|
||||
assistant_response = None
|
||||
for msg in response.messages:
|
||||
if msg.message_type == "assistant_message":
|
||||
assistant_response = msg
|
||||
break
|
||||
|
||||
assert assistant_response is not None
|
||||
assert len(assistant_response.text) > 0
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_ollama_async_streaming(ollama_provider):
|
||||
"""Test sending a message with Ollama using async streaming"""
|
||||
import os
|
||||
import threading
|
||||
|
||||
from letta_client import MessageCreate
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Skip if no models available
|
||||
models = await ollama_provider.list_llm_models_async()
|
||||
if len(models) == 0:
|
||||
pytest.skip("No Ollama models available for testing")
|
||||
|
||||
# Use the first available model
|
||||
model = models[0]
|
||||
|
||||
# Set up client
|
||||
def run_server():
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
|
||||
client = Letta(base_url=server_url, token=None)
|
||||
|
||||
# Create agent with Ollama model
|
||||
agent = client.agents.create(
|
||||
name="test_ollama_streaming_agent",
|
||||
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
|
||||
model=model.handle,
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
try:
|
||||
# Test step streaming (no token streaming)
|
||||
response_stream = client.agents.messages.create_stream(
|
||||
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")], stream_tokens=False
|
||||
)
|
||||
|
||||
# Collect streamed messages
|
||||
streamed_messages = []
|
||||
for chunk in response_stream:
|
||||
if hasattr(chunk, "messages"):
|
||||
streamed_messages.extend(chunk.messages)
|
||||
|
||||
# Verify streaming response
|
||||
assert len(streamed_messages) > 0
|
||||
|
||||
# Find assistant response in stream
|
||||
assistant_response = None
|
||||
for msg in streamed_messages:
|
||||
if msg.message_type == "assistant_message":
|
||||
assistant_response = msg
|
||||
break
|
||||
|
||||
assert assistant_response is not None
|
||||
assert len(assistant_response.text) > 0
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_ollama_async_job(ollama_provider):
|
||||
"""Test sending a message with Ollama using async background job"""
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from letta_client import MessageCreate
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Skip if no models available
|
||||
models = await ollama_provider.list_llm_models_async()
|
||||
if len(models) == 0:
|
||||
pytest.skip("No Ollama models available for testing")
|
||||
|
||||
# Use the first available model
|
||||
model = models[0]
|
||||
|
||||
# Set up client
|
||||
def run_server():
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
|
||||
client = Letta(base_url=server_url, token=None)
|
||||
|
||||
# Create agent with Ollama model
|
||||
agent = client.agents.create(
|
||||
name="test_ollama_async_agent",
|
||||
memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}],
|
||||
model=model.handle,
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
try:
|
||||
# Start async job
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")]
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
def wait_for_run_completion(run_id: str, timeout: float = 30.0):
|
||||
start = time.time()
|
||||
while True:
|
||||
current_run = client.runs.retrieve(run_id)
|
||||
if current_run.status == "completed":
|
||||
return current_run
|
||||
if current_run.status == "failed":
|
||||
raise RuntimeError(f"Run {run_id} failed: {current_run.metadata}")
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Run {run_id} timed out")
|
||||
time.sleep(0.5)
|
||||
|
||||
completed_run = wait_for_run_completion(run.id)
|
||||
|
||||
# Verify the job completed successfully
|
||||
assert completed_run.status == "completed"
|
||||
assert "result" in completed_run.metadata
|
||||
|
||||
# Get messages from the result
|
||||
result = completed_run.metadata["result"]
|
||||
assert "messages" in result
|
||||
messages = result["messages"]
|
||||
assert len(messages) > 0
|
||||
|
||||
# Find assistant response
|
||||
assistant_response = None
|
||||
for msg in messages:
|
||||
if msg.get("message_type") == "assistant_message":
|
||||
assistant_response = msg
|
||||
break
|
||||
|
||||
assert assistant_response is not None
|
||||
assert len(assistant_response.get("text", "")) > 0
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.agents.delete(agent.id)
|
||||
Reference in New Issue
Block a user