feat: Add tavily search builtin tool (#2257)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -86,7 +86,7 @@ BASE_VOICE_SLEEPTIME_TOOLS = [
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
|
||||
|
||||
# Built in tools
|
||||
BUILTIN_TOOLS = ["run_code"]
|
||||
BUILTIN_TOOLS = ["run_code", "web_search"]
|
||||
|
||||
# Set of all built-in Letta tools
|
||||
LETTA_TOOL_SET = set(
|
||||
@@ -241,3 +241,7 @@ RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
|
||||
|
||||
MAX_FILENAME_LENGTH = 255
|
||||
RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"}
|
||||
|
||||
WEB_SEARCH_CLIP_CONTENT = False
|
||||
WEB_SEARCH_INCLUDE_SCORE = False
|
||||
WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n"
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
from typing import Literal
|
||||
|
||||
|
||||
async def web_search(query: str) -> str:
|
||||
"""
|
||||
Search the web for information.
|
||||
Args:
|
||||
query (str): The query to search the web for.
|
||||
Returns:
|
||||
str: The search results.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")
|
||||
|
||||
|
||||
def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
|
||||
"""
|
||||
Run code in a sandbox. Supports Python, Javascript, Typescript, R, and Java.
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import List, Optional
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, ForeignKeyConstraint, Index, String
|
||||
from sqlalchemy import JSON, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
|
||||
@@ -9,7 +9,6 @@ from marshmallow import ValidationError
|
||||
from orjson import orjson
|
||||
from pydantic import Field
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
|
||||
@@ -81,7 +81,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
except Exception:
|
||||
logger.exception("unhandled_streaming_error")
|
||||
more_body = False
|
||||
error_resp = {"error": {"message": "Internal Server Error"}}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
|
||||
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from textwrap import shorten
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from letta.constants import (
|
||||
@@ -8,6 +9,9 @@ from letta.constants import (
|
||||
CORE_MEMORY_LINE_NUMBER_WARNING,
|
||||
READ_ONLY_BLOCK_EDIT_ERROR,
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
|
||||
WEB_SEARCH_CLIP_CONTENT,
|
||||
WEB_SEARCH_INCLUDE_SCORE,
|
||||
WEB_SEARCH_SEPARATOR,
|
||||
)
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name
|
||||
@@ -689,9 +693,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
function_map = {
|
||||
"run_code": self.run_code,
|
||||
}
|
||||
function_map = {"run_code": self.run_code, "web_search": self.web_search}
|
||||
|
||||
if function_name not in function_map:
|
||||
raise ValueError(f"Unknown function: {function_name}")
|
||||
@@ -719,3 +721,49 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
|
||||
res = await sbx.run_code(**params)
|
||||
return str(res)
|
||||
|
||||
async def web_search(agent_state: "AgentState", query: str) -> str:
|
||||
"""
|
||||
Search the web for information.
|
||||
Args:
|
||||
query (str): The query to search the web for.
|
||||
Returns:
|
||||
str: The search results.
|
||||
"""
|
||||
|
||||
try:
|
||||
from tavily import AsyncTavilyClient
|
||||
except ImportError:
|
||||
raise ImportError("tavily is not installed in the tool execution environment")
|
||||
|
||||
# Check if the API key exists
|
||||
if tool_settings.tavily_api_key is None:
|
||||
raise ValueError("TAVILY_API_KEY is not set")
|
||||
|
||||
# Instantiate client and search
|
||||
tavily_client = AsyncTavilyClient(api_key=tool_settings.tavily_api_key)
|
||||
search_results = await tavily_client.search(query=query, auto_parameters=True)
|
||||
|
||||
results = search_results.get("results", [])
|
||||
if not results:
|
||||
return "No search results found."
|
||||
|
||||
# ---- format for the LLM -------------------------------------------------
|
||||
formatted_blocks = []
|
||||
for idx, item in enumerate(results, start=1):
|
||||
title = item.get("title") or "Untitled"
|
||||
url = item.get("url") or "Unknown URL"
|
||||
# keep each content snippet reasonably short so you don’t blow up context
|
||||
content = (
|
||||
shorten(item.get("content", "").strip(), width=600, placeholder=" …")
|
||||
if WEB_SEARCH_CLIP_CONTENT
|
||||
else item.get("content", "").strip()
|
||||
)
|
||||
score = item.get("score")
|
||||
if WEB_SEARCH_INCLUDE_SCORE:
|
||||
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Relevance score: {score:.4f}\n" f"Content: {content}\n"
|
||||
else:
|
||||
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Content: {content}\n"
|
||||
formatted_blocks.append(block)
|
||||
|
||||
return WEB_SEARCH_SEPARATOR.join(formatted_blocks)
|
||||
|
||||
@@ -15,6 +15,9 @@ class ToolSettings(BaseSettings):
|
||||
e2b_api_key: Optional[str] = None
|
||||
e2b_sandbox_template_id: Optional[str] = None # Updated manually
|
||||
|
||||
# Tavily search
|
||||
tavily_api_key: Optional[str] = None
|
||||
|
||||
# Local Sandbox configurations
|
||||
tool_exec_dir: Optional[str] = None
|
||||
tool_sandbox_timeout: float = 180
|
||||
|
||||
18
poetry.lock
generated
18
poetry.lock
generated
@@ -6299,6 +6299,22 @@ files = [
|
||||
{file = "striprtf-0.0.26.tar.gz", hash = "sha256:fdb2bba7ac440072d1c41eab50d8d74ae88f60a8b6575c6e2c7805dc462093aa"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tavily-python"
|
||||
version = "0.7.2"
|
||||
description = "Python wrapper for the Tavily API"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "tavily_python-0.7.2-py3-none-any.whl", hash = "sha256:0d7cc8b1a2f95ac10cf722094c3b5807aade67cc7750f7ca605edef7455d4c62"},
|
||||
{file = "tavily_python-0.7.2.tar.gz", hash = "sha256:34f713002887df2b5e6b8d7db7bc64ae107395bdb5f53611e80a89dac9cbdf19"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = "*"
|
||||
requests = "*"
|
||||
tiktoken = ">=0.5.1"
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "9.1.2"
|
||||
@@ -7172,4 +7188,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "e73bf0ff3ec8b6b839d69f2a6e51228fb61a20030e3b334e74e259361ca8ab43"
|
||||
content-hash = "837f6a25033a01cca117f4c61bcf973bc6ccfcda442615bbf4af038061bf88ce"
|
||||
|
||||
@@ -91,6 +91,7 @@ apscheduler = "^3.11.0"
|
||||
aiomultiprocess = "^0.9.1"
|
||||
matplotlib = "^3.10.1"
|
||||
asyncpg = "^0.30.0"
|
||||
tavily-python = "^0.7.2"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -88,10 +88,11 @@ def agent_state(client: Letta) -> AgentState:
|
||||
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
run_code_tool = client.tools.list(name="run_code")[0]
|
||||
web_search_tool = client.tools.list(name="web_search")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="supervisor",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id],
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
@@ -187,3 +188,28 @@ def test_run_code(
|
||||
assert any(expected in ret for ret in returns), (
|
||||
f"For language={language!r}, expected to find '{expected}' in tool_return, " f"but got {returns!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
|
||||
def test_web_search(
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
user_message = MessageCreate(
|
||||
role="user",
|
||||
content=("Use the web search tool to find the latest news about San Francisco."),
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[user_message],
|
||||
)
|
||||
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
assert tool_returns, "No ToolReturnMessage found"
|
||||
|
||||
returns = [m.tool_return for m in tool_returns]
|
||||
expected = "RESULT 1:"
|
||||
assert any(expected in ret for ret in returns), f"Expected to find '{expected}' in tool_return, " f"but got {returns!r}"
|
||||
|
||||
Reference in New Issue
Block a user