feat: Port firecrawl search into web search (#2868)
This commit is contained in:
@@ -125,7 +125,7 @@ MEMORY_TOOLS_LINE_NUMBER_PREFIX_REGEX = re.compile(
|
||||
)
|
||||
|
||||
# Built in tools
|
||||
BUILTIN_TOOLS = ["run_code", "web_search", "firecrawl_search"]
|
||||
BUILTIN_TOOLS = ["run_code", "web_search"]
|
||||
|
||||
# Built in tools
|
||||
FILES_TOOLS = ["open_file", "close_file", "grep", "search_files"]
|
||||
|
||||
@@ -1,18 +1,6 @@
|
||||
from typing import Literal
|
||||
|
||||
|
||||
async def web_search(query: str) -> str:
|
||||
"""
|
||||
Search the web for information.
|
||||
Args:
|
||||
query (str): The query to search the web for.
|
||||
Returns:
|
||||
str: The search results.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")
|
||||
|
||||
|
||||
def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
|
||||
"""
|
||||
Run code in a sandbox. Supports Python, Javascript, Typescript, R, and Java.
|
||||
@@ -27,7 +15,7 @@ def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) ->
|
||||
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")
|
||||
|
||||
|
||||
async def firecrawl_search(
|
||||
async def web_search(
|
||||
query: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
from textwrap import shorten
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import WEB_SEARCH_CLIP_CONTENT, WEB_SEARCH_INCLUDE_SCORE, WEB_SEARCH_SEPARATOR
|
||||
from letta.functions.prompts import FIRECRAWL_SEARCH_SYSTEM_PROMPT, get_firecrawl_search_user_prompt
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
@@ -47,7 +45,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
function_map = {"run_code": self.run_code, "web_search": self.web_search, "firecrawl_search": self.firecrawl_search}
|
||||
function_map = {"run_code": self.run_code, "web_search": self.web_search}
|
||||
|
||||
if function_name not in function_map:
|
||||
raise ValueError(f"Unknown function: {function_name}")
|
||||
@@ -90,53 +88,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
out["error"] = err
|
||||
return out
|
||||
|
||||
async def web_search(self, agent_state: "AgentState", query: str) -> str:
|
||||
"""
|
||||
Search the web for information.
|
||||
Args:
|
||||
query (str): The query to search the web for.
|
||||
Returns:
|
||||
str: The search results.
|
||||
"""
|
||||
|
||||
try:
|
||||
from tavily import AsyncTavilyClient
|
||||
except ImportError:
|
||||
raise ImportError("tavily is not installed in the tool execution environment")
|
||||
|
||||
# Check if the API key exists
|
||||
if tool_settings.tavily_api_key is None:
|
||||
raise ValueError("TAVILY_API_KEY is not set")
|
||||
|
||||
# Instantiate client and search
|
||||
tavily_client = AsyncTavilyClient(api_key=tool_settings.tavily_api_key)
|
||||
search_results = await tavily_client.search(query=query, auto_parameters=True)
|
||||
|
||||
results = search_results.get("results", [])
|
||||
if not results:
|
||||
return "No search results found."
|
||||
|
||||
# ---- format for the LLM -------------------------------------------------
|
||||
formatted_blocks = []
|
||||
for idx, item in enumerate(results, start=1):
|
||||
title = item.get("title") or "Untitled"
|
||||
url = item.get("url") or "Unknown URL"
|
||||
# keep each content snippet reasonably short so you don’t blow up context
|
||||
content = (
|
||||
shorten(item.get("content", "").strip(), width=600, placeholder=" …")
|
||||
if WEB_SEARCH_CLIP_CONTENT
|
||||
else item.get("content", "").strip()
|
||||
)
|
||||
score = item.get("score")
|
||||
if WEB_SEARCH_INCLUDE_SCORE:
|
||||
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Relevance score: {score:.4f}\n" f"Content: {content}\n"
|
||||
else:
|
||||
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Content: {content}\n"
|
||||
formatted_blocks.append(block)
|
||||
|
||||
return WEB_SEARCH_SEPARATOR.join(formatted_blocks)
|
||||
|
||||
async def firecrawl_search(
|
||||
async def web_search(
|
||||
self,
|
||||
agent_state: "AgentState",
|
||||
query: str,
|
||||
|
||||
@@ -76,11 +76,10 @@ def agent_state(client: Letta) -> AgentState:
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
run_code_tool = client.tools.list(name="run_code")[0]
|
||||
web_search_tool = client.tools.list(name="web_search")[0]
|
||||
firecrawl_search_tool = client.tools.list(name="firecrawl_search")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="test_builtin_tools_agent",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id, firecrawl_search_tool.id],
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["test_builtin_tools_agent"],
|
||||
@@ -98,11 +97,10 @@ def agent_state_with_firecrawl_key(client: Letta) -> AgentState:
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
run_code_tool = client.tools.list(name="run_code")[0]
|
||||
web_search_tool = client.tools.list(name="web_search")[0]
|
||||
firecrawl_search_tool = client.tools.list(name="firecrawl_search")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="test_builtin_tools_agent",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id, firecrawl_search_tool.id],
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["test_builtin_tools_agent"],
|
||||
@@ -207,32 +205,7 @@ def test_web_search(
|
||||
) -> None:
|
||||
user_message = MessageCreate(
|
||||
role="user",
|
||||
content="Use the web search tool to find the latest news about San Francisco.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[user_message],
|
||||
)
|
||||
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
assert tool_returns, "No ToolReturnMessage found"
|
||||
|
||||
returns = [m.tool_return for m in tool_returns]
|
||||
expected = "RESULT 1:"
|
||||
assert any(expected in ret for ret in returns), f"Expected to find '{expected}' in tool_return, " f"but got {returns!r}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
|
||||
def test_firecrawl_search(
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
user_message = MessageCreate(
|
||||
role="user",
|
||||
content="I am executing a test. Use the firecrawl search tool to find where I, Charles Packer, the CEO of Letta, went to school.",
|
||||
content="I am executing a test. Use the web search tool to find where I, Charles Packer, the CEO of Letta, went to school.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
|
||||
@@ -247,7 +220,7 @@ def test_firecrawl_search(
|
||||
returns = [m.tool_return for m in tool_returns]
|
||||
print(returns)
|
||||
|
||||
# Parse the JSON response from firecrawl_search
|
||||
# Parse the JSON response from web_search
|
||||
assert len(returns) > 0, "No tool returns found"
|
||||
response_json = json.loads(returns[0])
|
||||
|
||||
@@ -290,14 +263,14 @@ def test_firecrawl_search(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
|
||||
def test_firecrawl_search_using_agent_state_env_var(
|
||||
def test_web_search_using_agent_state_env_var(
|
||||
client: Letta,
|
||||
agent_state_with_firecrawl_key: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
user_message = MessageCreate(
|
||||
role="user",
|
||||
content="I am executing a test. Use the firecrawl search tool to find where I, Charles Packer, the CEO of Letta, went to school.",
|
||||
content="I am executing a test. Use the web search tool to find where I, Charles Packer, the CEO of Letta, went to school.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
|
||||
@@ -312,7 +285,7 @@ def test_firecrawl_search_using_agent_state_env_var(
|
||||
returns = [m.tool_return for m in tool_returns]
|
||||
print(returns)
|
||||
|
||||
# Parse the JSON response from firecrawl_search
|
||||
# Parse the JSON response from web search
|
||||
assert len(returns) > 0, "No tool returns found"
|
||||
response_json = json.loads(returns[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user