From 33d99be1573a61798d891be891aa2b8429dcf03f Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 18 Jul 2025 16:45:59 -0700 Subject: [PATCH] feat: Add env var to control model within builtin `web_search` tool (#3417) --- letta/constants.py | 4 + .../tool_executor/builtin_tool_executor.py | 6 +- tests/integration_test_builtin_tools.py | 85 +++++++++++++++++++ 3 files changed, 94 insertions(+), 1 deletion(-) diff --git a/letta/constants.py b/letta/constants.py index 210319ab..33c9745f 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -379,3 +379,7 @@ PINECONE_RETRY_BASE_DELAY = 1.0 # seconds PINECONE_RETRY_MAX_DELAY = 60.0 # seconds PINECONE_RETRY_BACKOFF_FACTOR = 2.0 PINECONE_THROTTLE_DELAY = 0.75 # seconds base delay between batches + +# builtin web search +WEB_SEARCH_MODEL_ENV_VAR_NAME = "LETTA_BUILTIN_WEBSEARCH_OPENAI_MODEL_NAME" +WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE = "gpt-4.1-mini-2025-04-14" diff --git a/letta/services/tool_executor/builtin_tool_executor.py b/letta/services/tool_executor/builtin_tool_executor.py index 22c0a38d..d8cd24a6 100644 --- a/letta/services/tool_executor/builtin_tool_executor.py +++ b/letta/services/tool_executor/builtin_tool_executor.py @@ -1,10 +1,12 @@ import asyncio import json +import os import time from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel +from letta.constants import WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE, WEB_SEARCH_MODEL_ENV_VAR_NAME from letta.functions.prompts import FIRECRAWL_SEARCH_SYSTEM_PROMPT, get_firecrawl_search_user_prompt from letta.functions.types import SearchTask from letta.log import get_logger @@ -322,8 +324,10 @@ class LettaBuiltinToolExecutor(ToolExecutor): # Time the OpenAI request start_time = time.time() + model = os.getenv(WEB_SEARCH_MODEL_ENV_VAR_NAME, WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE) + logger.info(f"Using model {model} for web search result parsing") response = await client.beta.chat.completions.parse( - model="gpt-4.1-mini-2025-04-14", + model=model, messages=[{"role": "system", "content": FIRECRAWL_SEARCH_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}], response_format=DocumentAnalysis, temperature=0.1, diff --git a/tests/integration_test_builtin_tools.py b/tests/integration_test_builtin_tools.py index f9c6df6a..5ca33c05 100644 --- a/tests/integration_test_builtin_tools.py +++ b/tests/integration_test_builtin_tools.py @@ -4,6 +4,7 @@ import threading import time import uuid from typing import List +from unittest.mock import AsyncMock, MagicMock, patch import pytest import requests @@ -11,8 +12,11 @@ from dotenv import load_dotenv from letta_client import Letta, MessageCreate from letta_client.types import ToolReturnMessage +from letta.constants import WEB_SEARCH_MODEL_ENV_VAR_NAME +from letta.functions.types import SearchTask from letta.schemas.agent import AgentState from letta.schemas.llm_config import LLMConfig +from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor from letta.settings import tool_settings # ------------------------------ @@ -109,6 +113,28 @@ def agent_state_with_firecrawl_key(client: Letta) -> AgentState: yield agent_state_instance +@pytest.fixture(scope="module") +def agent_state_with_web_search_env_var(client: Letta) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + """ + client.tools.upsert_base_tools() + + send_message_tool = client.tools.list(name="send_message")[0] + run_code_tool = client.tools.list(name="run_code")[0] + web_search_tool = client.tools.list(name="web_search")[0] + agent_state_instance = client.agents.create( + name="test_builtin_tools_agent", + include_base_tools=False, + tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id], + model="openai/gpt-4o", + embedding="letta/letta-free", + tags=["test_builtin_tools_agent"], + tool_exec_environment_variables={WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}, + ) + yield agent_state_instance + + # ------------------------------ # Helper Functions and Constants # ------------------------------ @@ -273,6 +299,65 @@ def test_web_search( ], f"Invalid api_key_source: {response_json['api_key_source']}" +@pytest.mark.asyncio +async def test_web_search_uses_agent_env_var_model(agent_state_with_web_search_env_var): + """Test that web search uses the model specified in agent tool exec env vars.""" + + # mock firecrawl response + mock_search_result = { + "data": [ + { + "url": "https://example.com", + "title": "Example Title", + "description": "Example description", + "markdown": "Line 1: Test content\nLine 2: More content", + } + ] + } + + # mock openai response + mock_openai_response = MagicMock() + mock_openai_response.usage = MagicMock() + mock_openai_response.usage.total_tokens = 100 + mock_openai_response.usage.prompt_tokens = 80 + mock_openai_response.usage.completion_tokens = 20 + mock_openai_response.choices = [MagicMock()] + mock_openai_response.choices[0].message.parsed = MagicMock() + mock_openai_response.choices[0].message.parsed.citations = [] + + with ( + patch("openai.AsyncOpenAI") as mock_openai_class, + patch("letta.services.tool_executor.builtin_tool_executor.model_settings") as mock_model_settings, + patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}), + ): + + # setup mocks + mock_model_settings.openai_api_key = "test-key" + + mock_openai_client = AsyncMock() + mock_openai_class.return_value = mock_openai_client + mock_openai_client.beta.chat.completions.parse.return_value = mock_openai_response + + # create executor with mock dependencies + executor = LettaBuiltinToolExecutor( + message_manager=MagicMock(), + agent_manager=MagicMock(), + block_manager=MagicMock(), + job_manager=MagicMock(), + passage_manager=MagicMock(), + actor=MagicMock(), + ) + + task = SearchTask(query="test query", question="test question") + + await executor.web_search(agent_state=agent_state_with_web_search_env_var, tasks=[task], limit=1) + + # verify correct model was used + mock_openai_client.beta.chat.completions.parse.assert_called_once() + call_args = mock_openai_client.beta.chat.completions.parse.call_args + assert call_args[1]["model"] == "gpt-4o" + + @pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS]) def test_web_search_using_agent_state_env_var( client: Letta,