From d5d71a776a44d73b360545eda9cf0f1d152a0b6a Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 18 Jun 2025 14:07:51 -0700 Subject: [PATCH] feat: Parallel web search tool (#2890) --- letta/functions/function_sets/builtin.py | 39 +-- letta/functions/prompts.py | 34 +-- letta/functions/types.py | 6 + .../chunker/llama_index_chunker.py | 90 ++++++- .../tool_executor/builtin_tool_executor.py | 239 ++++++++++++++---- 5 files changed, 330 insertions(+), 78 deletions(-) create mode 100644 letta/functions/types.py diff --git a/letta/functions/function_sets/builtin.py b/letta/functions/function_sets/builtin.py index f06409a2..a7ff8fb3 100644 --- a/letta/functions/function_sets/builtin.py +++ b/letta/functions/function_sets/builtin.py @@ -1,4 +1,6 @@ -from typing import Literal +from typing import List, Literal + +from letta.functions.types import SearchTask def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str: @@ -16,29 +18,34 @@ def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> async def web_search( - query: str, - question: str, - limit: int = 5, + tasks: List[SearchTask], + limit: int = 3, return_raw: bool = False, ) -> str: """ - Search the web with the `query` and extract passages that answer the provided `question`. + Search the web with a list of query/question pairs and extract passages that answer the corresponding questions. Examples: - query -> "Tesla Q1 2025 earnings report PDF" - question -> "What was Tesla's net profit in Q1 2025?" - - query -> "Letta API prebuilt tools core_memory_append" - question -> "What does the core_memory_append tool do in Letta?" + tasks -> [ + SearchTask( + query="Tesla Q1 2025 earnings report PDF", + question="What was Tesla's net profit in Q1 2025?" + ), + SearchTask( + query="Letta API prebuilt tools core_memory_append", + question="What does the core_memory_append tool do in Letta?" + ) + ] Args: - query (str): The raw web-search query. - question (str): The information goal to answer using the retrieved pages. Consider the context and intent of the conversation so far when forming the question. - limit (int, optional): Maximum number of URLs to fetch and analyse (must be > 0). Defaults to 5. - return_raw (bool, optional): If set to True, returns the raw content of the web page. This should be False unless otherwise specified by the user. Defaults to False. + tasks (List[SearchTask]): A list of search tasks, each containing a `query` and a corresponding `question`. + limit (int, optional): Maximum number of URLs to fetch and analyse per task (must be > 0). Defaults to 3. + return_raw (bool, optional): If set to True, returns the raw content of the web pages. + This should be False unless otherwise specified by the user. Defaults to False. Returns: - str: A JSON-encoded string containing ranked snippets with their source - URLs and relevance scores. + str: A JSON-encoded string containing a list of search results. + Each result includes ranked snippets with their source URLs and relevance scores, + corresponding to each search task. """ raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.") diff --git a/letta/functions/prompts.py b/letta/functions/prompts.py index 85bd99a9..780280c3 100644 --- a/letta/functions/prompts.py +++ b/letta/functions/prompts.py @@ -1,24 +1,26 @@ -"""Prompts for Letta function tools.""" +FIRECRAWL_SEARCH_SYSTEM_PROMPT = """You are an expert at extracting relevant information from web content. -FIRECRAWL_SEARCH_SYSTEM_PROMPT = """You are an expert information extraction assistant. Your task is to analyze a document and extract the most relevant passages that answer a specific question, based on a search query context. +Given a document with line numbers (format: "LINE_NUM: content"), identify passages that answer the provided question by returning line ranges: +- start_line: The starting line number (inclusive) +- end_line: The ending line number (inclusive) -Guidelines: -1. Extract substantial, lengthy text snippets that directly address the question -2. Preserve important context and details in each snippet - err on the side of including more rather than less -3. Keep thinking very brief (1 short sentence) - focus on WHY the snippet is relevant, not WHAT it says -4. Only extract snippets that actually answer or relate to the question - don't force relevance -5. Be comprehensive - include all relevant information, don't limit the number of snippets -6. Prioritize longer, information-rich passages over shorter ones""" +SELECTION PRINCIPLES: +1. Prefer comprehensive passages that include full context +2. Capture complete thoughts, examples, and explanations +3. When relevant content spans multiple paragraphs, include the entire section +4. Favor fewer, substantial passages over many fragments + +Focus on passages that can stand alone as complete, meaningful responses.""" -def get_firecrawl_search_user_prompt(query: str, question: str, markdown_content: str) -> str: - """Generate the user prompt for firecrawl search analysis.""" +def get_firecrawl_search_user_prompt(query: str, question: str, numbered_content: str) -> str: + """Generate the user prompt for line-number based search analysis.""" return f"""Search Query: {query} Question to Answer: {question} -Document Content: -```markdown -{markdown_content} -``` +Document Content (with line numbers): +{numbered_content} -Please analyze this document and extract all relevant passages that help answer the question.""" +Identify line ranges that best answer: "{question}" + +Select comprehensive passages with full context. Include entire sections when relevant.""" diff --git a/letta/functions/types.py b/letta/functions/types.py new file mode 100644 index 00000000..420ef82c --- /dev/null +++ b/letta/functions/types.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, Field + + +class SearchTask(BaseModel): + query: str = Field(description="Search query for web search") + question: str = Field(description="Question to answer from search results, considering full conversation context") diff --git a/letta/services/file_processor/chunker/llama_index_chunker.py b/letta/services/file_processor/chunker/llama_index_chunker.py index 94f45e0a..dbb290e3 100644 --- a/letta/services/file_processor/chunker/llama_index_chunker.py +++ b/letta/services/file_processor/chunker/llama_index_chunker.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple from mistralai import OCRPageObject @@ -27,3 +27,91 @@ class LlamaIndexChunker: except Exception as e: logger.error(f"Chunking failed: {str(e)}") raise + + +class MarkdownChunker: + """Markdown-specific chunker that preserves line numbers for citation purposes""" + + def __init__(self, chunk_size: int = 2048): + self.chunk_size = chunk_size + # No overlap for line-based citations to avoid ambiguity + + from llama_index.core.node_parser import MarkdownNodeParser + + self.parser = MarkdownNodeParser() + + def chunk_markdown_with_line_numbers(self, markdown_content: str) -> List[Tuple[str, int, int]]: + """ + Chunk markdown content while preserving line number mappings. + + Returns: + List of tuples: (chunk_text, start_line, end_line) + """ + try: + # Split content into lines for line number tracking + lines = markdown_content.split("\n") + + # Create nodes using MarkdownNodeParser + from llama_index.core import Document + + document = Document(text=markdown_content) + nodes = self.parser.get_nodes_from_documents([document]) + + chunks_with_line_numbers = [] + + for node in nodes: + chunk_text = node.text + + # Find the line numbers for this chunk + start_line, end_line = self._find_line_numbers(chunk_text, lines) + + chunks_with_line_numbers.append((chunk_text, start_line, end_line)) + + return chunks_with_line_numbers + + except Exception as e: + logger.error(f"Markdown chunking failed: {str(e)}") + # Fallback to simple line-based chunking + return self._fallback_line_chunking(markdown_content) + + def _find_line_numbers(self, chunk_text: str, lines: List[str]) -> Tuple[int, int]: + """Find the start and end line numbers for a given chunk of text.""" + chunk_lines = chunk_text.split("\n") + + # Find the first line of the chunk in the original document + start_line = 1 + for i, line in enumerate(lines): + if chunk_lines[0].strip() in line.strip() and len(chunk_lines[0].strip()) > 10: # Avoid matching short lines + start_line = i + 1 + break + + # Calculate end line + end_line = start_line + len(chunk_lines) - 1 + + return start_line, min(end_line, len(lines)) + + def _fallback_line_chunking(self, markdown_content: str) -> List[Tuple[str, int, int]]: + """Fallback chunking method that simply splits by lines with no overlap.""" + lines = markdown_content.split("\n") + chunks = [] + + i = 0 + while i < len(lines): + chunk_lines = [] + start_line = i + 1 + char_count = 0 + + # Build chunk until we hit size limit + while i < len(lines) and char_count < self.chunk_size: + line = lines[i] + chunk_lines.append(line) + char_count += len(line) + 1 # +1 for newline + i += 1 + + end_line = i + chunk_text = "\n".join(chunk_lines) + chunks.append((chunk_text, start_line, end_line)) + + # No overlap - continue from where we left off + + return chunks diff --git a/letta/services/tool_executor/builtin_tool_executor.py b/letta/services/tool_executor/builtin_tool_executor.py index dbeefdc8..6753bba4 100644 --- a/letta/services/tool_executor/builtin_tool_executor.py +++ b/letta/services/tool_executor/builtin_tool_executor.py @@ -1,10 +1,12 @@ import asyncio import json +import time from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel from letta.functions.prompts import FIRECRAWL_SEARCH_SYSTEM_PROMPT, get_firecrawl_search_user_prompt +from letta.functions.types import SearchTask from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState @@ -19,10 +21,16 @@ logger = get_logger(__name__) class Citation(BaseModel): - """A relevant text snippet extracted from a document.""" + """A relevant text snippet identified by line numbers in a document.""" - text: str - thinking: str # Reasoning of why this snippet is relevant + start_line: int # Starting line number (1-indexed) + end_line: int # Ending line number (1-indexed, inclusive) + + +class CitationWithText(BaseModel): + """A citation with the actual extracted text.""" + + text: str # The actual extracted text from the lines class DocumentAnalysis(BaseModel): @@ -31,6 +39,12 @@ class DocumentAnalysis(BaseModel): citations: List[Citation] +class DocumentAnalysisWithText(BaseModel): + """Analysis with extracted text from line citations.""" + + citations: List[CitationWithText] + + class LettaBuiltinToolExecutor(ToolExecutor): """Executor for built in Letta tools.""" @@ -88,39 +102,62 @@ class LettaBuiltinToolExecutor(ToolExecutor): out["error"] = err return out + @trace_method async def web_search( self, agent_state: "AgentState", - query: str, - question: str, - limit: int = 5, + tasks: List[SearchTask], + limit: int = 3, return_raw: bool = False, ) -> str: """ - Search the web with the `query` and extract passages that answer the provided `question`. + Search the web with a list of query/question pairs and extract passages that answer the corresponding questions. Examples: - query -> "Tesla Q1 2025 earnings report PDF" - question -> "What was Tesla's net profit in Q1 2025?" - - query -> "Letta API prebuilt tools core_memory_append" - question -> "What does the core_memory_append tool do in Letta?" + tasks -> [ + SearchTask( + query="Tesla Q1 2025 earnings report PDF", + question="What was Tesla's net profit in Q1 2025?" + ), + SearchTask( + query="Letta API prebuilt tools core_memory_append", + question="What does the core_memory_append tool do in Letta?" + ) + ] Args: - query (str): The raw web-search query. - question (str): The information goal to answer using the retrieved pages. - limit (int, optional): Maximum number of URLs to fetch and analyse (must be > 0). Defaults to 5. - return_raw (bool, optional): If set to True, returns the raw content of the web page. This should be False unless otherwise specified by the user. Defaults to False. + tasks (List[SearchTask]): A list of search tasks, each containing a `query` and a corresponding `question`. + limit (int, optional): Maximum number of URLs to fetch and analyse per task (must be > 0). Defaults to 3. + return_raw (bool, optional): If set to True, returns the raw content of the web pages. + This should be False unless otherwise specified by the user. Defaults to False. Returns: - str: A JSON-encoded string containing ranked snippets with their source - URLs and relevance scores. + str: A JSON-encoded string containing a list of search results. + Each result includes ranked snippets with their source URLs and relevance scores, + corresponding to each search task. """ + # TODO: Temporary, maybe deprecate this field? + if return_raw: + logger.warning("WARNING! return_raw was set to True, we default to False always. Deprecate this field.") + return_raw = False try: - from firecrawl import AsyncFirecrawlApp, ScrapeOptions + from firecrawl import AsyncFirecrawlApp except ImportError: raise ImportError("firecrawl-py is not installed in the tool execution environment") + if not tasks: + return json.dumps({"error": "No search tasks provided."}) + + # Convert dict objects to SearchTask objects + search_tasks = [] + for task in tasks: + if isinstance(task, dict): + search_tasks.append(SearchTask(**task)) + else: + search_tasks.append(task) + + logger.info(f"[DEBUG] Starting web search with {len(search_tasks)} tasks, limit={limit}, return_raw={return_raw}") + # Check if the API key exists on the agent state agent_state_tool_env_vars = agent_state.get_agent_env_vars_as_dict() firecrawl_api_key = agent_state_tool_env_vars.get("FIRECRAWL_API_KEY") or tool_settings.firecrawl_api_key @@ -136,17 +173,64 @@ class LettaBuiltinToolExecutor(ToolExecutor): # Initialize Firecrawl client app = AsyncFirecrawlApp(api_key=firecrawl_api_key) - # Perform the search, just request markdown - search_result = await app.search(query, limit=limit, scrape_options=ScrapeOptions(formats=["markdown"])) + # Process all search tasks in parallel + search_task_coroutines = [self._process_single_search_task(app, task, limit, return_raw, api_key_source) for task in search_tasks] + + # Execute all searches concurrently + search_results = await asyncio.gather(*search_task_coroutines, return_exceptions=True) + + # Build final response as a mapping of query -> result + final_results = {} + successful_tasks = 0 + failed_tasks = 0 + + for i, result in enumerate(search_results): + query = search_tasks[i].query + if isinstance(result, Exception): + logger.error(f"Search task {i} failed: {result}") + failed_tasks += 1 + final_results[query] = {"query": query, "question": search_tasks[i].question, "error": str(result)} + else: + successful_tasks += 1 + final_results[query] = result + + logger.info(f"[DEBUG] Web search completed: {successful_tasks} successful, {failed_tasks} failed") + + # Build final response with api_key_source at top level + response = {"api_key_source": api_key_source, "results": final_results} + + return json.dumps(response, indent=2, ensure_ascii=False) + + @trace_method + async def _process_single_search_task( + self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str + ) -> Dict[str, Any]: + """Process a single search task.""" + from firecrawl import ScrapeOptions + + logger.info(f"[DEBUG] Starting Firecrawl search for query: '{task.query}' with limit={limit}") + + # Perform the search for this task + search_result = await app.search(task.query, limit=limit, scrape_options=ScrapeOptions(formats=["markdown"])) + + logger.info( + f"[DEBUG] Firecrawl search completed for '{task.query}': {len(search_result.get('data', [])) if search_result else 0} results" + ) if not search_result or not search_result.get("data"): - return json.dumps({"error": "No search results found."}) + return {"query": task.query, "question": task.question, "error": "No search results found."} + + # If raw results requested, return them directly + if return_raw: + return {"query": task.query, "question": task.question, "raw_results": search_result} # Check if OpenAI API key is available for semantic parsing - if not return_raw and model_settings.openai_api_key: + if model_settings.openai_api_key: try: from openai import AsyncOpenAI + logger.info(f"[DEBUG] Starting OpenAI analysis for '{task.query}'") + # Initialize OpenAI client client = AsyncOpenAI( api_key=model_settings.openai_api_key, @@ -160,15 +244,19 @@ class LettaBuiltinToolExecutor(ToolExecutor): for result in search_result.get("data"): if result.get("markdown"): # Create async task for OpenAI analysis - task = self._analyze_document_with_openai(client, result["markdown"], query, question) - analysis_tasks.append(task) + analysis_task = self._analyze_document_with_openai(client, result["markdown"], task.query, task.question) + analysis_tasks.append(analysis_task) results_with_markdown.append(result) else: results_without_markdown.append(result) + logger.info(f"[DEBUG] Starting parallel OpenAI analysis of {len(analysis_tasks)} documents for '{task.query}'") + # Fire off all OpenAI requests concurrently analyses = await asyncio.gather(*analysis_tasks, return_exceptions=True) + logger.info(f"[DEBUG] Completed parallel OpenAI analysis of {len(analyses)} documents for '{task.query}'") + # Build processed results processed_results = [] @@ -176,16 +264,21 @@ class LettaBuiltinToolExecutor(ToolExecutor): for result, analysis in zip(results_with_markdown, analyses): if isinstance(analysis, Exception) or analysis is None: logger.error(f"Analysis failed for {result.get('url')}, falling back to raw results") - return str(search_result) + return {"query": task.query, "question": task.question, "raw_results": search_result} # All analyses succeeded, build processed results for result, analysis in zip(results_with_markdown, analyses): + # Extract actual text from line number citations + analysis_with_text = None + if analysis and analysis.citations: + analysis_with_text = self._extract_text_from_line_citations(analysis, result["markdown"]) + processed_results.append( { "url": result.get("url"), "title": result.get("title"), "description": result.get("description"), - "analysis": analysis.model_dump() if analysis else None, + "analysis": analysis_with_text.model_dump() if analysis_with_text else None, } ) @@ -195,35 +288,95 @@ class LettaBuiltinToolExecutor(ToolExecutor): {"url": result.get("url"), "title": result.get("title"), "description": result.get("description"), "analysis": None} ) - # Concatenate all relevant snippets into a final response - final_response = self._build_final_response(processed_results, query, question, api_key_source) - return final_response + # Build final response for this task + return self._build_final_response_dict(processed_results, task.query, task.question) except Exception as e: # Log error but continue with raw results - logger.error(f"Error with OpenAI processing: {e}") + logger.error(f"Error with OpenAI processing for task '{task.query}': {e}") # Return raw search results if OpenAI processing isn't available or fails - return str(search_result) + return {"query": task.query, "question": task.question, "raw_results": search_result} + @trace_method async def _analyze_document_with_openai(self, client, markdown_content: str, query: str, question: str) -> Optional[DocumentAnalysis]: - """Use OpenAI to analyze a document and extract relevant passages.""" - max_content_length = 200000 # GPT-4.1 has ~1M token context window, so we can be more generous with content length - if len(markdown_content) > max_content_length: - markdown_content = markdown_content[:max_content_length] + "..." + """Use OpenAI to analyze a document and extract relevant passages using line numbers.""" + original_length = len(markdown_content) - user_prompt = get_firecrawl_search_user_prompt(query, question, markdown_content) + # Create numbered markdown for the LLM to reference + numbered_lines = markdown_content.split("\n") + numbered_markdown = "\n".join([f"{i+1:4d}: {line}" for i, line in enumerate(numbered_lines)]) + + # Truncate if too long + max_content_length = 200000 + truncated = False + if len(numbered_markdown) > max_content_length: + numbered_markdown = numbered_markdown[:max_content_length] + "..." + truncated = True + + user_prompt = get_firecrawl_search_user_prompt(query, question, numbered_markdown) + + logger.info( + f"[DEBUG] Starting OpenAI request with line numbers - Query: '{query}', Content: {original_length} chars (truncated: {truncated})" + ) + + # Time the OpenAI request + start_time = time.time() response = await client.beta.chat.completions.parse( model="gpt-4.1-mini-2025-04-14", messages=[{"role": "system", "content": FIRECRAWL_SEARCH_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}], response_format=DocumentAnalysis, temperature=0.1, + max_tokens=300, # Limit output tokens - only need line numbers ) - return response.choices[0].message.parsed + end_time = time.time() + request_duration = end_time - start_time - def _build_final_response(self, processed_results: List[Dict], query: str, question: str, api_key_source: str = None) -> str: - """Build the final JSON response from all processed results.""" + # Get usage statistics and output length + usage = response.usage + parsed_result = response.choices[0].message.parsed + num_citations = len(parsed_result.citations) if parsed_result else 0 + + # Calculate output length (minimal now - just line numbers) + output_length = 0 + if parsed_result and parsed_result.citations: + for citation in parsed_result.citations: + output_length += 20 # ~20 chars for line numbers only + + logger.info(f"[TIMING] OpenAI request completed in {request_duration:.2f}s - Query: '{query}'") + logger.info(f"[TOKENS] Total: {usage.total_tokens} (prompt: {usage.prompt_tokens}, completion: {usage.completion_tokens})") + logger.info(f"[OUTPUT] Citations: {num_citations}, Output chars: {output_length} (line-number based)") + + return parsed_result + + def _extract_text_from_line_citations(self, analysis: DocumentAnalysis, original_markdown: str) -> DocumentAnalysisWithText: + """Extract actual text from line number citations.""" + lines = original_markdown.split("\n") + citations_with_text = [] + + for citation in analysis.citations: + try: + # Convert to 0-indexed and ensure bounds + start_idx = max(0, citation.start_line - 1) + end_idx = min(len(lines), citation.end_line) + + # Extract the lines + extracted_lines = lines[start_idx:end_idx] + extracted_text = "\n".join(extracted_lines) + + citations_with_text.append(CitationWithText(text=extracted_text)) + + except Exception as e: + logger.info(f"[DEBUG] Failed to extract text for citation lines {citation.start_line}-{citation.end_line}: {e}") + # Fall back to including the citation with empty text + citations_with_text.append(CitationWithText(text="")) + + return DocumentAnalysisWithText(citations=citations_with_text) + + @trace_method + def _build_final_response_dict(self, processed_results: List[Dict], query: str, question: str) -> Dict[str, Any]: + """Build the final response dictionary from all processed results.""" # Build sources array sources = [] @@ -250,11 +403,7 @@ class LettaBuiltinToolExecutor(ToolExecutor): "sources": sources, } - # Add API key source if provided - if api_key_source: - response["api_key_source"] = api_key_source - if total_snippets == 0: response["message"] = "No relevant passages found that directly answer the question." - return json.dumps(response, indent=2, ensure_ascii=False) + return response