feat: Change web search to exa [LET-4190] (#4401)

* Change web search to exa

* Fix tf/justfile

* add exa api key for integration test

* Mock exa

---------

Co-authored-by: Kian Jones <kian@letta.com>
This commit is contained in:
Matthew Zhou
2025-09-03 15:52:10 -07:00
committed by GitHub
parent 12d8242d94
commit 2af6049d6f
9 changed files with 254 additions and 579 deletions

View File

@@ -1,6 +1,4 @@
from typing import List, Literal
from letta.functions.types import SearchTask
from typing import List, Literal, Optional
def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
@@ -17,32 +15,40 @@ def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) ->
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")
async def web_search(tasks: List[SearchTask], limit: int = 1, return_raw: bool = True) -> str:
async def web_search(
query: str,
num_results: int = 10,
category: Optional[
Literal["company", "research paper", "news", "pdf", "github", "tweet", "personal site", "linkedin profile", "financial report"]
] = None,
include_text: bool = False,
include_domains: Optional[List[str]] = None,
exclude_domains: Optional[List[str]] = None,
start_published_date: Optional[str] = None,
end_published_date: Optional[str] = None,
user_location: Optional[str] = None,
) -> str:
"""
Search the web with a list of query/question pairs and extract passages that answer the corresponding questions.
Search the web using Exa's AI-powered search engine and retrieve relevant content.
Examples:
tasks -> [
SearchTask(
query="Tesla Q1 2025 earnings report PDF",
question="What was Tesla's net profit in Q1 2025?"
),
SearchTask(
query="Letta API prebuilt tools core_memory_append",
question="What does the core_memory_append tool do in Letta?"
)
]
web_search("Tesla Q1 2025 earnings report", num_results=5, category="financial report")
web_search("Latest research in large language models", category="research paper", include_domains=["arxiv.org", "paperswithcode.com"])
web_search("Letta API documentation core_memory_append", num_results=3)
Args:
tasks (List[SearchTask]): A list of search tasks, each containing a `query` and a corresponding `question`.
limit (int, optional): Maximum number of URLs to fetch and analyse per task (must be > 0). Defaults to 1.
return_raw (bool, optional): If set to True, returns the raw content of the web pages.
This should be True unless otherwise specified by the user. Defaults to True.
query (str): The search query to find relevant web content.
num_results (int, optional): Number of results to return (1-100). Defaults to 10.
category (Optional[Literal], optional): Focus search on specific content types. Defaults to None.
include_text (bool, optional): Whether to retrieve full page content. Defaults to False (only returns summary and highlights, since the full text usually will overflow the context window).
include_domains (Optional[List[str]], optional): List of domains to include in search results. Defaults to None.
exclude_domains (Optional[List[str]], optional): List of domains to exclude from search results. Defaults to None.
start_published_date (Optional[str], optional): Only return content published after this date (ISO format). Defaults to None.
end_published_date (Optional[str], optional): Only return content published before this date (ISO format). Defaults to None.
user_location (Optional[str], optional): Two-letter country code for localized results (e.g., "US"). Defaults to None.
Returns:
str: A JSON-encoded string containing a list of search results.
Each result includes ranked snippets with their source URLs and relevance scores,
corresponding to each search task.
str: A JSON-encoded string containing search results with title, URL, content, highlights, and summary.
"""
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")

View File

