feat: Add env var to control model within builtin web_search tool (#3417)
This commit is contained in:
@@ -379,3 +379,7 @@ PINECONE_RETRY_BASE_DELAY = 1.0 # seconds
|
||||
PINECONE_RETRY_MAX_DELAY = 60.0 # seconds
|
||||
PINECONE_RETRY_BACKOFF_FACTOR = 2.0
|
||||
PINECONE_THROTTLE_DELAY = 0.75 # seconds base delay between batches
|
||||
|
||||
# builtin web search
|
||||
WEB_SEARCH_MODEL_ENV_VAR_NAME = "LETTA_BUILTIN_WEBSEARCH_OPENAI_MODEL_NAME"
|
||||
WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE, WEB_SEARCH_MODEL_ENV_VAR_NAME
|
||||
from letta.functions.prompts import FIRECRAWL_SEARCH_SYSTEM_PROMPT, get_firecrawl_search_user_prompt
|
||||
from letta.functions.types import SearchTask
|
||||
from letta.log import get_logger
|
||||
@@ -322,8 +324,10 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
# Time the OpenAI request
|
||||
start_time = time.time()
|
||||
|
||||
model = os.getenv(WEB_SEARCH_MODEL_ENV_VAR_NAME, WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE)
|
||||
logger.info(f"Using model {model} for web search result parsing")
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model="gpt-4.1-mini-2025-04-14",
|
||||
model=model,
|
||||
messages=[{"role": "system", "content": FIRECRAWL_SEARCH_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}],
|
||||
response_format=DocumentAnalysis,
|
||||
temperature=0.1,
|
||||
|
||||
@@ -4,6 +4,7 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@@ -11,8 +12,11 @@ from dotenv import load_dotenv
|
||||
from letta_client import Letta, MessageCreate
|
||||
from letta_client.types import ToolReturnMessage
|
||||
|
||||
from letta.constants import WEB_SEARCH_MODEL_ENV_VAR_NAME
|
||||
from letta.functions.types import SearchTask
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
|
||||
from letta.settings import tool_settings
|
||||
|
||||
# ------------------------------
|
||||
@@ -109,6 +113,28 @@ def agent_state_with_firecrawl_key(client: Letta) -> AgentState:
|
||||
yield agent_state_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_state_with_web_search_env_var(client: Letta) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
"""
|
||||
client.tools.upsert_base_tools()
|
||||
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
run_code_tool = client.tools.list(name="run_code")[0]
|
||||
web_search_tool = client.tools.list(name="web_search")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="test_builtin_tools_agent",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["test_builtin_tools_agent"],
|
||||
tool_exec_environment_variables={WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"},
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
@@ -273,6 +299,65 @@ def test_web_search(
|
||||
], f"Invalid api_key_source: {response_json['api_key_source']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_uses_agent_env_var_model(agent_state_with_web_search_env_var):
|
||||
"""Test that web search uses the model specified in agent tool exec env vars."""
|
||||
|
||||
# mock firecrawl response
|
||||
mock_search_result = {
|
||||
"data": [
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Example Title",
|
||||
"description": "Example description",
|
||||
"markdown": "Line 1: Test content\nLine 2: More content",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# mock openai response
|
||||
mock_openai_response = MagicMock()
|
||||
mock_openai_response.usage = MagicMock()
|
||||
mock_openai_response.usage.total_tokens = 100
|
||||
mock_openai_response.usage.prompt_tokens = 80
|
||||
mock_openai_response.usage.completion_tokens = 20
|
||||
mock_openai_response.choices = [MagicMock()]
|
||||
mock_openai_response.choices[0].message.parsed = MagicMock()
|
||||
mock_openai_response.choices[0].message.parsed.citations = []
|
||||
|
||||
with (
|
||||
patch("openai.AsyncOpenAI") as mock_openai_class,
|
||||
patch("letta.services.tool_executor.builtin_tool_executor.model_settings") as mock_model_settings,
|
||||
patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}),
|
||||
):
|
||||
|
||||
# setup mocks
|
||||
mock_model_settings.openai_api_key = "test-key"
|
||||
|
||||
mock_openai_client = AsyncMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_openai_client.beta.chat.completions.parse.return_value = mock_openai_response
|
||||
|
||||
# create executor with mock dependencies
|
||||
executor = LettaBuiltinToolExecutor(
|
||||
message_manager=MagicMock(),
|
||||
agent_manager=MagicMock(),
|
||||
block_manager=MagicMock(),
|
||||
job_manager=MagicMock(),
|
||||
passage_manager=MagicMock(),
|
||||
actor=MagicMock(),
|
||||
)
|
||||
|
||||
task = SearchTask(query="test query", question="test question")
|
||||
|
||||
await executor.web_search(agent_state=agent_state_with_web_search_env_var, tasks=[task], limit=1)
|
||||
|
||||
# verify correct model was used
|
||||
mock_openai_client.beta.chat.completions.parse.assert_called_once()
|
||||
call_args = mock_openai_client.beta.chat.completions.parse.call_args
|
||||
assert call_args[1]["model"] == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
|
||||
def test_web_search_using_agent_state_env_var(
|
||||
client: Letta,
|
||||
|
||||
Reference in New Issue
Block a user