feat: Parallel web search tool (#2890)

This commit is contained in:
Matthew Zhou
2025-06-18 14:07:51 -07:00
committed by GitHub
parent c6aca63d56
commit d5d71a776a
5 changed files with 330 additions and 78 deletions

View File

@@ -1,4 +1,6 @@
from typing import Literal
from typing import List, Literal
from letta.functions.types import SearchTask
def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
@@ -16,29 +18,34 @@ def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) ->
async def web_search(
query: str,
question: str,
limit: int = 5,
tasks: List[SearchTask],
limit: int = 3,
return_raw: bool = False,
) -> str:
"""
Search the web with the `query` and extract passages that answer the provided `question`.
Search the web with a list of query/question pairs and extract passages that answer the corresponding questions.
Examples:
query -> "Tesla Q1 2025 earnings report PDF"
question -> "What was Tesla's net profit in Q1 2025?"
query -> "Letta API prebuilt tools core_memory_append"
question -> "What does the core_memory_append tool do in Letta?"
tasks -> [
SearchTask(
query="Tesla Q1 2025 earnings report PDF",
question="What was Tesla's net profit in Q1 2025?"
),
SearchTask(
query="Letta API prebuilt tools core_memory_append",
question="What does the core_memory_append tool do in Letta?"
)
]
Args:
query (str): The raw web-search query.
question (str): The information goal to answer using the retrieved pages. Consider the context and intent of the conversation so far when forming the question.
limit (int, optional): Maximum number of URLs to fetch and analyse (must be > 0). Defaults to 5.
return_raw (bool, optional): If set to True, returns the raw content of the web page. This should be False unless otherwise specified by the user. Defaults to False.
tasks (List[SearchTask]): A list of search tasks, each containing a `query` and a corresponding `question`.
limit (int, optional): Maximum number of URLs to fetch and analyse per task (must be > 0). Defaults to 3.
return_raw (bool, optional): If set to True, returns the raw content of the web pages.
This should be False unless otherwise specified by the user. Defaults to False.
Returns:
str: A JSON-encoded string containing ranked snippets with their source
URLs and relevance scores.
str: A JSON-encoded string containing a list of search results.
Each result includes ranked snippets with their source URLs and relevance scores,
corresponding to each search task.
"""
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")

View File