@@ -1,13 +1,7 @@
import asyncio
import json
import time
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel
from letta.constants import WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE, WEB_SEARCH_MODEL_ENV_VAR_NAME
from letta.functions.prompts import FIRECRAWL_SEARCH_SYSTEM_PROMPT, get_firecrawl_search_user_prompt
from letta.functions.types import SearchTask
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -16,36 +10,11 @@ from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.tool_executor.tool_executor_base import ToolExecutor
from letta.settings import model_settings, tool_settings
from letta.settings import tool_settings
logger = get_logger(__name__)
class Citation(BaseModel):
"""A relevant text snippet identified by line numbers in a document."""
start_line: int # Starting line number (1-indexed)
end_line: int # Ending line number (1-indexed, inclusive)
class CitationWithText(BaseModel):
"""A citation with the actual extracted text."""
text: str # The actual extracted text from the lines
class DocumentAnalysis(BaseModel):
"""Analysis of a document's relevance to a search question."""
citations: List[Citation]
class DocumentAnalysisWithText(BaseModel):
"""Analysis with extracted text from line citations."""
citations: List[CitationWithText]
class LettaBuiltinToolExecutor(ToolExecutor):
"""Executor for built in Letta tools."""
@@ -104,317 +73,124 @@ class LettaBuiltinToolExecutor(ToolExecutor):
return out
@trace_method
async def web_search(self, agent_state: "AgentState", tasks: List[SearchTask], limit: int = 1, return_raw: bool = True) -> str:
async def web_search(
self,
agent_state: "AgentState",
query: str,
num_results: int = 10,
category: Optional[
Literal["company", "research paper", "news", "pdf", "github", "tweet", "personal site", "linkedin profile", "financial report"]
] = None,
include_text: bool = False,
include_domains: Optional[List[str]] = None,
exclude_domains: Optional[List[str]] = None,
start_published_date: Optional[str] = None,
end_published_date: Optional[str] = None,
user_location: Optional[str] = None,
) -> str:
"""
Search the web with a list of query/question pairs and extract passages that answer the corresponding questions.
Examples:
tasks -> [
SearchTask(
query="Tesla Q1 2025 earnings report PDF",
question="What was Tesla's net profit in Q1 2025?"
),
SearchTask(
query="Letta API prebuilt tools core_memory_append",
question="What does the core_memory_append tool do in Letta?"
)
]
Search the web using Exa's AI-powered search engine and retrieve relevant content.
Args:
tasks (List[SearchTask]): A list of search tasks, each containing a `query` and a corresponding `question`.
limit (int, optional): Maximum number of URLs to fetch and analyse per task (must be > 0). Defaults to 3.
return_raw (bool, optional): If set to True, returns the raw content of the web pages.
This should be False unless otherwise specified by the user. Defaults to False.
query: The search query to find relevant web content
num_results: Number of results to return (1-100)
category: Focus search on specific content types
include_text: Whether to retrieve full page content (default: False, only returns summary and highlights)
include_domains: List of domains to include in search results
exclude_domains: List of domains to exclude from search results
start_published_date: Only return content published after this date (ISO format)
end_published_date: Only return content published before this date (ISO format)
user_location: Two-letter country code for localized results
Returns:
str: A JSON-encoded string containing a list of search results.
Each result includes ranked snippets with their source URLs and relevance scores,
corresponding to each search task.
JSON-encoded string containing search results
"""
# # TODO: Temporary, maybe deprecate this field?
# if return_raw:
# logger.warning("WARNING! return_raw was set to True, we default to False always. Deprecate this field.")
# return_raw = False
try:
from firecrawl import AsyncFirecrawlApp
from exa_py import Exa
except ImportError:
raise ImportError("firecrawl-py is not installed in the tool execution environment")
raise ImportError("exa-py is not installed in the tool execution environment")
if not tasks:
return json.dumps({"error": "No search tasks provided."})
if not query.strip():
return json.dumps({"error": "Query cannot be empty", "query": query})
# Convert dict objects to SearchTask objects
search_tasks = []
for task in tasks:
if isinstance(task, dict):
search_tasks.append(SearchTask(**task))
else:
search_tasks.append(task)
logger.info(f"[DEBUG] Starting web search with {len(search_tasks)} tasks, limit={limit}, return_raw={return_raw}")
# Check if the API key exists on the agent state
# Get EXA API key from agent environment or tool settings
agent_state_tool_env_vars = agent_state.get_agent_env_vars_as_dict()
firecrawl_api_key = agent_state_tool_env_vars.get("FIRECRAWL_API_KEY") or tool_settings.firecrawl_api_key
if not firecrawl_api_key:
raise ValueError("FIRECRAWL_API_KEY is not set in environment or on agent_state tool exec environment variables.")
exa_api_key = agent_state_tool_env_vars.get("EXA_API_KEY") or tool_settings.exa_api_key
if not exa_api_key:
raise ValueError("EXA_API_KEY is not set in environment or on agent_state tool execution environment variables.")
# Track which API key source was used
api_key_source = "agent_environment" if agent_state_tool_env_vars.get("FIRECRAWL_API_KEY") else "system_settings"
logger.info(f"[DEBUG] Starting Exa web search for query: '{query}' with {num_results} results")
if limit <= 0:
raise ValueError("limit must be greater than 0")
# Initialize Firecrawl client
app = AsyncFirecrawlApp(api_key=firecrawl_api_key)
# Process all search tasks serially
search_results = []
for task in search_tasks:
try:
result = await self._process_single_search_task(app, task, limit, return_raw, api_key_source, agent_state)
search_results.append(result)
except Exception as e:
search_results.append(e)
# Build final response as a mapping of query -> result
final_results = {}
successful_tasks = 0
failed_tasks = 0
for i, result in enumerate(search_results):
query = search_tasks[i].query
if isinstance(result, Exception):
logger.error(f"Search task {i} failed: {result}")
failed_tasks += 1
final_results[query] = {"query": query, "question": search_tasks[i].question, "error": str(result)}
else:
successful_tasks += 1
final_results[query] = result
logger.info(f"[DEBUG] Web search completed: {successful_tasks} successful, {failed_tasks} failed")
# Build final response with api_key_source at top level
response = {"api_key_source": api_key_source, "results": final_results}
return json.dumps(response, indent=2, ensure_ascii=False)
@trace_method
async def _process_single_search_task(
self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str, agent_state: "AgentState"
) -> Dict[str, Any]:
"""Process a single search task."""
from firecrawl import ScrapeOptions
logger.info(f"[DEBUG] Starting Firecrawl search for query: '{task.query}' with limit={limit}")
# Perform the search for this task
scrape_options = ScrapeOptions(
formats=["markdown"], excludeTags=["#ad", "#footer"], onlyMainContent=True, parsePDF=True, removeBase64Images=True
)
search_result = await app.search(task.query, limit=limit, scrape_options=scrape_options)
logger.info(
f"[DEBUG] Firecrawl search completed for '{task.query}': {len(search_result.get('data', [])) if search_result else 0} results"
)
if not search_result or not search_result.get("data"):
return {"query": task.query, "question": task.question, "error": "No search results found."}
# If raw results requested, return them directly
if return_raw:
return {"query": task.query, "question": task.question, "raw_results": search_result}
# Check if OpenAI API key is available for semantic parsing
if model_settings.openai_api_key:
try:
from openai import AsyncOpenAI
logger.info(f"[DEBUG] Starting OpenAI analysis for '{task.query}'")
# Initialize OpenAI client
client = AsyncOpenAI(
api_key=model_settings.openai_api_key,
)
# Process each result with OpenAI concurrently
analysis_tasks = []
results_with_markdown = []
results_without_markdown = []
for result in search_result.get("data"):
if result.get("markdown"):
# Create async task for OpenAI analysis
analysis_task = self._analyze_document_with_openai(
client, result["markdown"], task.query, task.question, agent_state
)
analysis_tasks.append(analysis_task)
results_with_markdown.append(result)
else:
results_without_markdown.append(result)
logger.info(f"[DEBUG] Starting parallel OpenAI analysis of {len(analysis_tasks)} documents for '{task.query}'")
# Fire off all OpenAI requests concurrently
analyses = await asyncio.gather(*analysis_tasks, return_exceptions=True)
logger.info(f"[DEBUG] Completed parallel OpenAI analysis of {len(analyses)} documents for '{task.query}'")
# Build processed results
processed_results = []
# Check if any analysis failed - if so, fall back to raw results
for result, analysis in zip(results_with_markdown, analyses):
if isinstance(analysis, Exception) or analysis is None:
logger.error(f"Analysis failed for {result.get('url')}, falling back to raw results")
return {"query": task.query, "question": task.question, "raw_results": search_result}
# All analyses succeeded, build processed results
for result, analysis in zip(results_with_markdown, analyses):
# Extract actual text from line number citations
analysis_with_text = None
if analysis and analysis.citations:
analysis_with_text = self._extract_text_from_line_citations(analysis, result["markdown"])
processed_results.append(
{
"url": result.get("url"),
"title": result.get("title"),
"description": result.get("description"),
"analysis": analysis_with_text.model_dump() if analysis_with_text else None,
}
)
# Add results without markdown
for result in results_without_markdown:
processed_results.append(
{"url": result.get("url"), "title": result.get("title"), "description": result.get("description"), "analysis": None}
)
# Build final response for this task
return self._build_final_response_dict(processed_results, task.query, task.question)
except Exception as e:
# Log error but continue with raw results
logger.error(f"Error with OpenAI processing for task '{task.query}': {e}")
# Return raw search results if OpenAI processing isn't available or fails
return {"query": task.query, "question": task.question, "raw_results": search_result}
@trace_method
async def _analyze_document_with_openai(
self, client, markdown_content: str, query: str, question: str, agent_state: "AgentState"
) -> Optional[DocumentAnalysis]:
"""Use OpenAI to analyze a document and extract relevant passages using line numbers."""
original_length = len(markdown_content)
# Create numbered markdown for the LLM to reference
numbered_lines = markdown_content.split("\n")
numbered_markdown = "\n".join([f"{i + 1:4d}: {line}" for i, line in enumerate(numbered_lines)])
# Truncate if too long
max_content_length = 200000
truncated = False
if len(numbered_markdown) > max_content_length:
numbered_markdown = numbered_markdown[:max_content_length] + "..."
truncated = True
user_prompt = get_firecrawl_search_user_prompt(query, question, numbered_markdown)
logger.info(
f"[DEBUG] Starting OpenAI request with line numbers - Query: '{query}', Content: {original_length} chars (truncated: {truncated})"
)
# Time the OpenAI request
start_time = time.time()
# Check agent state env vars first, then fall back to os.getenv
agent_state_tool_env_vars = agent_state.get_agent_env_vars_as_dict()
model = agent_state_tool_env_vars.get(WEB_SEARCH_MODEL_ENV_VAR_NAME) or WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE
logger.info(f"Using model {model} for web search result parsing")
response = await client.beta.chat.completions.parse(
model=model,
messages=[{"role": "system", "content": FIRECRAWL_SEARCH_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}],
response_format=DocumentAnalysis,
temperature=0.1,
)
end_time = time.time()
request_duration = end_time - start_time
# Get usage statistics and output length
usage = response.usage
parsed_result = response.choices[0].message.parsed
num_citations = len(parsed_result.citations) if parsed_result else 0
# Calculate output length (minimal now - just line numbers)
output_length = 0
if parsed_result and parsed_result.citations:
for citation in parsed_result.citations:
output_length += 20 # ~20 chars for line numbers only
logger.info(f"[TIMING] OpenAI request completed in {request_duration:.2f}s - Query: '{query}'")
logger.info(f"[TOKENS] Total: {usage.total_tokens} (prompt: {usage.prompt_tokens}, completion: {usage.completion_tokens})")
logger.info(f"[OUTPUT] Citations: {num_citations}, Output chars: {output_length} (line-number based)")
return parsed_result
def _extract_text_from_line_citations(self, analysis: DocumentAnalysis, original_markdown: str) -> DocumentAnalysisWithText:
"""Extract actual text from line number citations."""
lines = original_markdown.split("\n")
citations_with_text = []
for citation in analysis.citations:
try:
# Convert to 0-indexed and ensure bounds
start_idx = max(0, citation.start_line - 1)
end_idx = min(len(lines), citation.end_line)
# Extract the lines
extracted_lines = lines[start_idx:end_idx]
extracted_text = "\n".join(extracted_lines)
citations_with_text.append(CitationWithText(text=extracted_text))
except Exception as e:
logger.info(f"[DEBUG] Failed to extract text for citation lines {citation.start_line}-{citation.end_line}: {e}")
# Fall back to including the citation with empty text
citations_with_text.append(CitationWithText(text=""))
return DocumentAnalysisWithText(citations=citations_with_text)
@trace_method
def _build_final_response_dict(self, processed_results: List[Dict], query: str, question: str) -> Dict[str, Any]:
"""Build the final response dictionary from all processed results."""
# Build sources array
sources = []
total_snippets = 0
for result in processed_results:
source = {"url": result.get("url"), "title": result.get("title"), "description": result.get("description")}
if result.get("analysis") and result["analysis"].get("citations"):
analysis = result["analysis"]
source["citations"] = analysis["citations"]
total_snippets += len(analysis["citations"])
else:
source["citations"] = []
sources.append(source)
# Build final response structure
response = {
# Build search parameters
search_params = {
"query": query,
"question": question,
"total_sources": len(sources),
"total_citations": total_snippets,
"sources": sources,
"num_results": min(max(num_results, 1), 100), # Clamp between 1-100
"type": "auto", # Always use auto search type
}
if total_snippets == 0:
response["message"] = "No relevant passages found that directly answer the question."
# Add optional parameters if provided
if category:
search_params["category"] = category
if include_domains:
search_params["include_domains"] = include_domains
if exclude_domains:
search_params["exclude_domains"] = exclude_domains
if start_published_date:
search_params["start_published_date"] = start_published_date
if end_published_date:
search_params["end_published_date"] = end_published_date
if user_location:
search_params["user_location"] = user_location
return response
# Configure contents retrieval
contents_params = {
"text": include_text,
"highlights": {"num_sentences": 2, "highlights_per_url": 3, "query": query},
"summary": {"query": f"Summarize the key information from this content related to: {query}"},
}
def _sync_exa_search():
"""Synchronous Exa API call to run in thread pool."""
exa = Exa(api_key=exa_api_key)
return exa.search_and_contents(**search_params, **contents_params)
try:
# Perform search with content retrieval in thread pool to avoid blocking event loop
logger.info(f"[DEBUG] Making async Exa API call with params: {search_params}")
result = await asyncio.to_thread(_sync_exa_search)
# Format results
formatted_results = []
for res in result.results:
formatted_result = {
"title": res.title,
"url": res.url,
"published_date": res.published_date,
"author": res.author,
}
# Add content if requested
if include_text and hasattr(res, "text") and res.text:
formatted_result["text"] = res.text
# Add highlights if available
if hasattr(res, "highlights") and res.highlights:
formatted_result["highlights"] = res.highlights
# Add summary if available
if hasattr(res, "summary") and res.summary:
formatted_result["summary"] = res.summary
formatted_results.append(formatted_result)
response = {"query": query, "results": formatted_results}
logger.info(f"[DEBUG] Exa search completed successfully with {len(formatted_results)} results")
return json.dumps(response, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Exa search failed for query '{query}': {str(e)}")
return json.dumps({"query": query, "error": f"Search failed: {str(e)}"})
async def fetch_webpage(self, agent_state: "AgentState", url: str) -> str:
"""

View File

@@ -23,7 +23,7 @@ class ToolSettings(BaseSettings):
# Search Providers
tavily_api_key: str | None = Field(default=None, description="API key for using Tavily as a search provider.")
firecrawl_api_key: str | None = Field(default=None, description="API key for using Firecrawl as a search provider.")
exa_api_key: str | None = Field(default=None, description="API key for using Exa as a search provider.")
# Local Sandbox configurations
tool_exec_dir: Optional[str] = None

View File

@@ -57,7 +57,7 @@ dependencies = [
"marshmallow-sqlalchemy>=1.4.1",
"datamodel-code-generator[http]>=0.25.0",
"mcp[cli]>=1.9.4",
"firecrawl-py>=2.8.0,<3.0.0",
"exa-py>=1.15.4",
"apscheduler>=3.11.0",
"aiomultiprocess>=0.9.1",
"matplotlib>=3.10.1",
@@ -125,7 +125,7 @@ external-tools = [
"langchain>=0.3.7",
"wikipedia>=1.4.0",
"langchain-community>=0.3.7",
"firecrawl-py>=2.8.0,<3.0.0",
"exa-py>=1.15.4",
"turbopuffer>=0.5.17",
]
desktop = [

View File

@@ -4,7 +4,7 @@ import threading
import time
import uuid
from typing import List
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
import pytest
import requests
@@ -12,8 +12,6 @@ from dotenv import load_dotenv
from letta_client import Letta, MessageCreate
from letta_client.types import ToolReturnMessage
from letta.constants import WEB_SEARCH_MODEL_ENV_VAR_NAME
from letta.functions.types import SearchTask
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
@@ -74,6 +72,7 @@ def client(server_url: str) -> Letta:
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
Uses system-level EXA_API_KEY setting.
"""
client.tools.upsert_base_tools()
@@ -91,50 +90,6 @@ def agent_state(client: Letta) -> AgentState:
yield agent_state_instance
@pytest.fixture(scope="module")
def agent_state_with_firecrawl_key(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
"""
client.tools.upsert_base_tools()
send_message_tool = client.tools.list(name="send_message")[0]
run_code_tool = client.tools.list(name="run_code")[0]
web_search_tool = client.tools.list(name="web_search")[0]
agent_state_instance = client.agents.create(
name="test_builtin_tools_agent",
include_base_tools=False,
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["test_builtin_tools_agent"],
tool_exec_environment_variables={"FIRECRAWL_API_KEY": tool_settings.firecrawl_api_key},
)
yield agent_state_instance
@pytest.fixture(scope="module")
def agent_state_with_web_search_env_var(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
"""
client.tools.upsert_base_tools()
send_message_tool = client.tools.list(name="send_message")[0]
run_code_tool = client.tools.list(name="run_code")[0]
web_search_tool = client.tools.list(name="web_search")[0]
agent_state_instance = client.agents.create(
name="test_builtin_tools_agent",
include_base_tools=False,
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["test_builtin_tools_agent"],
tool_exec_environment_variables={WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"},
)
yield agent_state_instance
# ------------------------------
# Helper Functions and Constants
# ------------------------------
@@ -185,11 +140,9 @@ def reference_partition(n: int) -> int:
@pytest.mark.parametrize("language", TEST_LANGUAGES, ids=TEST_LANGUAGES)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
def test_run_code(
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
language: str,
) -> None:
"""
@@ -223,12 +176,40 @@ def test_run_code(
)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
@patch("exa_py.Exa")
def test_web_search(
mock_exa_class,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
# Mock Exa search result with education information
mock_exa_result = MagicMock()
mock_exa_result.results = [
MagicMock(
title="Charles Packer - UC Berkeley PhD in Computer Science",
url="https://example.com/charles-packer-profile",
published_date="2023-01-01",
author="UC Berkeley",
text=None, # include_text=False by default
highlights=["Charles Packer completed his PhD at UC Berkeley", "Research in artificial intelligence and machine learning"],
summary="Charles Packer is the CEO of Letta who earned his PhD in Computer Science from UC Berkeley, specializing in AI research.",
),
MagicMock(
title="Letta Leadership Team",
url="https://letta.com/team",
published_date="2023-06-01",
author="Letta",
text=None,
highlights=["CEO Charles Packer brings academic expertise"],
summary="Leadership team page featuring CEO Charles Packer's educational background.",
),
]
# Setup mock
mock_exa_client = MagicMock()
mock_exa_class.return_value = mock_exa_client
mock_exa_client.search_and_contents.return_value = mock_exa_result
user_message = MessageCreate(
role="user",
content="I am executing a test. Use the web search tool to find where I, Charles Packer, the CEO of Letta, went to school.",
@@ -250,136 +231,71 @@ def test_web_search(
assert len(returns) > 0, "No tool returns found"
response_json = json.loads(returns[0])
# Basic structure assertions
assert "api_key_source" in response_json, "Missing 'api_key_source' field in response"
# Basic structure assertions for new Exa format
assert "query" in response_json, "Missing 'query' field in response"
assert "results" in response_json, "Missing 'results' field in response"
assert response_json["api_key_source"] == "system_settings"
# Get the first result from the results dictionary
# Verify we got search results
results = response_json["results"]
assert len(results) > 0, "No results found in response"
assert len(results) == 2, "Should have found exactly 2 search results from mock"
# Get the first (and typically only) result
first_result_key = list(results.keys())[0]
result_data = results[first_result_key]
# Check each result has the expected structure
found_education_info = False
for result in results:
assert "title" in result, "Result missing title"
assert "url" in result, "Result missing URL"
# Basic structure assertions for the result data
assert "query" in result_data, "Missing 'query' field in result"
assert "question" in result_data, "Missing 'question' field in result"
# text should not be present since include_text=False by default
assert "text" not in result or result["text"] is None, "Text should not be included by default"
# Check if we have the new response format with raw_results
if "raw_results" in result_data:
# New format with raw_results
assert "raw_results" in result_data, "Missing 'raw_results' field in result"
raw_results = result_data["raw_results"]
# Check for education-related information in summary and highlights
result_text = ""
if "summary" in result and result["summary"]:
result_text += " " + result["summary"].lower()
if "highlights" in result and result["highlights"]:
for highlight in result["highlights"]:
result_text += " " + highlight.lower()
assert "success" in raw_results, "Missing 'success' field in raw_results"
assert "data" in raw_results, "Missing 'data' field in raw_results"
# Look for education keywords
if any(keyword in result_text for keyword in ["berkeley", "university", "phd", "ph.d", "education", "student"]):
found_education_info = True
# Verify we got search results
assert len(raw_results["data"]) > 0, "Should have found at least one search result"
assert found_education_info, "Should have found education-related information about Charles Packer"
# Check if we found education-related information in the search results
found_education_info = False
for item in raw_results["data"]:
# Check in description
if "description" in item:
desc_lower = item["description"].lower()
if any(keyword in desc_lower for keyword in ["berkeley", "university", "education", "phd", "student"]):
found_education_info = True
break
# Also check in markdown content if available
if "markdown" in item:
markdown_lower = item["markdown"].lower()
if any(keyword in markdown_lower for keyword in ["berkeley", "university", "phd", "student"]):
found_education_info = True
break
# We should find education info since we now have successful scraping with markdown content
assert found_education_info, "Should have found education-related information about Charles Packer"
else:
# Parsed format with total_sources, total_citations, sources
assert "total_sources" in result_data, "Missing 'total_sources' field in result"
assert "total_citations" in result_data, "Missing 'total_citations' field in result"
assert "sources" in result_data, "Missing 'sources' field in result"
# Content assertions
assert result_data["total_sources"] > 0, "Should have found at least one source"
assert result_data["total_citations"] > 0, "Should have found at least one citation"
assert len(result_data["sources"]) == result_data["total_sources"], "Sources count mismatch"
# Verify we found information about Charles Packer's education
found_education_info = False
for source in result_data["sources"]:
assert "url" in source, "Source missing URL"
assert "title" in source, "Source missing title"
assert "citations" in source, "Source missing citations"
for citation in source["citations"]:
assert "text" in citation, "Citation missing text"
# Check if we found education-related information
if any(keyword in citation["text"].lower() for keyword in ["berkeley", "phd", "ph.d", "university", "student"]):
found_education_info = True
assert found_education_info, "Should have found education-related information about Charles Packer"
# API key source should be valid
assert response_json["api_key_source"] in [
"agent_environment",
"system_settings",
], f"Invalid api_key_source: {response_json['api_key_source']}"
# Verify Exa was called with correct parameters
mock_exa_client.search_and_contents.assert_called_once()
call_args = mock_exa_client.search_and_contents.call_args
assert call_args[1]["type"] == "auto"
assert call_args[1]["text"] is False # Default is False now
@pytest.mark.asyncio(scope="function")
async def test_web_search_uses_agent_env_var_model():
"""Test that web search uses the model specified in agent tool exec env vars."""
async def test_web_search_uses_exa():
"""Test that web search uses Exa API correctly."""
# create mock agent state with web search model env var
# create mock agent state with exa api key
mock_agent_state = MagicMock()
mock_agent_state.get_agent_env_vars_as_dict.return_value = {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}
mock_agent_state.get_agent_env_vars_as_dict.return_value = {"EXA_API_KEY": "test-exa-key"}
# mock openai response
mock_openai_response = MagicMock()
mock_openai_response.usage = MagicMock()
mock_openai_response.usage.total_tokens = 100
mock_openai_response.usage.prompt_tokens = 80
mock_openai_response.usage.completion_tokens = 20
mock_openai_response.choices = [MagicMock()]
mock_openai_response.choices[0].message.parsed = MagicMock()
mock_openai_response.choices[0].message.parsed.citations = []
# Mock exa search result
mock_exa_result = MagicMock()
mock_exa_result.results = [
MagicMock(
title="Test Result",
url="https://example.com/test",
published_date="2023-01-01",
author="Test Author",
text="This is test content from the search result.",
highlights=["This is a highlight"],
summary="This is a summary of the content.",
)
]
with (
patch("openai.AsyncOpenAI") as mock_openai_class,
patch("letta.services.tool_executor.builtin_tool_executor.model_settings") as mock_model_settings,
patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}),
patch("firecrawl.AsyncFirecrawlApp") as mock_firecrawl_class,
):
# setup mocks
mock_model_settings.openai_api_key = "test-key"
mock_openai_client = AsyncMock()
mock_openai_class.return_value = mock_openai_client
mock_openai_client.beta.chat.completions.parse.return_value = mock_openai_response
# Mock Firecrawl
mock_firecrawl_app = AsyncMock()
mock_firecrawl_class.return_value = mock_firecrawl_app
# Mock search results with markdown content
mock_search_result = {
"data": [
{
"url": "https://example.com/test",
"title": "Test Result",
"description": "Test description",
"markdown": "This is test markdown content for the search result.",
}
]
}
mock_firecrawl_app.search.return_value = mock_search_result
with patch("exa_py.Exa") as mock_exa_class:
# Mock Exa
mock_exa_client = MagicMock()
mock_exa_class.return_value = mock_exa_client
mock_exa_client.search_and_contents.return_value = mock_exa_result
# create executor with mock dependencies
executor = LettaBuiltinToolExecutor(
@@ -391,44 +307,22 @@ async def test_web_search_uses_agent_env_var_model():
actor=MagicMock(),
)
task = SearchTask(query="test query", question="test question")
result = await executor.web_search(agent_state=mock_agent_state, query="test query", num_results=3, include_text=True)
await executor.web_search(agent_state=mock_agent_state, tasks=[task], limit=1, return_raw=False)
# Verify Exa was called correctly
mock_exa_class.assert_called_once_with(api_key="test-exa-key")
mock_exa_client.search_and_contents.assert_called_once()
# verify correct model was used
mock_openai_client.beta.chat.completions.parse.assert_called_once()
call_args = mock_openai_client.beta.chat.completions.parse.call_args
assert call_args[1]["model"] == "gpt-4o"
# Check the call arguments
call_args = mock_exa_client.search_and_contents.call_args
assert call_args[1]["query"] == "test query"
assert call_args[1]["num_results"] == 3
assert call_args[1]["type"] == "auto"
assert call_args[1]["text"] == True
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
def test_web_search_using_agent_state_env_var(
client: Letta,
agent_state_with_firecrawl_key: AgentState,
llm_config: LLMConfig,
) -> None:
user_message = MessageCreate(
role="user",
content="I am executing a test. Use the web search tool to find where I, Charles Packer, the CEO of Letta, went to school.",
otid=USER_MESSAGE_OTID,
)
response = client.agents.messages.create(
agent_id=agent_state_with_firecrawl_key.id,
messages=[user_message],
)
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert tool_returns, "No ToolReturnMessage found"
returns = [m.tool_return for m in tool_returns]
print(returns)
# Parse the JSON response from web search
assert len(returns) > 0, "No tool returns found"
response_json = json.loads(returns[0])
# Basic structure assertions
assert "api_key_source" in response_json, "Missing 'api_key_source' field in response"
assert "results" in response_json, "Missing 'results' field in response"
assert response_json["api_key_source"] == "agent_environment"
# Verify the response format
response_json = json.loads(result)
assert "query" in response_json
assert "results" in response_json
assert response_json["query"] == "test query"
assert len(response_json["results"]) == 1

File diff suppressed because one or more lines are too long

View File

@@ -771,7 +771,7 @@
{
"created_at": "2025-08-08T18:14:10.519658+00:00",
"description": null,
"key": "FIRECRAWL_API_KEY",
"key": "EXA_API_KEY",
"updated_at": "2025-08-08T18:14:10.519658+00:00",
"value": ""
},

View File

@@ -142,7 +142,7 @@
{
"created_at": "2025-08-14T21:31:27.793445+00:00",
"description": null,
"key": "FIRECRAWL_API_KEY",
"key": "EXA_API_KEY",
"updated_at": "2025-08-14T21:31:27.793445+00:00",
"value": ""
},
@@ -455,4 +455,4 @@
],
"updated_at": "2025-08-14T22:49:29.169737+00:00",
"version": "0.10.0"
}
}

41
uv.lock generated
View File

@@ -1110,6 +1110,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ce/31/55cd413eaccd39125368be33c46de24a1f639f2e12349b0361b4678f3915/eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a", size = 5830 },
]
[[package]]
name = "exa-py"
version = "1.15.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
{ name = "openai" },
{ name = "pydantic" },
{ name = "requests" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5b/39/750a4a62e25afad962cf4f6d2c7e998ac25e8001b448ab6906292f86a88f/exa_py-1.15.4.tar.gz", hash = "sha256:707781bead63e495576375385729ebff4f5e843559e8acd78a0f0feb120292ca", size = 40250 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/9a/3503696c939aee273a55a20f434d1f61ea623ff12f09a612f0efed1087c7/exa_py-1.15.4-py3-none-any.whl", hash = "sha256:2c29e74f130a086e061bab10cb042f5c7894beb471eddb58dad8b2ed5e916cba", size = 55176 },
]
[[package]]
name = "executing"
version = "2.2.0"
@@ -1163,23 +1179,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/18/79/1b8fa1bb3568781e84c9200f951c735f3f157429f44be0495da55894d620/filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25", size = 19970 },
]
[[package]]
name = "firecrawl-py"
version = "2.16.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "nest-asyncio" },
{ name = "pydantic" },
{ name = "python-dotenv" },
{ name = "requests" },
{ name = "websockets" },
]
sdist = { url = "https://files.pythonhosted.org/packages/97/7a/e6911fe11d140db5b5deda7a3ad6ee80c7615fcb248097ec7a4bc784eff6/firecrawl_py-2.16.5.tar.gz", hash = "sha256:7f5186bba359a426140a6827b550a604e62bfbeda33ded757952899b1cca4c83", size = 40154 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0f/99/fa2a41e2c1c5ff3e01150470b0a8a630792a363724c188bf6522e6798be7/firecrawl_py-2.16.5-py3-none-any.whl", hash = "sha256:3caed19b8f21522ab9c2193c2226990f2468e6bc5669ef54aa156e4230a5e35e", size = 75978 },
]
[[package]]
name = "flask"
version = "3.1.2"
@@ -2418,8 +2417,8 @@ dependencies = [
{ name = "datamodel-code-generator", extra = ["http"] },
{ name = "demjson3" },
{ name = "docstring-parser" },
{ name = "exa-py" },
{ name = "faker" },
{ name = "firecrawl-py" },
{ name = "grpcio" },
{ name = "grpcio-tools" },
{ name = "html2text" },
@@ -2509,7 +2508,7 @@ experimental = [
]
external-tools = [
{ name = "docker" },
{ name = "firecrawl-py" },
{ name = "exa-py" },
{ name = "langchain" },
{ name = "langchain-community" },
{ name = "turbopuffer" },
@@ -2566,11 +2565,11 @@ requires-dist = [
{ name = "docker", marker = "extra == 'external-tools'", specifier = ">=7.1.0" },
{ name = "docstring-parser", specifier = ">=0.16,<0.17" },
{ name = "e2b-code-interpreter", marker = "extra == 'cloud-tool-sandbox'", specifier = ">=1.0.3" },
{ name = "exa-py", specifier = ">=1.15.4" },
{ name = "exa-py", marker = "extra == 'external-tools'", specifier = ">=1.15.4" },
{ name = "faker", specifier = ">=36.1.0" },
{ name = "fastapi", marker = "extra == 'desktop'", specifier = ">=0.115.6" },
{ name = "fastapi", marker = "extra == 'server'", specifier = ">=0.115.6" },
{ name = "firecrawl-py", specifier = ">=2.8.0,<3.0.0" },
{ name = "firecrawl-py", marker = "extra == 'external-tools'", specifier = ">=2.8.0,<3.0.0" },
{ name = "google-cloud-profiler", marker = "extra == 'experimental'", specifier = ">=4.1.0" },
{ name = "google-genai", marker = "extra == 'google'", specifier = ">=1.15.0" },
{ name = "granian", extras = ["uvloop", "reload"], marker = "extra == 'experimental'", specifier = ">=2.3.2" },