From 26ae9c450239fdadff88fd9dcd6d13e5c1acb129 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 20 May 2025 07:38:11 +0800 Subject: [PATCH] feat: Add tavily search builtin tool (#2257) --- letta/agent.py | 1 - letta/constants.py | 6 ++- letta/functions/function_sets/builtin.py | 12 +++++ letta/llm_api/openai_client.py | 2 +- letta/orm/provider_trace.py | 2 +- letta/server/rest_api/routers/v1/agents.py | 1 - letta/server/rest_api/streaming_response.py | 2 +- letta/services/telemetry_manager.py | 2 - letta/services/tool_executor/tool_executor.py | 54 +++++++++++++++++-- letta/settings.py | 3 ++ poetry.lock | 18 ++++++- pyproject.toml | 1 + tests/integration_test_builtin_tools.py | 28 +++++++++- 13 files changed, 119 insertions(+), 13 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index c156dd83..b276ed79 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1,7 +1,6 @@ import json import time import traceback -import uuid import warnings from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, Union diff --git a/letta/constants.py b/letta/constants.py index 741de313..7b3392ee 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -86,7 +86,7 @@ BASE_VOICE_SLEEPTIME_TOOLS = [ MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"] # Built in tools -BUILTIN_TOOLS = ["run_code"] +BUILTIN_TOOLS = ["run_code", "web_search"] # Set of all built-in Letta tools LETTA_TOOL_SET = set( @@ -241,3 +241,7 @@ RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5 MAX_FILENAME_LENGTH = 255 RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"} + +WEB_SEARCH_CLIP_CONTENT = False +WEB_SEARCH_INCLUDE_SCORE = False +WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n" diff --git a/letta/functions/function_sets/builtin.py b/letta/functions/function_sets/builtin.py index 3f839a62..c8d69568 100644 --- a/letta/functions/function_sets/builtin.py +++ b/letta/functions/function_sets/builtin.py @@ -1,6 +1,18 @@ from typing import Literal +async def web_search(query: str) -> str: + """ + Search the web for information. + Args: + query (str): The query to search the web for. + Returns: + str: The search results. + """ + + raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.") + + def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str: """ Run code in a sandbox. Supports Python, Javascript, Typescript, R, and Java. diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index f3353bed..639a550d 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -2,7 +2,7 @@ import os from typing import List, Optional import openai -from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream +from openai import AsyncOpenAI, AsyncStream, OpenAI from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk diff --git a/letta/orm/provider_trace.py b/letta/orm/provider_trace.py index c957636e..69b7df14 100644 --- a/letta/orm/provider_trace.py +++ b/letta/orm/provider_trace.py @@ -1,6 +1,6 @@ import uuid -from sqlalchemy import JSON, ForeignKeyConstraint, Index, String +from sqlalchemy import JSON, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index e26dbc62..bc8609f4 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -9,7 +9,6 @@ from marshmallow import ValidationError from orjson import orjson from pydantic import Field from sqlalchemy.exc import IntegrityError, OperationalError -from starlette.background import BackgroundTask from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index 06e019b3..13d57e87 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -81,7 +81,7 @@ class StreamingResponseWithStatusCode(StreamingResponse): } ) - except Exception as exc: + except Exception: logger.exception("unhandled_streaming_error") more_body = False error_resp = {"error": {"message": "Internal Server Error"}} diff --git a/letta/services/telemetry_manager.py b/letta/services/telemetry_manager.py index e6ab218c..10e99c9b 100644 --- a/letta/services/telemetry_manager.py +++ b/letta/services/telemetry_manager.py @@ -1,5 +1,3 @@ -from sqlalchemy import select - from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace from letta.schemas.provider_trace import ProviderTraceCreate diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index 82e88650..0cf1eb75 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -1,6 +1,7 @@ import math import traceback from abc import ABC, abstractmethod +from textwrap import shorten from typing import Any, Dict, Literal, Optional from letta.constants import ( @@ -8,6 +9,9 @@ from letta.constants import ( CORE_MEMORY_LINE_NUMBER_WARNING, READ_ONLY_BLOCK_EDIT_ERROR, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE, + WEB_SEARCH_CLIP_CONTENT, + WEB_SEARCH_INCLUDE_SCORE, + WEB_SEARCH_SEPARATOR, ) from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name @@ -689,9 +693,7 @@ class LettaBuiltinToolExecutor(ToolExecutor): sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: - function_map = { - "run_code": self.run_code, - } + function_map = {"run_code": self.run_code, "web_search": self.web_search} if function_name not in function_map: raise ValueError(f"Unknown function: {function_name}") @@ -719,3 +721,49 @@ class LettaBuiltinToolExecutor(ToolExecutor): res = await sbx.run_code(**params) return str(res) + + async def web_search(agent_state: "AgentState", query: str) -> str: + """ + Search the web for information. + Args: + query (str): The query to search the web for. + Returns: + str: The search results. + """ + + try: + from tavily import AsyncTavilyClient + except ImportError: + raise ImportError("tavily is not installed in the tool execution environment") + + # Check if the API key exists + if tool_settings.tavily_api_key is None: + raise ValueError("TAVILY_API_KEY is not set") + + # Instantiate client and search + tavily_client = AsyncTavilyClient(api_key=tool_settings.tavily_api_key) + search_results = await tavily_client.search(query=query, auto_parameters=True) + + results = search_results.get("results", []) + if not results: + return "No search results found." + + # ---- format for the LLM ------------------------------------------------- + formatted_blocks = [] + for idx, item in enumerate(results, start=1): + title = item.get("title") or "Untitled" + url = item.get("url") or "Unknown URL" + # keep each content snippet reasonably short so you don’t blow up context + content = ( + shorten(item.get("content", "").strip(), width=600, placeholder=" …") + if WEB_SEARCH_CLIP_CONTENT + else item.get("content", "").strip() + ) + score = item.get("score") + if WEB_SEARCH_INCLUDE_SCORE: + block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Relevance score: {score:.4f}\n" f"Content: {content}\n" + else: + block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Content: {content}\n" + formatted_blocks.append(block) + + return WEB_SEARCH_SEPARATOR.join(formatted_blocks) diff --git a/letta/settings.py b/letta/settings.py index 19c4adc9..a1cc61f6 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -15,6 +15,9 @@ class ToolSettings(BaseSettings): e2b_api_key: Optional[str] = None e2b_sandbox_template_id: Optional[str] = None # Updated manually + # Tavily search + tavily_api_key: Optional[str] = None + # Local Sandbox configurations tool_exec_dir: Optional[str] = None tool_sandbox_timeout: float = 180 diff --git a/poetry.lock b/poetry.lock index f123bb9e..cad3a0a3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6299,6 +6299,22 @@ files = [ {file = "striprtf-0.0.26.tar.gz", hash = "sha256:fdb2bba7ac440072d1c41eab50d8d74ae88f60a8b6575c6e2c7805dc462093aa"}, ] +[[package]] +name = "tavily-python" +version = "0.7.2" +description = "Python wrapper for the Tavily API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "tavily_python-0.7.2-py3-none-any.whl", hash = "sha256:0d7cc8b1a2f95ac10cf722094c3b5807aade67cc7750f7ca605edef7455d4c62"}, + {file = "tavily_python-0.7.2.tar.gz", hash = "sha256:34f713002887df2b5e6b8d7db7bc64ae107395bdb5f53611e80a89dac9cbdf19"}, +] + +[package.dependencies] +httpx = "*" +requests = "*" +tiktoken = ">=0.5.1" + [[package]] name = "tenacity" version = "9.1.2" @@ -7172,4 +7188,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.14,>=3.10" -content-hash = "e73bf0ff3ec8b6b839d69f2a6e51228fb61a20030e3b334e74e259361ca8ab43" +content-hash = "837f6a25033a01cca117f4c61bcf973bc6ccfcda442615bbf4af038061bf88ce" diff --git a/pyproject.toml b/pyproject.toml index 9cc8b155..30716cee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ apscheduler = "^3.11.0" aiomultiprocess = "^0.9.1" matplotlib = "^3.10.1" asyncpg = "^0.30.0" +tavily-python = "^0.7.2" [tool.poetry.extras] diff --git a/tests/integration_test_builtin_tools.py b/tests/integration_test_builtin_tools.py index e8781762..7e3faf78 100644 --- a/tests/integration_test_builtin_tools.py +++ b/tests/integration_test_builtin_tools.py @@ -88,10 +88,11 @@ def agent_state(client: Letta) -> AgentState: send_message_tool = client.tools.list(name="send_message")[0] run_code_tool = client.tools.list(name="run_code")[0] + web_search_tool = client.tools.list(name="web_search")[0] agent_state_instance = client.agents.create( name="supervisor", include_base_tools=False, - tool_ids=[send_message_tool.id, run_code_tool.id], + tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id], model="openai/gpt-4o", embedding="letta/letta-free", tags=["supervisor"], @@ -187,3 +188,28 @@ def test_run_code( assert any(expected in ret for ret in returns), ( f"For language={language!r}, expected to find '{expected}' in tool_return, " f"but got {returns!r}" ) + + +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS]) +def test_web_search( + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + user_message = MessageCreate( + role="user", + content=("Use the web search tool to find the latest news about San Francisco."), + otid=USER_MESSAGE_OTID, + ) + + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[user_message], + ) + + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert tool_returns, "No ToolReturnMessage found" + + returns = [m.tool_return for m in tool_returns] + expected = "RESULT 1:" + assert any(expected in ret for ret in returns), f"Expected to find '{expected}' in tool_return, " f"but got {returns!r}"