diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index b21e4da2..1e03c5f7 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -58,7 +58,7 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) - case ProviderType.openai | ProviderType.together: + case ProviderType.openai | ProviderType.together | ProviderType.ollama: from letta.llm_api.openai_client import OpenAIClient return OpenAIClient( diff --git a/letta/schemas/providers/ollama.py b/letta/schemas/providers/ollama.py index b9ddaa2c..8cc8f720 100644 --- a/letta/schemas/providers/ollama.py +++ b/letta/schemas/providers/ollama.py @@ -13,6 +13,8 @@ from letta.schemas.providers.openai import OpenAIProvider logger = get_logger(__name__) +ollama_prefix = "/v1" + class OllamaProvider(OpenAIProvider): """Ollama provider that uses the native /api/generate endpoint @@ -43,13 +45,13 @@ class OllamaProvider(OpenAIProvider): for model in response_json["models"]: context_window = self.get_model_context_window(model["name"]) if context_window is None: - print(f"Ollama model {model['name']} has no context window") - continue + print(f"Ollama model {model['name']} has no context window, using default 32000") + context_window = 32000 configs.append( LLMConfig( model=model["name"], - model_endpoint_type="ollama", - model_endpoint=self.base_url, + model_endpoint_type=ProviderType.ollama, + model_endpoint=f"{self.base_url}{ollama_prefix}", model_wrapper=self.default_prompt_formatter, context_window=context_window, handle=self.get_handle(model["name"]), @@ -75,13 +77,14 @@ class OllamaProvider(OpenAIProvider): for model in response_json["models"]: embedding_dim = await self._get_model_embedding_dim_async(model["name"]) if not embedding_dim: - print(f"Ollama model {model['name']} has no embedding dimension") - continue + print(f"Ollama model {model['name']} has no embedding dimension, using default 1024") + # continue + embedding_dim = 1024 configs.append( EmbeddingConfig( embedding_model=model["name"], - embedding_endpoint_type="ollama", - embedding_endpoint=self.base_url, + embedding_endpoint_type=ProviderType.ollama, + embedding_endpoint=f"{self.base_url}{ollama_prefix}", embedding_dim=embedding_dim, embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, handle=self.get_handle(model["name"], is_embedding=True), diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 1b7eecf9..1cc10f8c 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -865,7 +865,15 @@ async def send_message( # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] + model_compatible = agent.llm_config.model_endpoint_type in [ + "anthropic", + "openai", + "together", + "google_ai", + "google_vertex", + "bedrock", + "ollama", + ] # Create a new run for execution tracking if settings.track_agent_run: @@ -999,7 +1007,15 @@ async def send_message_streaming( # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] + model_compatible = agent.llm_config.model_endpoint_type in [ + "anthropic", + "openai", + "together", + "google_ai", + "google_vertex", + "bedrock", + "ollama", + ] model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT @@ -1194,6 +1210,7 @@ async def _process_message_background( "google_ai", "google_vertex", "bedrock", + "ollama", ] if agent_eligible and model_compatible: if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: @@ -1373,7 +1390,15 @@ async def preview_raw_payload( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] + model_compatible = agent.llm_config.model_endpoint_type in [ + "anthropic", + "openai", + "together", + "google_ai", + "google_vertex", + "bedrock", + "ollama", + ] if agent_eligible and model_compatible: if agent.enable_sleeptime: @@ -1433,7 +1458,15 @@ async def summarize_agent_conversation( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] + model_compatible = agent.llm_config.model_endpoint_type in [ + "anthropic", + "openai", + "together", + "google_ai", + "google_vertex", + "bedrock", + "ollama", + ] if agent_eligible and model_compatible: agent = LettaAgent( diff --git a/tests/configs/llm_model_configs/ollama.json b/tests/configs/llm_model_configs/ollama.json index 3db25192..a4212689 100644 --- a/tests/configs/llm_model_configs/ollama.json +++ b/tests/configs/llm_model_configs/ollama.json @@ -1,7 +1,7 @@ { "context_window": 8192, "model_endpoint_type": "ollama", - "model_endpoint": "http://127.0.0.1:11434", - "model": "qwen3:32b", + "model_endpoint": "http://127.0.0.1:11434/v1", + "model": "qwen2.5:7b", "put_inner_thoughts_in_kwargs": true } diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 2952e2fd..b30aa42c 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -108,47 +108,35 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [ otid=USER_MESSAGE_OTID, ) ] + +# configs for models that are to dumb to do much other than messaging +limited_configs = [ + "ollama.json", + "together-qwen-2.5-72b-instruct.json", +] + all_configs = [ - # "openai-gpt-4o-mini.json", - # "openai-o1.json", - # "openai-o3.json", - # "openai-o4-mini.json", - # "azure-gpt-4o-mini.json", - # "claude-4-sonnet.json", - # "claude-3-5-sonnet.json", - # "claude-3-7-sonnet.json", + "openai-gpt-4o-mini.json", + "openai-o1.json", + "openai-o3.json", + "openai-o4-mini.json", + "azure-gpt-4o-mini.json", + "claude-4-sonnet.json", + "claude-3-5-sonnet.json", + "claude-3-7-sonnet.json", "claude-3-7-sonnet-extended.json", - # "bedrock-claude-4-sonnet.json", - # "gemini-2.5-pro.json", - # "gemini-2.5-flash.json", - # "gemini-2.5-flash-vertex.json", - # "gemini-2.5-pro-vertex.json", - # "together-qwen-2.5-72b-instruct.json", - # "ollama.json", # TODO (cliandy): enable this in ollama testing + "bedrock-claude-4-sonnet.json", + "gemini-1.5-pro.json", + "gemini-2.5-flash-vertex.json", + "gemini-2.5-pro-vertex.json", + "ollama.json", + "together-qwen-2.5-72b-instruct.json", ] reasoning_configs = [ "openai-o1.json", "openai-o3.json", "openai-o4-mini.json", - # "azure-gpt-4o-mini.json", - # "claude-4-sonnet.json", - # "claude-3-5-sonnet.json", - # "claude-3-7-sonnet.json", - # "claude-3-7-sonnet-extended.json", - # "bedrock-claude-4-sonnet.json", - # "gemini-1.5-pro.json", - # "gemini-2.5-flash-vertex.json", - # "gemini-2.5-pro-vertex.json", - # "together-qwen-2.5-72b-instruct.json", - # "ollama.json", # TODO (cliandy): enable this in ollama testing - # TODO @jnjpng: not supported in CI yet, uncomment to test locally (requires lmstudio running locally with respective models loaded) - # "lmstudio-meta-llama-3.1-8b-instruct.json", - # "lmstudio-qwen-2.5-7b-instruct.json", - # "mlx-qwen-2.5-7b-instruct.json", - # "mlx-meta-llama-3.1-8b-instruct-8bit.json", - # "mlx-ministral-8b-instruct-2410.json", - # "bartowski-ministral-8b-instruct-2410.json" ] @@ -360,7 +348,6 @@ def assert_tool_call_response( def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None: """ Validate that OpenAI format assistant messages with tool calls have no inner thoughts content. - Args: messages: List of message dictionaries in OpenAI format """ @@ -383,7 +370,6 @@ def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None: def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None: """ Validate that Anthropic/Claude format assistant messages with tool_use have no tags. - Args: messages: List of message dictionaries in Anthropic format """ @@ -424,7 +410,6 @@ def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None: def validate_google_format_scrubbing(contents: List[Dict[str, Any]]) -> None: """ Validate that Google/Gemini format model messages with functionCall have no thinking field. - Args: contents: List of content dictionaries in Google format (uses 'contents' instead of 'messages') """ @@ -753,6 +738,18 @@ def test_url_image_input( Tests sending a message with a synchronous client. Verifies that the response messages follow the expected order. """ + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( @@ -779,6 +776,18 @@ def test_base64_image_input( Tests sending a message with a synchronous client. Verifies that the response messages follow the expected order. """ + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( @@ -886,6 +895,18 @@ def test_step_streaming_tool_call( Tests sending a streaming message with a synchronous client. Checks that each chunk in the stream has the correct message types. """ + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( @@ -1004,6 +1025,18 @@ def test_token_streaming_tool_call( Tests sending a streaming message with a synchronous client. Checks that each chunk in the stream has the correct message types. """ + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( @@ -1152,6 +1185,17 @@ def test_async_tool_call( Tests sending a message as an asynchronous job using the synchronous client. Waits for job completion and asserts that the result messages are as expected. """ + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) @@ -1272,6 +1316,17 @@ def test_async_greeting_with_callback_url( Tests sending a message as an asynchronous job with callback URL functionality. Validates that callbacks are properly sent with correct payload structure. """ + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) with callback_server() as server: @@ -1337,6 +1392,17 @@ def test_async_callback_failure_scenarios( Tests that job completion works even when callback URLs fail. This ensures callback failures don't affect job processing. """ + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) # Test with invalid callback URL - job should still complete @@ -1375,6 +1441,18 @@ def test_async_callback_failure_scenarios( ) def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLMConfig): """Test that summarization is automatically triggered.""" + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model (runs too slow) + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + # pydantic prevents us for overriding the context window paramter in the passed LLMConfig new_llm_config = llm_config.model_dump() new_llm_config["context_window"] = 3000 @@ -1641,6 +1719,10 @@ def test_inner_thoughts_false_non_reasoner_models( config_filename = filename break + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + # skip if this is a reasoning model if not config_filename or config_filename in reasoning_configs: pytest.skip(f"Skipping test for reasoning model {llm_config.model}") @@ -1682,6 +1764,10 @@ def test_inner_thoughts_false_non_reasoner_models_streaming( config_filename = filename break + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + # skip if this is a reasoning model if not config_filename or config_filename in reasoning_configs: pytest.skip(f"Skipping test for reasoning model {llm_config.model}") diff --git a/tests/test_ollama.py b/tests/test_ollama.py deleted file mode 100644 index 3da98480..00000000 --- a/tests/test_ollama.py +++ /dev/null @@ -1,266 +0,0 @@ -import pytest -from letta_client import Letta - -from letta.schemas.providers import OllamaProvider -from letta.settings import model_settings - - -@pytest.fixture -def ollama_provider(): - """Create an Ollama provider for testing""" - return OllamaProvider( - name="ollama", - base_url=model_settings.ollama_base_url or "http://localhost:11434", - api_key=None, - default_prompt_formatter="chatml", - ) - - -@pytest.mark.asyncio -async def test_list_llm_models_async(ollama_provider): - """Test async listing of LLM models from Ollama""" - models = await ollama_provider.list_llm_models_async() - assert len(models) >= 0 - - model = models[0] - assert model.handle == f"{ollama_provider.name}/{model.model}" - assert model.model_endpoint_type == "ollama" - assert model.model_endpoint == ollama_provider.base_url - assert model.context_window is not None - assert model.context_window > 0 - - -# noinspection DuplicatedCode -@pytest.mark.asyncio -async def test_list_embedding_models_async(ollama_provider): - """Test async listing of embedding models from Ollama""" - embedding_models = await ollama_provider.list_embedding_models_async() - assert len(embedding_models) >= 0 - - model = embedding_models[0] - assert model.handle == f"{ollama_provider.name}/{model.embedding_model}" - assert model.embedding_endpoint_type == "ollama" - assert model.embedding_endpoint == ollama_provider.base_url - assert model.embedding_dim is not None - assert model.embedding_dim > 0 - - -def test_send_message_with_ollama_sync(ollama_provider): - """Test sending a message with Ollama (sync version)""" - import os - import threading - - from letta_client import MessageCreate - - from tests.utils import wait_for_server - - # Skip if no models available - models = ollama_provider.list_llm_models() - if len(models) == 0: - pytest.skip("No Ollama models available for testing") - - # Use the first available model - model = models[0] - - # Set up client (similar to other tests) - def run_server(): - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - wait_for_server(server_url) - - client = Letta(base_url=server_url, token=None) - - # Create agent with Ollama model - agent = client.agents.create( - name="test_ollama_agent", - memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}], - model=model.handle, - embedding="letta/letta-free", - ) - - try: - # Send a simple message - response = client.agents.messages.create( - agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond with just 'Hi there!'")] - ) - - # Verify response - assert response is not None - assert len(response.messages) > 0 - - # Find the assistant response - assistant_response = None - for msg in response.messages: - if msg.message_type == "assistant_message": - assistant_response = msg - break - - assert assistant_response is not None - assert len(assistant_response.text) > 0 - - finally: - # Clean up - client.agents.delete(agent.id) - - -@pytest.mark.asyncio -async def test_send_message_with_ollama_async_streaming(ollama_provider): - """Test sending a message with Ollama using async streaming""" - import os - import threading - - from letta_client import MessageCreate - - from tests.utils import wait_for_server - - # Skip if no models available - models = await ollama_provider.list_llm_models_async() - if len(models) == 0: - pytest.skip("No Ollama models available for testing") - - # Use the first available model - model = models[0] - - # Set up client - def run_server(): - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - wait_for_server(server_url) - - client = Letta(base_url=server_url, token=None) - - # Create agent with Ollama model - agent = client.agents.create( - name="test_ollama_streaming_agent", - memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}], - model=model.handle, - embedding="letta/letta-free", - ) - - try: - # Test step streaming (no token streaming) - response_stream = client.agents.messages.create_stream( - agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")], stream_tokens=False - ) - - # Collect streamed messages - streamed_messages = [] - for chunk in response_stream: - if hasattr(chunk, "messages"): - streamed_messages.extend(chunk.messages) - - # Verify streaming response - assert len(streamed_messages) > 0 - - # Find assistant response in stream - assistant_response = None - for msg in streamed_messages: - if msg.message_type == "assistant_message": - assistant_response = msg - break - - assert assistant_response is not None - assert len(assistant_response.text) > 0 - - finally: - # Clean up - client.agents.delete(agent.id) - - -@pytest.mark.asyncio -async def test_send_message_with_ollama_async_job(ollama_provider): - """Test sending a message with Ollama using async background job""" - import os - import threading - import time - - from letta_client import MessageCreate - - from tests.utils import wait_for_server - - # Skip if no models available - models = await ollama_provider.list_llm_models_async() - if len(models) == 0: - pytest.skip("No Ollama models available for testing") - - # Use the first available model - model = models[0] - - # Set up client - def run_server(): - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - wait_for_server(server_url) - - client = Letta(base_url=server_url, token=None) - - # Create agent with Ollama model - agent = client.agents.create( - name="test_ollama_async_agent", - memory_blocks=[{"label": "human", "value": "username: test_user"}, {"label": "persona", "value": "you are a helpful assistant"}], - model=model.handle, - embedding="letta/letta-free", - ) - - try: - # Start async job - run = client.agents.messages.create_async( - agent_id=agent.id, messages=[MessageCreate(role="user", content="Hello, respond briefly!")] - ) - - # Wait for completion - def wait_for_run_completion(run_id: str, timeout: float = 30.0): - start = time.time() - while True: - current_run = client.runs.retrieve(run_id) - if current_run.status == "completed": - return current_run - if current_run.status == "failed": - raise RuntimeError(f"Run {run_id} failed: {current_run.metadata}") - if time.time() - start > timeout: - raise TimeoutError(f"Run {run_id} timed out") - time.sleep(0.5) - - completed_run = wait_for_run_completion(run.id) - - # Verify the job completed successfully - assert completed_run.status == "completed" - assert "result" in completed_run.metadata - - # Get messages from the result - result = completed_run.metadata["result"] - assert "messages" in result - messages = result["messages"] - assert len(messages) > 0 - - # Find assistant response - assistant_response = None - for msg in messages: - if msg.get("message_type") == "assistant_message": - assistant_response = msg - break - - assert assistant_response is not None - assert len(assistant_response.get("text", "")) > 0 - - finally: - # Clean up - client.agents.delete(agent.id)