@@ -1,24 +1,26 @@
"""Prompts for Letta function tools."""
FIRECRAWL_SEARCH_SYSTEM_PROMPT = """You are an expert at extracting relevant information from web content.
FIRECRAWL_SEARCH_SYSTEM_PROMPT = """You are an expert information extraction assistant. Your task is to analyze a document and extract the most relevant passages that answer a specific question, based on a search query context.
Given a document with line numbers (format: "LINE_NUM: content"), identify passages that answer the provided question by returning line ranges:
- start_line: The starting line number (inclusive)
- end_line: The ending line number (inclusive)
Guidelines:
1. Extract substantial, lengthy text snippets that directly address the question
2. Preserve important context and details in each snippet - err on the side of including more rather than less
3. Keep thinking very brief (1 short sentence) - focus on WHY the snippet is relevant, not WHAT it says
4. Only extract snippets that actually answer or relate to the question - don't force relevance
5. Be comprehensive - include all relevant information, don't limit the number of snippets
6. Prioritize longer, information-rich passages over shorter ones"""
SELECTION PRINCIPLES:
1. Prefer comprehensive passages that include full context
2. Capture complete thoughts, examples, and explanations
3. When relevant content spans multiple paragraphs, include the entire section
4. Favor fewer, substantial passages over many fragments
Focus on passages that can stand alone as complete, meaningful responses."""
def get_firecrawl_search_user_prompt(query: str, question: str, markdown_content: str) -> str:
"""Generate the user prompt for firecrawl search analysis."""
def get_firecrawl_search_user_prompt(query: str, question: str, numbered_content: str) -> str:
"""Generate the user prompt for line-number based search analysis."""
return f"""Search Query: {query}
Question to Answer: {question}
Document Content:
```markdown
{markdown_content}
```
Document Content (with line numbers):
{numbered_content}
Please analyze this document and extract all relevant passages that help answer the question."""
Identify line ranges that best answer: "{question}"
Select comprehensive passages with full context. Include entire sections when relevant."""

6
letta/functions/types.py Normal file
View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel, Field
class SearchTask(BaseModel):
query: str = Field(description="Search query for web search")
question: str = Field(description="Question to answer from search results, considering full conversation context")

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple
from mistralai import OCRPageObject
@@ -27,3 +27,91 @@ class LlamaIndexChunker:
except Exception as e:
logger.error(f"Chunking failed: {str(e)}")
raise
class MarkdownChunker:
"""Markdown-specific chunker that preserves line numbers for citation purposes"""
def __init__(self, chunk_size: int = 2048):
self.chunk_size = chunk_size
# No overlap for line-based citations to avoid ambiguity
from llama_index.core.node_parser import MarkdownNodeParser
self.parser = MarkdownNodeParser()
def chunk_markdown_with_line_numbers(self, markdown_content: str) -> List[Tuple[str, int, int]]:
"""
Chunk markdown content while preserving line number mappings.
Returns:
List of tuples: (chunk_text, start_line, end_line)
"""
try:
# Split content into lines for line number tracking
lines = markdown_content.split("\n")
# Create nodes using MarkdownNodeParser
from llama_index.core import Document
document = Document(text=markdown_content)
nodes = self.parser.get_nodes_from_documents([document])
chunks_with_line_numbers = []
for node in nodes:
chunk_text = node.text
# Find the line numbers for this chunk
start_line, end_line = self._find_line_numbers(chunk_text, lines)
chunks_with_line_numbers.append((chunk_text, start_line, end_line))
return chunks_with_line_numbers
except Exception as e:
logger.error(f"Markdown chunking failed: {str(e)}")
# Fallback to simple line-based chunking
return self._fallback_line_chunking(markdown_content)
def _find_line_numbers(self, chunk_text: str, lines: List[str]) -> Tuple[int, int]:
"""Find the start and end line numbers for a given chunk of text."""
chunk_lines = chunk_text.split("\n")
# Find the first line of the chunk in the original document
start_line = 1
for i, line in enumerate(lines):
if chunk_lines[0].strip() in line.strip() and len(chunk_lines[0].strip()) > 10: # Avoid matching short lines
start_line = i + 1
break
# Calculate end line
end_line = start_line + len(chunk_lines) - 1
return start_line, min(end_line, len(lines))
def _fallback_line_chunking(self, markdown_content: str) -> List[Tuple[str, int, int]]:
"""Fallback chunking method that simply splits by lines with no overlap."""
lines = markdown_content.split("\n")
chunks = []
i = 0
while i < len(lines):
chunk_lines = []
start_line = i + 1
char_count = 0
# Build chunk until we hit size limit
while i < len(lines) and char_count < self.chunk_size:
line = lines[i]
chunk_lines.append(line)
char_count += len(line) + 1 # +1 for newline
i += 1
end_line = i
chunk_text = "\n".join(chunk_lines)
chunks.append((chunk_text, start_line, end_line))
# No overlap - continue from where we left off
return chunks

View File

@@ -1,10 +1,12 @@
import asyncio
import json
import time
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel
from letta.functions.prompts import FIRECRAWL_SEARCH_SYSTEM_PROMPT, get_firecrawl_search_user_prompt
from letta.functions.types import SearchTask
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -19,10 +21,16 @@ logger = get_logger(__name__)
class Citation(BaseModel):
"""A relevant text snippet extracted from a document."""
"""A relevant text snippet identified by line numbers in a document."""
text: str
thinking: str # Reasoning of why this snippet is relevant
start_line: int # Starting line number (1-indexed)
end_line: int # Ending line number (1-indexed, inclusive)
class CitationWithText(BaseModel):
"""A citation with the actual extracted text."""
text: str # The actual extracted text from the lines
class DocumentAnalysis(BaseModel):
@@ -31,6 +39,12 @@ class DocumentAnalysis(BaseModel):
citations: List[Citation]
class DocumentAnalysisWithText(BaseModel):
"""Analysis with extracted text from line citations."""
citations: List[CitationWithText]
class LettaBuiltinToolExecutor(ToolExecutor):
"""Executor for built in Letta tools."""
@@ -88,39 +102,62 @@ class LettaBuiltinToolExecutor(ToolExecutor):
out["error"] = err
return out
@trace_method
async def web_search(
self,
agent_state: "AgentState",
query: str,
question: str,
limit: int = 5,
tasks: List[SearchTask],
limit: int = 3,
return_raw: bool = False,
) -> str:
"""
Search the web with the `query` and extract passages that answer the provided `question`.
Search the web with a list of query/question pairs and extract passages that answer the corresponding questions.
Examples:
query -> "Tesla Q1 2025 earnings report PDF"
question -> "What was Tesla's net profit in Q1 2025?"
query -> "Letta API prebuilt tools core_memory_append"
question -> "What does the core_memory_append tool do in Letta?"
tasks -> [
SearchTask(
query="Tesla Q1 2025 earnings report PDF",
question="What was Tesla's net profit in Q1 2025?"
),
SearchTask(
query="Letta API prebuilt tools core_memory_append",
question="What does the core_memory_append tool do in Letta?"
)
]
Args:
query (str): The raw web-search query.
question (str): The information goal to answer using the retrieved pages.
limit (int, optional): Maximum number of URLs to fetch and analyse (must be > 0). Defaults to 5.
return_raw (bool, optional): If set to True, returns the raw content of the web page. This should be False unless otherwise specified by the user. Defaults to False.
tasks (List[SearchTask]): A list of search tasks, each containing a `query` and a corresponding `question`.
limit (int, optional): Maximum number of URLs to fetch and analyse per task (must be > 0). Defaults to 3.
return_raw (bool, optional): If set to True, returns the raw content of the web pages.
This should be False unless otherwise specified by the user. Defaults to False.
Returns:
str: A JSON-encoded string containing ranked snippets with their source
URLs and relevance scores.
str: A JSON-encoded string containing a list of search results.
Each result includes ranked snippets with their source URLs and relevance scores,
corresponding to each search task.
"""
# TODO: Temporary, maybe deprecate this field?
if return_raw:
logger.warning("WARNING! return_raw was set to True, we default to False always. Deprecate this field.")
return_raw = False
try:
from firecrawl import AsyncFirecrawlApp, ScrapeOptions
from firecrawl import AsyncFirecrawlApp
except ImportError:
raise ImportError("firecrawl-py is not installed in the tool execution environment")
if not tasks:
return json.dumps({"error": "No search tasks provided."})
# Convert dict objects to SearchTask objects
search_tasks = []
for task in tasks:
if isinstance(task, dict):
search_tasks.append(SearchTask(**task))
else:
search_tasks.append(task)
logger.info(f"[DEBUG] Starting web search with {len(search_tasks)} tasks, limit={limit}, return_raw={return_raw}")
# Check if the API key exists on the agent state
agent_state_tool_env_vars = agent_state.get_agent_env_vars_as_dict()
firecrawl_api_key = agent_state_tool_env_vars.get("FIRECRAWL_API_KEY") or tool_settings.firecrawl_api_key
@@ -136,17 +173,64 @@ class LettaBuiltinToolExecutor(ToolExecutor):
# Initialize Firecrawl client
app = AsyncFirecrawlApp(api_key=firecrawl_api_key)
# Perform the search, just request markdown
search_result = await app.search(query, limit=limit, scrape_options=ScrapeOptions(formats=["markdown"]))
# Process all search tasks in parallel
search_task_coroutines = [self._process_single_search_task(app, task, limit, return_raw, api_key_source) for task in search_tasks]
# Execute all searches concurrently
search_results = await asyncio.gather(*search_task_coroutines, return_exceptions=True)
# Build final response as a mapping of query -> result
final_results = {}
successful_tasks = 0
failed_tasks = 0
for i, result in enumerate(search_results):
query = search_tasks[i].query
if isinstance(result, Exception):
logger.error(f"Search task {i} failed: {result}")
failed_tasks += 1
final_results[query] = {"query": query, "question": search_tasks[i].question, "error": str(result)}
else:
successful_tasks += 1
final_results[query] = result
logger.info(f"[DEBUG] Web search completed: {successful_tasks} successful, {failed_tasks} failed")
# Build final response with api_key_source at top level
response = {"api_key_source": api_key_source, "results": final_results}
return json.dumps(response, indent=2, ensure_ascii=False)
@trace_method
async def _process_single_search_task(
self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str
) -> Dict[str, Any]:
"""Process a single search task."""
from firecrawl import ScrapeOptions
logger.info(f"[DEBUG] Starting Firecrawl search for query: '{task.query}' with limit={limit}")
# Perform the search for this task
search_result = await app.search(task.query, limit=limit, scrape_options=ScrapeOptions(formats=["markdown"]))
logger.info(
f"[DEBUG] Firecrawl search completed for '{task.query}': {len(search_result.get('data', [])) if search_result else 0} results"
)
if not search_result or not search_result.get("data"):
return json.dumps({"error": "No search results found."})
return {"query": task.query, "question": task.question, "error": "No search results found."}
# If raw results requested, return them directly
if return_raw:
return {"query": task.query, "question": task.question, "raw_results": search_result}
# Check if OpenAI API key is available for semantic parsing
if not return_raw and model_settings.openai_api_key:
if model_settings.openai_api_key:
try:
from openai import AsyncOpenAI
logger.info(f"[DEBUG] Starting OpenAI analysis for '{task.query}'")
# Initialize OpenAI client
client = AsyncOpenAI(
api_key=model_settings.openai_api_key,
@@ -160,15 +244,19 @@ class LettaBuiltinToolExecutor(ToolExecutor):
for result in search_result.get("data"):
if result.get("markdown"):
# Create async task for OpenAI analysis
task = self._analyze_document_with_openai(client, result["markdown"], query, question)
analysis_tasks.append(task)
analysis_task = self._analyze_document_with_openai(client, result["markdown"], task.query, task.question)
analysis_tasks.append(analysis_task)
results_with_markdown.append(result)
else:
results_without_markdown.append(result)
logger.info(f"[DEBUG] Starting parallel OpenAI analysis of {len(analysis_tasks)} documents for '{task.query}'")
# Fire off all OpenAI requests concurrently
analyses = await asyncio.gather(*analysis_tasks, return_exceptions=True)
logger.info(f"[DEBUG] Completed parallel OpenAI analysis of {len(analyses)} documents for '{task.query}'")
# Build processed results
processed_results = []
@@ -176,16 +264,21 @@ class LettaBuiltinToolExecutor(ToolExecutor):
for result, analysis in zip(results_with_markdown, analyses):
if isinstance(analysis, Exception) or analysis is None:
logger.error(f"Analysis failed for {result.get('url')}, falling back to raw results")
return str(search_result)
return {"query": task.query, "question": task.question, "raw_results": search_result}
# All analyses succeeded, build processed results
for result, analysis in zip(results_with_markdown, analyses):
# Extract actual text from line number citations
analysis_with_text = None
if analysis and analysis.citations:
analysis_with_text = self._extract_text_from_line_citations(analysis, result["markdown"])
processed_results.append(
{
"url": result.get("url"),
"title": result.get("title"),
"description": result.get("description"),
"analysis": analysis.model_dump() if analysis else None,
"analysis": analysis_with_text.model_dump() if analysis_with_text else None,
}
)
@@ -195,35 +288,95 @@ class LettaBuiltinToolExecutor(ToolExecutor):
{"url": result.get("url"), "title": result.get("title"), "description": result.get("description"), "analysis": None}
)
# Concatenate all relevant snippets into a final response
final_response = self._build_final_response(processed_results, query, question, api_key_source)
return final_response
# Build final response for this task
return self._build_final_response_dict(processed_results, task.query, task.question)
except Exception as e:
# Log error but continue with raw results
logger.error(f"Error with OpenAI processing: {e}")
logger.error(f"Error with OpenAI processing for task '{task.query}': {e}")
# Return raw search results if OpenAI processing isn't available or fails
return str(search_result)
return {"query": task.query, "question": task.question, "raw_results": search_result}
@trace_method
async def _analyze_document_with_openai(self, client, markdown_content: str, query: str, question: str) -> Optional[DocumentAnalysis]:
"""Use OpenAI to analyze a document and extract relevant passages."""
max_content_length = 200000 # GPT-4.1 has ~1M token context window, so we can be more generous with content length
if len(markdown_content) > max_content_length:
markdown_content = markdown_content[:max_content_length] + "..."
"""Use OpenAI to analyze a document and extract relevant passages using line numbers."""
original_length = len(markdown_content)
user_prompt = get_firecrawl_search_user_prompt(query, question, markdown_content)
# Create numbered markdown for the LLM to reference
numbered_lines = markdown_content.split("\n")
numbered_markdown = "\n".join([f"{i+1:4d}: {line}" for i, line in enumerate(numbered_lines)])
# Truncate if too long
max_content_length = 200000
truncated = False
if len(numbered_markdown) > max_content_length:
numbered_markdown = numbered_markdown[:max_content_length] + "..."
truncated = True
user_prompt = get_firecrawl_search_user_prompt(query, question, numbered_markdown)
logger.info(
f"[DEBUG] Starting OpenAI request with line numbers - Query: '{query}', Content: {original_length} chars (truncated: {truncated})"
)
# Time the OpenAI request
start_time = time.time()
response = await client.beta.chat.completions.parse(
model="gpt-4.1-mini-2025-04-14",
messages=[{"role": "system", "content": FIRECRAWL_SEARCH_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}],
response_format=DocumentAnalysis,
temperature=0.1,
max_tokens=300, # Limit output tokens - only need line numbers
)
return response.choices[0].message.parsed
end_time = time.time()
request_duration = end_time - start_time
def _build_final_response(self, processed_results: List[Dict], query: str, question: str, api_key_source: str = None) -> str:
"""Build the final JSON response from all processed results."""
# Get usage statistics and output length
usage = response.usage
parsed_result = response.choices[0].message.parsed
num_citations = len(parsed_result.citations) if parsed_result else 0
# Calculate output length (minimal now - just line numbers)
output_length = 0
if parsed_result and parsed_result.citations:
for citation in parsed_result.citations:
output_length += 20 # ~20 chars for line numbers only
logger.info(f"[TIMING] OpenAI request completed in {request_duration:.2f}s - Query: '{query}'")
logger.info(f"[TOKENS] Total: {usage.total_tokens} (prompt: {usage.prompt_tokens}, completion: {usage.completion_tokens})")
logger.info(f"[OUTPUT] Citations: {num_citations}, Output chars: {output_length} (line-number based)")
return parsed_result
def _extract_text_from_line_citations(self, analysis: DocumentAnalysis, original_markdown: str) -> DocumentAnalysisWithText:
"""Extract actual text from line number citations."""
lines = original_markdown.split("\n")
citations_with_text = []
for citation in analysis.citations:
try:
# Convert to 0-indexed and ensure bounds
start_idx = max(0, citation.start_line - 1)
end_idx = min(len(lines), citation.end_line)
# Extract the lines
extracted_lines = lines[start_idx:end_idx]
extracted_text = "\n".join(extracted_lines)
citations_with_text.append(CitationWithText(text=extracted_text))
except Exception as e:
logger.info(f"[DEBUG] Failed to extract text for citation lines {citation.start_line}-{citation.end_line}: {e}")
# Fall back to including the citation with empty text
citations_with_text.append(CitationWithText(text=""))
return DocumentAnalysisWithText(citations=citations_with_text)
@trace_method
def _build_final_response_dict(self, processed_results: List[Dict], query: str, question: str) -> Dict[str, Any]:
"""Build the final response dictionary from all processed results."""
# Build sources array
sources = []
@@ -250,11 +403,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
"sources": sources,
}
# Add API key source if provided
if api_key_source:
response["api_key_source"] = api_key_source
if total_snippets == 0:
response["message"] = "No relevant passages found that directly answer the question."
return json.dumps(response, indent=2, ensure_ascii=False)
return response