diff --git a/tests/integration_test_builtin_tools.py b/tests/integration_test_builtin_tools.py index 92efbf3a..99a61a7c 100644 --- a/tests/integration_test_builtin_tools.py +++ b/tests/integration_test_builtin_tools.py @@ -177,12 +177,14 @@ def test_run_code( ) -@patch("exa_py.Exa") -def test_web_search( - mock_exa_class, - client: Letta, - agent_state: AgentState, -) -> None: +@pytest.mark.asyncio(scope="function") +async def test_web_search() -> None: + """Test web search tool with mocked Exa API.""" + + # create mock agent state with exa api key + mock_agent_state = MagicMock() + mock_agent_state.get_agent_env_vars_as_dict.return_value = {"EXA_API_KEY": "test-exa-key"} + # Mock Exa search result with education information mock_exa_result = MagicMock() mock_exa_result.results = [ @@ -191,7 +193,7 @@ def test_web_search( url="https://example.com/charles-packer-profile", published_date="2023-01-01", author="UC Berkeley", - text=None, # include_text=False by default + text=None, highlights=["Charles Packer completed his PhD at UC Berkeley", "Research in artificial intelligence and machine learning"], summary="Charles Packer is the CEO of Letta who earned his PhD in Computer Science from UC Berkeley, specializing in AI research.", ), @@ -206,68 +208,70 @@ def test_web_search( ), ] - # Setup mock - mock_exa_client = MagicMock() - mock_exa_class.return_value = mock_exa_client - mock_exa_client.search_and_contents.return_value = mock_exa_result + with patch("exa_py.Exa") as mock_exa_class: + # Setup mock + mock_exa_client = MagicMock() + mock_exa_class.return_value = mock_exa_client + mock_exa_client.search_and_contents.return_value = mock_exa_result - user_message = MessageCreate( - role="user", - content="I am executing a test. Use the web search tool to find where I, Charles Packer, the CEO of Letta, went to school.", - otid=USER_MESSAGE_OTID, - ) + # create executor with mock dependencies + executor = LettaBuiltinToolExecutor( + message_manager=MagicMock(), + agent_manager=MagicMock(), + block_manager=MagicMock(), + run_manager=MagicMock(), + passage_manager=MagicMock(), + actor=MagicMock(), + ) - response = client.agents.messages.create( - agent_id=agent_state.id, - messages=[user_message], - ) + # call web_search directly + result = await executor.web_search( + agent_state=mock_agent_state, + query="where did Charles Packer, CEO of Letta, go to school", + num_results=10, + include_text=False, + ) - tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] - assert tool_returns, "No ToolReturnMessage found" + # Parse the JSON response from web_search + response_json = json.loads(result) - returns = [m.tool_return for m in tool_returns] - print(returns) + # Basic structure assertions for new Exa format + assert "query" in response_json, "Missing 'query' field in response" + assert "results" in response_json, "Missing 'results' field in response" - # Parse the JSON response from web_search - assert len(returns) > 0, "No tool returns found" - response_json = json.loads(returns[0]) + # Verify we got search results + results = response_json["results"] + assert len(results) == 2, "Should have found exactly 2 search results from mock" - # Basic structure assertions for new Exa format - assert "query" in response_json, "Missing 'query' field in response" - assert "results" in response_json, "Missing 'results' field in response" + # Check each result has the expected structure + found_education_info = False + for result in results: + assert "title" in result, "Result missing title" + assert "url" in result, "Result missing URL" - # Verify we got search results - results = response_json["results"] - assert len(results) == 2, "Should have found exactly 2 search results from mock" + # text should not be present since include_text=False by default + assert "text" not in result or result["text"] is None, "Text should not be included by default" - # Check each result has the expected structure - found_education_info = False - for result in results: - assert "title" in result, "Result missing title" - assert "url" in result, "Result missing URL" + # Check for education-related information in summary and highlights + result_text = "" + if "summary" in result and result["summary"]: + result_text += " " + result["summary"].lower() + if "highlights" in result and result["highlights"]: + for highlight in result["highlights"]: + result_text += " " + highlight.lower() - # text should not be present since include_text=False by default - assert "text" not in result or result["text"] is None, "Text should not be included by default" + # Look for education keywords + if any(keyword in result_text for keyword in ["berkeley", "university", "phd", "ph.d", "education", "student"]): + found_education_info = True - # Check for education-related information in summary and highlights - result_text = "" - if "summary" in result and result["summary"]: - result_text += " " + result["summary"].lower() - if "highlights" in result and result["highlights"]: - for highlight in result["highlights"]: - result_text += " " + highlight.lower() + assert found_education_info, "Should have found education-related information about Charles Packer" - # Look for education keywords - if any(keyword in result_text for keyword in ["berkeley", "university", "phd", "ph.d", "education", "student"]): - found_education_info = True - - assert found_education_info, "Should have found education-related information about Charles Packer" - - # Verify Exa was called with correct parameters - mock_exa_client.search_and_contents.assert_called_once() - call_args = mock_exa_client.search_and_contents.call_args - assert call_args[1]["type"] == "auto" - assert call_args[1]["text"] is False # Default is False now + # Verify Exa was called with correct parameters + mock_exa_class.assert_called_once_with(api_key="test-exa-key") + mock_exa_client.search_and_contents.assert_called_once() + call_args = mock_exa_client.search_and_contents.call_args + assert call_args[1]["type"] == "auto" + assert call_args[1]["text"] is False # Default is False now @pytest.mark.asyncio(scope="function")