feat: Add tavily search builtin tool (#2257)

This commit is contained in:
Matthew Zhou
2025-05-20 07:38:11 +08:00
committed by GitHub
parent 9542dd2fd9
commit 26ae9c4502
13 changed files with 119 additions and 13 deletions

View File

@@ -1,7 +1,6 @@
import json
import time
import traceback
import uuid
import warnings
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union

View File

@@ -86,7 +86,7 @@ BASE_VOICE_SLEEPTIME_TOOLS = [
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
# Built in tools
BUILTIN_TOOLS = ["run_code"]
BUILTIN_TOOLS = ["run_code", "web_search"]
# Set of all built-in Letta tools
LETTA_TOOL_SET = set(
@@ -241,3 +241,7 @@ RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
MAX_FILENAME_LENGTH = 255
RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"}
WEB_SEARCH_CLIP_CONTENT = False
WEB_SEARCH_INCLUDE_SCORE = False
WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n"

View File

@@ -1,6 +1,18 @@
from typing import Literal
async def web_search(query: str) -> str:
"""
Search the web for information.
Args:
query (str): The query to search the web for.
Returns:
str: The search results.
"""
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")
def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
"""
Run code in a sandbox. Supports Python, Javascript, Typescript, R, and Java.

View File

@@ -2,7 +2,7 @@ import os
from typing import List, Optional
import openai
from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

View File

@@ -1,6 +1,6 @@
import uuid
from sqlalchemy import JSON, ForeignKeyConstraint, Index, String
from sqlalchemy import JSON, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin

View File

@@ -9,7 +9,6 @@ from marshmallow import ValidationError
from orjson import orjson
from pydantic import Field
from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.background import BackgroundTask
from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent

View File

@@ -81,7 +81,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
}
)
except Exception as exc:
except Exception:
logger.exception("unhandled_streaming_error")
more_body = False
error_resp = {"error": {"message": "Internal Server Error"}}

View File

@@ -1,5 +1,3 @@
from sqlalchemy import select
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
from letta.schemas.provider_trace import ProviderTraceCreate

View File

@@ -1,6 +1,7 @@
import math
import traceback
from abc import ABC, abstractmethod
from textwrap import shorten
from typing import Any, Dict, Literal, Optional
from letta.constants import (
@@ -8,6 +9,9 @@ from letta.constants import (
CORE_MEMORY_LINE_NUMBER_WARNING,
READ_ONLY_BLOCK_EDIT_ERROR,
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
WEB_SEARCH_CLIP_CONTENT,
WEB_SEARCH_INCLUDE_SCORE,
WEB_SEARCH_SEPARATOR,
)
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name
@@ -689,9 +693,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
function_map = {
"run_code": self.run_code,
}
function_map = {"run_code": self.run_code, "web_search": self.web_search}
if function_name not in function_map:
raise ValueError(f"Unknown function: {function_name}")
@@ -719,3 +721,49 @@ class LettaBuiltinToolExecutor(ToolExecutor):
res = await sbx.run_code(**params)
return str(res)
async def web_search(agent_state: "AgentState", query: str) -> str:
"""
Search the web for information.
Args:
query (str): The query to search the web for.
Returns:
str: The search results.
"""
try:
from tavily import AsyncTavilyClient
except ImportError:
raise ImportError("tavily is not installed in the tool execution environment")
# Check if the API key exists
if tool_settings.tavily_api_key is None:
raise ValueError("TAVILY_API_KEY is not set")
# Instantiate client and search
tavily_client = AsyncTavilyClient(api_key=tool_settings.tavily_api_key)
search_results = await tavily_client.search(query=query, auto_parameters=True)
results = search_results.get("results", [])
if not results:
return "No search results found."
# ---- format for the LLM -------------------------------------------------
formatted_blocks = []
for idx, item in enumerate(results, start=1):
title = item.get("title") or "Untitled"
url = item.get("url") or "Unknown URL"
# keep each content snippet reasonably short so you dont blow up context
content = (
shorten(item.get("content", "").strip(), width=600, placeholder="")
if WEB_SEARCH_CLIP_CONTENT
else item.get("content", "").strip()
)
score = item.get("score")
if WEB_SEARCH_INCLUDE_SCORE:
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Relevance score: {score:.4f}\n" f"Content: {content}\n"
else:
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Content: {content}\n"
formatted_blocks.append(block)
return WEB_SEARCH_SEPARATOR.join(formatted_blocks)

View File

@@ -15,6 +15,9 @@ class ToolSettings(BaseSettings):
e2b_api_key: Optional[str] = None
e2b_sandbox_template_id: Optional[str] = None # Updated manually
# Tavily search
tavily_api_key: Optional[str] = None
# Local Sandbox configurations
tool_exec_dir: Optional[str] = None
tool_sandbox_timeout: float = 180

18
poetry.lock generated
View File

@@ -6299,6 +6299,22 @@ files = [
{file = "striprtf-0.0.26.tar.gz", hash = "sha256:fdb2bba7ac440072d1c41eab50d8d74ae88f60a8b6575c6e2c7805dc462093aa"},
]
[[package]]
name = "tavily-python"
version = "0.7.2"
description = "Python wrapper for the Tavily API"
optional = false
python-versions = ">=3.6"
files = [
{file = "tavily_python-0.7.2-py3-none-any.whl", hash = "sha256:0d7cc8b1a2f95ac10cf722094c3b5807aade67cc7750f7ca605edef7455d4c62"},
{file = "tavily_python-0.7.2.tar.gz", hash = "sha256:34f713002887df2b5e6b8d7db7bc64ae107395bdb5f53611e80a89dac9cbdf19"},
]
[package.dependencies]
httpx = "*"
requests = "*"
tiktoken = ">=0.5.1"
[[package]]
name = "tenacity"
version = "9.1.2"
@@ -7172,4 +7188,4 @@ tests = ["wikipedia"]
[metadata]
lock-version = "2.0"
python-versions = "<3.14,>=3.10"
content-hash = "e73bf0ff3ec8b6b839d69f2a6e51228fb61a20030e3b334e74e259361ca8ab43"
content-hash = "837f6a25033a01cca117f4c61bcf973bc6ccfcda442615bbf4af038061bf88ce"

View File

@@ -91,6 +91,7 @@ apscheduler = "^3.11.0"
aiomultiprocess = "^0.9.1"
matplotlib = "^3.10.1"
asyncpg = "^0.30.0"
tavily-python = "^0.7.2"
[tool.poetry.extras]

View File

@@ -88,10 +88,11 @@ def agent_state(client: Letta) -> AgentState:
send_message_tool = client.tools.list(name="send_message")[0]
run_code_tool = client.tools.list(name="run_code")[0]
web_search_tool = client.tools.list(name="web_search")[0]
agent_state_instance = client.agents.create(
name="supervisor",
include_base_tools=False,
tool_ids=[send_message_tool.id, run_code_tool.id],
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
@@ -187,3 +188,28 @@ def test_run_code(
assert any(expected in ret for ret in returns), (
f"For language={language!r}, expected to find '{expected}' in tool_return, " f"but got {returns!r}"
)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
def test_web_search(
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
user_message = MessageCreate(
role="user",
content=("Use the web search tool to find the latest news about San Francisco."),
otid=USER_MESSAGE_OTID,
)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[user_message],
)
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert tool_returns, "No ToolReturnMessage found"
returns = [m.tool_return for m in tool_returns]
expected = "RESULT 1:"
assert any(expected in ret for ret in returns), f"Expected to find '{expected}' in tool_return, " f"but got {returns!r}"