feat: add tool embedding and search [LET-6333] (#6398)
* feat: add tool embedding and search * fix ci * add env variable for embedding tools --------- Co-authored-by: Ari Webb <ari@letta.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from letta_client.types import CreateBlockParam, MessageCreateParam
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.message import MessageSearchResult
|
||||
from letta.schemas.tool import ToolSearchResult
|
||||
from letta.server.rest_api.routers.v1.passages import PassageSearchResult
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import model_settings, settings
|
||||
@@ -53,6 +54,34 @@ def cleanup_agent_with_messages(client: Letta, agent_id: str):
|
||||
print(f"Warning: Failed to clean up agent {agent_id}: {e}")
|
||||
|
||||
|
||||
def cleanup_tool(client: Letta, tool_id: str):
|
||||
"""
|
||||
Helper function to properly clean up a tool by deleting it from both
|
||||
Turbopuffer and the database.
|
||||
|
||||
Args:
|
||||
client: Letta SDK client
|
||||
tool_id: ID of the tool to clean up
|
||||
"""
|
||||
try:
|
||||
# First, delete from Turbopuffer if tool embedding is enabled
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_tools
|
||||
|
||||
if should_use_tpuf_for_tools():
|
||||
tpuf_client = TurbopufferClient()
|
||||
asyncio.run(tpuf_client.delete_tools(DEFAULT_ORG_ID, [tool_id]))
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clean up Turbopuffer tool {tool_id}: {e}")
|
||||
|
||||
# Now delete the tool from the database
|
||||
client.tools.delete(tool_id=tool_id)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clean up tool {tool_id}: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""Server fixture for testing"""
|
||||
@@ -524,3 +553,163 @@ def test_passage_search_org_wide(client: Letta, enable_turbopuffer):
|
||||
# Clean up agents
|
||||
cleanup_agent_with_messages(client, agent1.id)
|
||||
cleanup_agent_with_messages(client, agent2.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_tool_embedding():
|
||||
"""Enable both Turbopuffer and tool embedding"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_embed_tools = settings.embed_tools
|
||||
original_environment = settings.environment
|
||||
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = settings.tpuf_api_key or "test-key"
|
||||
settings.embed_tools = True
|
||||
settings.environment = "DEV"
|
||||
|
||||
yield
|
||||
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.embed_tools = original_embed_tools
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (settings.use_tpuf and settings.tpuf_api_key and model_settings.openai_api_key and settings.embed_tools),
|
||||
reason="Tool search requires Turbopuffer, OpenAI, and tool embedding to be enabled",
|
||||
)
|
||||
def test_tool_search_basic(client: Letta, enable_tool_embedding):
|
||||
"""Test basic tool search functionality through the SDK"""
|
||||
tool_ids = []
|
||||
|
||||
try:
|
||||
# Create test tools with distinct descriptions for semantic search
|
||||
test_tools = [
|
||||
{
|
||||
"source_code": '''
|
||||
def send_email_to_user(recipient: str, subject: str, body: str) -> str:
|
||||
"""Send an email message to a specified recipient.
|
||||
|
||||
Args:
|
||||
recipient: Email address of the recipient
|
||||
subject: Subject line of the email
|
||||
body: Body content of the email message
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
"""
|
||||
return f"Email sent to {recipient}"
|
||||
''',
|
||||
"description": "Send an email message to a specified recipient with subject and body.",
|
||||
"tags": ["communication", "email"],
|
||||
},
|
||||
{
|
||||
"source_code": '''
|
||||
def fetch_weather_data(city: str, units: str = "celsius") -> str:
|
||||
"""Fetch current weather information for a city.
|
||||
|
||||
Args:
|
||||
city: Name of the city to get weather for
|
||||
units: Temperature units (celsius or fahrenheit)
|
||||
|
||||
Returns:
|
||||
Weather information string
|
||||
"""
|
||||
return f"Weather in {city}: sunny, 25 {units}"
|
||||
''',
|
||||
"description": "Fetch current weather information for a specified city.",
|
||||
"tags": ["weather", "api"],
|
||||
},
|
||||
{
|
||||
"source_code": '''
|
||||
def calculate_compound_interest(principal: float, rate: float, years: int) -> float:
|
||||
"""Calculate compound interest on an investment.
|
||||
|
||||
Args:
|
||||
principal: Initial investment amount
|
||||
rate: Annual interest rate as decimal
|
||||
years: Number of years
|
||||
|
||||
Returns:
|
||||
Final amount after compound interest
|
||||
"""
|
||||
return principal * (1 + rate) ** years
|
||||
''',
|
||||
"description": "Calculate compound interest on a financial investment over time.",
|
||||
"tags": ["finance", "calculator"],
|
||||
},
|
||||
]
|
||||
|
||||
# Create tools via SDK
|
||||
for tool_data in test_tools:
|
||||
tool = client.tools.create(
|
||||
source_code=tool_data["source_code"],
|
||||
description=tool_data["description"],
|
||||
tags=tool_data["tags"],
|
||||
)
|
||||
tool_ids.append(tool.id)
|
||||
|
||||
# Wait for embeddings to be indexed
|
||||
time.sleep(3)
|
||||
|
||||
# Test semantic search - should find email-related tool
|
||||
results = client.post(
|
||||
"/v1/tools/search",
|
||||
cast_to=list[ToolSearchResult],
|
||||
body={
|
||||
"query": "send message to someone",
|
||||
"search_mode": "hybrid",
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(results) > 0, "Should find at least one tool"
|
||||
|
||||
# The email tool should be ranked highly for this query
|
||||
tool_names = [result["tool"]["name"] for result in results]
|
||||
assert "send_email_to_user" in tool_names, "Should find email tool for messaging query"
|
||||
|
||||
# Verify result structure
|
||||
for result in results:
|
||||
assert "tool" in result, "Result should have tool field"
|
||||
assert "combined_score" in result, "Result should have combined_score field"
|
||||
assert isinstance(result["combined_score"], float), "combined_score should be a float"
|
||||
|
||||
# Test search with different query - should find weather tool
|
||||
weather_results = client.post(
|
||||
"/v1/tools/search",
|
||||
cast_to=list[ToolSearchResult],
|
||||
body={
|
||||
"query": "get temperature forecast",
|
||||
"search_mode": "hybrid",
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(weather_results) > 0, "Should find tools for weather query"
|
||||
weather_tool_names = [result["tool"]["name"] for result in weather_results]
|
||||
assert "fetch_weather_data" in weather_tool_names, "Should find weather tool"
|
||||
|
||||
# Test search with tag filter
|
||||
finance_results = client.post(
|
||||
"/v1/tools/search",
|
||||
cast_to=list[ToolSearchResult],
|
||||
body={
|
||||
"query": "money calculation",
|
||||
"tags": ["finance"],
|
||||
"search_mode": "hybrid",
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
# Should find the finance tool when filtering by tag
|
||||
if len(finance_results) > 0:
|
||||
finance_tool_names = [result["tool"]["name"] for result in finance_results]
|
||||
assert "calculate_compound_interest" in finance_tool_names, "Should find finance tool with tag filter"
|
||||
|
||||
finally:
|
||||
# Clean up all created tools
|
||||
for tool_id in tool_ids:
|
||||
cleanup_tool(client, tool_id)
|
||||
|
||||
Reference in New Issue
Block a user