feat: add tool embedding and search [LET-6333] (#6398)

* feat: add tool embedding and search

* fix ci

* add env variable for embedding tools

---------

Co-authored-by: Ari Webb <ari@letta.com>
This commit is contained in:
Ari Webb
2025-11-25 15:51:43 -08:00
committed by Caren Thomas
parent 2c702785d7
commit 3e02f12dfd
7 changed files with 858 additions and 3 deletions

View File

@@ -16,6 +16,7 @@ from letta_client.types import CreateBlockParam, MessageCreateParam
from letta.config import LettaConfig
from letta.schemas.message import MessageSearchResult
from letta.schemas.tool import ToolSearchResult
from letta.server.rest_api.routers.v1.passages import PassageSearchResult
from letta.server.server import SyncServer
from letta.settings import model_settings, settings
@@ -53,6 +54,34 @@ def cleanup_agent_with_messages(client: Letta, agent_id: str):
print(f"Warning: Failed to clean up agent {agent_id}: {e}")
def cleanup_tool(client: Letta, tool_id: str):
"""
Helper function to properly clean up a tool by deleting it from both
Turbopuffer and the database.
Args:
client: Letta SDK client
tool_id: ID of the tool to clean up
"""
try:
# First, delete from Turbopuffer if tool embedding is enabled
try:
import asyncio
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_tools
if should_use_tpuf_for_tools():
tpuf_client = TurbopufferClient()
asyncio.run(tpuf_client.delete_tools(DEFAULT_ORG_ID, [tool_id]))
except Exception as e:
print(f"Warning: Failed to clean up Turbopuffer tool {tool_id}: {e}")
# Now delete the tool from the database
client.tools.delete(tool_id=tool_id)
except Exception as e:
print(f"Warning: Failed to clean up tool {tool_id}: {e}")
@pytest.fixture(scope="module")
def server():
"""Server fixture for testing"""
@@ -524,3 +553,163 @@ def test_passage_search_org_wide(client: Letta, enable_turbopuffer):
# Clean up agents
cleanup_agent_with_messages(client, agent1.id)
cleanup_agent_with_messages(client, agent2.id)
@pytest.fixture
def enable_tool_embedding():
"""Enable both Turbopuffer and tool embedding"""
original_use_tpuf = settings.use_tpuf
original_api_key = settings.tpuf_api_key
original_embed_tools = settings.embed_tools
original_environment = settings.environment
settings.use_tpuf = True
settings.tpuf_api_key = settings.tpuf_api_key or "test-key"
settings.embed_tools = True
settings.environment = "DEV"
yield
settings.use_tpuf = original_use_tpuf
settings.tpuf_api_key = original_api_key
settings.embed_tools = original_embed_tools
settings.environment = original_environment
@pytest.mark.skipif(
not (settings.use_tpuf and settings.tpuf_api_key and model_settings.openai_api_key and settings.embed_tools),
reason="Tool search requires Turbopuffer, OpenAI, and tool embedding to be enabled",
)
def test_tool_search_basic(client: Letta, enable_tool_embedding):
"""Test basic tool search functionality through the SDK"""
tool_ids = []
try:
# Create test tools with distinct descriptions for semantic search
test_tools = [
{
"source_code": '''
def send_email_to_user(recipient: str, subject: str, body: str) -> str:
"""Send an email message to a specified recipient.
Args:
recipient: Email address of the recipient
subject: Subject line of the email
body: Body content of the email message
Returns:
Confirmation message
"""
return f"Email sent to {recipient}"
''',
"description": "Send an email message to a specified recipient with subject and body.",
"tags": ["communication", "email"],
},
{
"source_code": '''
def fetch_weather_data(city: str, units: str = "celsius") -> str:
"""Fetch current weather information for a city.
Args:
city: Name of the city to get weather for
units: Temperature units (celsius or fahrenheit)
Returns:
Weather information string
"""
return f"Weather in {city}: sunny, 25 {units}"
''',
"description": "Fetch current weather information for a specified city.",
"tags": ["weather", "api"],
},
{
"source_code": '''
def calculate_compound_interest(principal: float, rate: float, years: int) -> float:
"""Calculate compound interest on an investment.
Args:
principal: Initial investment amount
rate: Annual interest rate as decimal
years: Number of years
Returns:
Final amount after compound interest
"""
return principal * (1 + rate) ** years
''',
"description": "Calculate compound interest on a financial investment over time.",
"tags": ["finance", "calculator"],
},
]
# Create tools via SDK
for tool_data in test_tools:
tool = client.tools.create(
source_code=tool_data["source_code"],
description=tool_data["description"],
tags=tool_data["tags"],
)
tool_ids.append(tool.id)
# Wait for embeddings to be indexed
time.sleep(3)
# Test semantic search - should find email-related tool
results = client.post(
"/v1/tools/search",
cast_to=list[ToolSearchResult],
body={
"query": "send message to someone",
"search_mode": "hybrid",
"limit": 10,
},
)
assert len(results) > 0, "Should find at least one tool"
# The email tool should be ranked highly for this query
tool_names = [result["tool"]["name"] for result in results]
assert "send_email_to_user" in tool_names, "Should find email tool for messaging query"
# Verify result structure
for result in results:
assert "tool" in result, "Result should have tool field"
assert "combined_score" in result, "Result should have combined_score field"
assert isinstance(result["combined_score"], float), "combined_score should be a float"
# Test search with different query - should find weather tool
weather_results = client.post(
"/v1/tools/search",
cast_to=list[ToolSearchResult],
body={
"query": "get temperature forecast",
"search_mode": "hybrid",
"limit": 10,
},
)
assert len(weather_results) > 0, "Should find tools for weather query"
weather_tool_names = [result["tool"]["name"] for result in weather_results]
assert "fetch_weather_data" in weather_tool_names, "Should find weather tool"
# Test search with tag filter
finance_results = client.post(
"/v1/tools/search",
cast_to=list[ToolSearchResult],
body={
"query": "money calculation",
"tags": ["finance"],
"search_mode": "hybrid",
"limit": 10,
},
)
# Should find the finance tool when filtering by tag
if len(finance_results) > 0:
finance_tool_names = [result["tool"]["name"] for result in finance_results]
assert "calculate_compound_interest" in finance_tool_names, "Should find finance tool with tag filter"
finally:
# Clean up all created tools
for tool_id in tool_ids:
cleanup_tool(client, tool_id)