feat: add tool embedding and search [LET-6333] (#6398)
* feat: add tool embedding and search * fix ci * add env variable for embedding tools --------- Co-authored-by: Ari Webb <ari@letta.com>
This commit is contained in:
@@ -1275,6 +1275,51 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/tools/search": {
|
||||
"post": {
|
||||
"tags": ["tools"],
|
||||
"summary": "Search Tools",
|
||||
"description": "Search tools using semantic search.\n\nRequires tool embedding to be enabled (embed_tools=True). Uses vector search,\nfull-text search, or hybrid mode to find tools matching the query.\n\nReturns tools ranked by relevance with their search scores.",
|
||||
"operationId": "search_tools",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ToolSearchRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolSearchResult"
|
||||
},
|
||||
"title": "Response Search Tools"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/tools/add-base-tools": {
|
||||
"post": {
|
||||
"tags": ["tools"],
|
||||
@@ -36392,6 +36437,125 @@
|
||||
"required": ["source_code", "args"],
|
||||
"title": "ToolRunFromSource"
|
||||
},
|
||||
"ToolSearchRequest": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Query",
|
||||
"description": "Text query for semantic search."
|
||||
},
|
||||
"search_mode": {
|
||||
"type": "string",
|
||||
"enum": ["vector", "fts", "hybrid"],
|
||||
"title": "Search Mode",
|
||||
"description": "Search mode: vector, fts, or hybrid.",
|
||||
"default": "hybrid"
|
||||
},
|
||||
"tool_types": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Tool Types",
|
||||
"description": "Filter by tool types (e.g., 'custom', 'letta_core')."
|
||||
},
|
||||
"tags": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Tags",
|
||||
"description": "Filter by tags (match any)."
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"maximum": 100,
|
||||
"minimum": 1,
|
||||
"title": "Limit",
|
||||
"description": "Maximum number of results to return.",
|
||||
"default": 50
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"title": "ToolSearchRequest",
|
||||
"description": "Request model for searching tools using semantic search."
|
||||
},
|
||||
"ToolSearchResult": {
|
||||
"properties": {
|
||||
"tool": {
|
||||
"$ref": "#/components/schemas/Tool",
|
||||
"description": "The matched tool."
|
||||
},
|
||||
"embedded_text": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedded Text",
|
||||
"description": "The embedded text content used for matching."
|
||||
},
|
||||
"fts_rank": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Fts Rank",
|
||||
"description": "Full-text search rank position."
|
||||
},
|
||||
"vector_rank": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Vector Rank",
|
||||
"description": "Vector search rank position."
|
||||
},
|
||||
"combined_score": {
|
||||
"type": "number",
|
||||
"title": "Combined Score",
|
||||
"description": "Combined relevance score (RRF for hybrid mode)."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"required": ["tool", "combined_score"],
|
||||
"title": "ToolSearchResult",
|
||||
"description": "Result from a tool search operation."
|
||||
},
|
||||
"ToolType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Turbopuffer utilities for archival memory storage."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
@@ -29,6 +30,11 @@ def should_use_tpuf_for_messages() -> bool:
|
||||
return should_use_tpuf() and bool(settings.embed_all_messages)
|
||||
|
||||
|
||||
def should_use_tpuf_for_tools() -> bool:
|
||||
"""Check if Turbopuffer should be used for tools."""
|
||||
return should_use_tpuf() and bool(settings.embed_tools)
|
||||
|
||||
|
||||
class TurbopufferClient:
|
||||
"""Client for managing archival memory with Turbopuffer vector database."""
|
||||
|
||||
@@ -104,6 +110,157 @@ class TurbopufferClient:
|
||||
|
||||
return namespace_name
|
||||
|
||||
@trace_method
|
||||
async def _get_tool_namespace_name(self, organization_id: str) -> str:
|
||||
"""Get namespace name for tools (org-scoped).
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID for namespace generation
|
||||
|
||||
Returns:
|
||||
The org-scoped namespace name for tools
|
||||
"""
|
||||
environment = settings.environment
|
||||
if environment:
|
||||
namespace_name = f"tools_{organization_id}_{environment.lower()}"
|
||||
else:
|
||||
namespace_name = f"tools_{organization_id}"
|
||||
|
||||
return namespace_name
|
||||
|
||||
def _extract_tool_text(self, tool: "PydanticTool") -> str:
|
||||
"""Extract searchable text from a tool for embedding.
|
||||
|
||||
Combines name, description, and JSON schema into a structured format
|
||||
that provides rich context for semantic search.
|
||||
|
||||
Args:
|
||||
tool: The tool to extract text from
|
||||
|
||||
Returns:
|
||||
JSON-formatted string containing tool information
|
||||
"""
|
||||
|
||||
parts = {
|
||||
"name": tool.name or "",
|
||||
"description": tool.description or "",
|
||||
}
|
||||
|
||||
# Extract parameter information from JSON schema
|
||||
if tool.json_schema:
|
||||
# Include function description from schema if different from tool description
|
||||
schema_description = tool.json_schema.get("description", "")
|
||||
if schema_description and schema_description != tool.description:
|
||||
parts["schema_description"] = schema_description
|
||||
|
||||
# Extract parameter information
|
||||
parameters = tool.json_schema.get("parameters", {})
|
||||
if parameters:
|
||||
properties = parameters.get("properties", {})
|
||||
param_descriptions = []
|
||||
for param_name, param_info in properties.items():
|
||||
param_desc = param_info.get("description", "")
|
||||
param_type = param_info.get("type", "any")
|
||||
if param_desc:
|
||||
param_descriptions.append(f"{param_name} ({param_type}): {param_desc}")
|
||||
else:
|
||||
param_descriptions.append(f"{param_name} ({param_type})")
|
||||
if param_descriptions:
|
||||
parts["parameters"] = param_descriptions
|
||||
|
||||
# Include tags for additional context
|
||||
if tool.tags:
|
||||
parts["tags"] = tool.tags
|
||||
|
||||
return json.dumps(parts)
|
||||
|
||||
@trace_method
|
||||
async def insert_tools(
|
||||
self,
|
||||
tools: List["PydanticTool"],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser",
|
||||
) -> bool:
|
||||
"""Insert tools into Turbopuffer.
|
||||
|
||||
Args:
|
||||
tools: List of tools to store
|
||||
organization_id: Organization ID for the tools
|
||||
actor: User actor for embedding generation
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not tools:
|
||||
return True
|
||||
|
||||
# Extract text and filter out empty content
|
||||
tool_texts = []
|
||||
valid_tools = []
|
||||
for tool in tools:
|
||||
text = self._extract_tool_text(tool)
|
||||
if text.strip():
|
||||
tool_texts.append(text)
|
||||
valid_tools.append(tool)
|
||||
|
||||
if not valid_tools:
|
||||
logger.warning("All tools had empty text content, skipping insertion")
|
||||
return True
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = await self._generate_embeddings(tool_texts, actor)
|
||||
|
||||
namespace_name = await self._get_tool_namespace_name(organization_id)
|
||||
|
||||
# Prepare column-based data
|
||||
ids = []
|
||||
vectors = []
|
||||
texts = []
|
||||
names = []
|
||||
organization_ids = []
|
||||
tool_types = []
|
||||
tags_arrays = []
|
||||
created_ats = []
|
||||
|
||||
for tool, text, embedding in zip(valid_tools, tool_texts, embeddings):
|
||||
ids.append(tool.id)
|
||||
vectors.append(embedding)
|
||||
texts.append(text)
|
||||
names.append(tool.name or "")
|
||||
organization_ids.append(organization_id)
|
||||
tool_types.append(tool.tool_type.value if tool.tool_type else "custom")
|
||||
tags_arrays.append(tool.tags or [])
|
||||
created_ats.append(getattr(tool, "created_at", None) or datetime.now(timezone.utc))
|
||||
|
||||
upsert_columns = {
|
||||
"id": ids,
|
||||
"vector": vectors,
|
||||
"text": texts,
|
||||
"name": names,
|
||||
"organization_id": organization_ids,
|
||||
"tool_type": tool_types,
|
||||
"tags": tags_arrays,
|
||||
"created_at": created_ats,
|
||||
}
|
||||
|
||||
try:
|
||||
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
await namespace.write(
|
||||
upsert_columns=upsert_columns,
|
||||
distance_metric="cosine_distance",
|
||||
schema={"text": {"type": "string", "full_text_search": True}},
|
||||
)
|
||||
logger.info(f"Successfully inserted {len(ids)} tools to Turbopuffer")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert tools to Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def insert_archival_memories(
|
||||
self,
|
||||
@@ -1468,3 +1625,150 @@ class TurbopufferClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete source passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
# tool methods
|
||||
|
||||
@trace_method
|
||||
async def delete_tools(self, organization_id: str, tool_ids: List[str]) -> bool:
|
||||
"""Delete tools from Turbopuffer.
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID for namespace lookup
|
||||
tool_ids: List of tool IDs to delete
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not tool_ids:
|
||||
return True
|
||||
|
||||
namespace_name = await self._get_tool_namespace_name(organization_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
await namespace.write(deletes=tool_ids)
|
||||
logger.info(f"Successfully deleted {len(tool_ids)} tools from Turbopuffer")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete tools from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def query_tools(
|
||||
self,
|
||||
organization_id: str,
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp"
|
||||
top_k: int = 50,
|
||||
tool_types: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
) -> List[Tuple[dict, float, dict]]:
|
||||
"""Query tools from Turbopuffer using semantic search.
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID for namespace lookup
|
||||
actor: User actor for embedding generation
|
||||
query_text: Text query for search
|
||||
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp"
|
||||
top_k: Number of results to return
|
||||
tool_types: Optional list of tool types to filter by
|
||||
tags: Optional list of tags to filter by (match any)
|
||||
vector_weight: Weight for vector search in hybrid mode
|
||||
fts_weight: Weight for FTS in hybrid mode
|
||||
|
||||
Returns:
|
||||
List of (tool_dict, score, metadata) tuples
|
||||
"""
|
||||
# Generate embedding for vector/hybrid search
|
||||
query_embedding = None
|
||||
if query_text and search_mode in ["vector", "hybrid"]:
|
||||
embeddings = await self._generate_embeddings([query_text], actor)
|
||||
query_embedding = embeddings[0] if embeddings else None
|
||||
|
||||
# Fallback to timestamp-based retrieval when no query
|
||||
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
||||
search_mode = "timestamp"
|
||||
|
||||
namespace_name = await self._get_tool_namespace_name(organization_id)
|
||||
|
||||
# Build filters
|
||||
all_filters = []
|
||||
|
||||
if tool_types:
|
||||
if len(tool_types) == 1:
|
||||
all_filters.append(("tool_type", "Eq", tool_types[0]))
|
||||
else:
|
||||
all_filters.append(("tool_type", "In", tool_types))
|
||||
|
||||
if tags:
|
||||
all_filters.append(("tags", "ContainsAny", tags))
|
||||
|
||||
# Combine filters
|
||||
final_filter = None
|
||||
if len(all_filters) == 1:
|
||||
final_filter = all_filters[0]
|
||||
elif len(all_filters) > 1:
|
||||
final_filter = ("And", all_filters)
|
||||
|
||||
try:
|
||||
result = await self._execute_query(
|
||||
namespace_name=namespace_name,
|
||||
search_mode=search_mode,
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "name", "organization_id", "tool_type", "tags", "created_at"],
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
)
|
||||
|
||||
if search_mode == "hybrid":
|
||||
vector_results = self._process_tool_query_results(result.results[0])
|
||||
fts_results = self._process_tool_query_results(result.results[1])
|
||||
results_with_metadata = self._reciprocal_rank_fusion(
|
||||
vector_results=vector_results,
|
||||
fts_results=fts_results,
|
||||
get_id_func=lambda d: d["id"],
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
top_k=top_k,
|
||||
)
|
||||
return results_with_metadata
|
||||
else:
|
||||
results = self._process_tool_query_results(result)
|
||||
results_with_metadata = []
|
||||
for idx, tool_dict in enumerate(results):
|
||||
metadata = {
|
||||
"combined_score": 1.0 / (idx + 1),
|
||||
"search_mode": search_mode,
|
||||
f"{search_mode}_rank": idx + 1,
|
||||
}
|
||||
results_with_metadata.append((tool_dict, metadata["combined_score"], metadata))
|
||||
return results_with_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query tools from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
def _process_tool_query_results(self, result) -> List[dict]:
|
||||
"""Process results from a tool query into tool dicts."""
|
||||
tools = []
|
||||
for row in result.rows:
|
||||
tool_dict = {
|
||||
"id": row.id,
|
||||
"text": getattr(row, "text", ""),
|
||||
"name": getattr(row, "name", ""),
|
||||
"organization_id": getattr(row, "organization_id", None),
|
||||
"tool_type": getattr(row, "tool_type", None),
|
||||
"tags": getattr(row, "tags", []),
|
||||
"created_at": getattr(row, "created_at", None),
|
||||
}
|
||||
tools.append(tool_dict)
|
||||
return tools
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
@@ -206,3 +206,23 @@ class ToolRunFromSource(LettaBase):
|
||||
)
|
||||
pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.")
|
||||
npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.")
|
||||
|
||||
|
||||
class ToolSearchRequest(LettaBase):
|
||||
"""Request model for searching tools using semantic search."""
|
||||
|
||||
query: Optional[str] = Field(None, description="Text query for semantic search.")
|
||||
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode: vector, fts, or hybrid.")
|
||||
tool_types: Optional[List[str]] = Field(None, description="Filter by tool types (e.g., 'custom', 'letta_core').")
|
||||
tags: Optional[List[str]] = Field(None, description="Filter by tags (match any).")
|
||||
limit: int = Field(50, description="Maximum number of results to return.", ge=1, le=100)
|
||||
|
||||
|
||||
class ToolSearchResult(LettaBase):
|
||||
"""Result from a tool search operation."""
|
||||
|
||||
tool: Tool = Field(..., description="The matched tool.")
|
||||
embedded_text: Optional[str] = Field(None, description="The embedded text content used for matching.")
|
||||
fts_rank: Optional[int] = Field(None, description="Full-text search rank position.")
|
||||
vector_rank: Optional[int] = Field(None, description="Vector search rank position.")
|
||||
combined_score: float = Field(..., description="Combined relevance score (RRF for hybrid mode).")
|
||||
|
||||
@@ -31,7 +31,7 @@ from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.schemas.tool import BaseTool, Tool, ToolCreate, ToolRunFromSource, ToolUpdate
|
||||
from letta.schemas.tool import BaseTool, Tool, ToolCreate, ToolRunFromSource, ToolSearchRequest, ToolSearchResult, ToolUpdate
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
from letta.server.server import SyncServer
|
||||
@@ -271,6 +271,46 @@ async def list_tools(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/search", response_model=List[ToolSearchResult], operation_id="search_tools")
|
||||
async def search_tools(
|
||||
request: ToolSearchRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""
|
||||
Search tools using semantic search.
|
||||
|
||||
Requires tool embedding to be enabled (embed_tools=True). Uses vector search,
|
||||
full-text search, or hybrid mode to find tools matching the query.
|
||||
|
||||
Returns tools ranked by relevance with their search scores.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
try:
|
||||
results = await server.tool_manager.search_tools_async(
|
||||
actor=actor,
|
||||
query_text=request.query,
|
||||
search_mode=request.search_mode,
|
||||
tool_types=request.tool_types,
|
||||
tags=request.tags,
|
||||
limit=request.limit,
|
||||
)
|
||||
|
||||
return [
|
||||
ToolSearchResult(
|
||||
tool=tool,
|
||||
embedded_text=None, # Could be populated if needed
|
||||
fts_rank=metadata.get("fts_rank"),
|
||||
vector_rank=metadata.get("vector_rank"),
|
||||
combined_score=metadata.get("combined_score", 0.0),
|
||||
)
|
||||
for tool, metadata in results
|
||||
]
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/", response_model=Tool, operation_id="create_tool")
|
||||
async def create_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
|
||||
@@ -41,7 +41,7 @@ from letta.services.helpers.agent_manager_helper import calculate_multi_agent_to
|
||||
from letta.services.mcp.types import SSEServerConfig, StdioServerConfig
|
||||
from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update
|
||||
from letta.settings import settings, tool_settings
|
||||
from letta.utils import enforce_types, printd
|
||||
from letta.utils import enforce_types, fire_and_forget, printd
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -338,6 +338,15 @@ class ToolManager:
|
||||
if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
|
||||
await self.create_or_update_modal_app(created_tool, actor)
|
||||
|
||||
# Embed tool in Turbopuffer if enabled
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
|
||||
|
||||
if should_use_tpuf_for_tools():
|
||||
fire_and_forget(
|
||||
self._embed_tool_background(created_tool, actor),
|
||||
task_name=f"embed_tool_{created_tool.id}",
|
||||
)
|
||||
|
||||
return created_tool
|
||||
|
||||
@enforce_types
|
||||
@@ -869,6 +878,28 @@ class ToolManager:
|
||||
logger.info(f"Deploying Modal app for tool {updated_tool.id} with new hash: {new_hash}")
|
||||
await self.create_or_update_modal_app(updated_tool, actor)
|
||||
|
||||
# Update embedding in Turbopuffer if enabled (delete old, insert new)
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
|
||||
|
||||
if should_use_tpuf_for_tools():
|
||||
|
||||
async def update_tool_embedding():
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
# Delete old and re-insert (simpler than update)
|
||||
await tpuf_client.delete_tools(actor.organization_id, [updated_tool.id])
|
||||
await tpuf_client.insert_tools([updated_tool], actor.organization_id, actor)
|
||||
logger.info(f"Successfully updated tool {updated_tool.id} in Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update tool {updated_tool.id} in Turbopuffer: {e}")
|
||||
|
||||
fire_and_forget(
|
||||
update_tool_embedding(),
|
||||
task_name=f"update_tool_embedding_{updated_tool.id}",
|
||||
)
|
||||
|
||||
return updated_tool
|
||||
|
||||
@enforce_types
|
||||
@@ -901,6 +932,20 @@ class ToolManager:
|
||||
logger.warning(f"Skipping Modal cleanup for corrupted tool {tool_id}: {e}")
|
||||
|
||||
await tool.hard_delete_async(db_session=session, actor=actor)
|
||||
|
||||
# Delete from Turbopuffer if enabled
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
|
||||
|
||||
if should_use_tpuf_for_tools():
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_tools(actor.organization_id, [tool_id])
|
||||
logger.info(f"Successfully deleted tool {tool_id} from Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete tool {tool_id} from Turbopuffer: {e}")
|
||||
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
@@ -1121,3 +1166,95 @@ class ToolManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Modal app deletion for tool {tool.name}: {e}")
|
||||
raise
|
||||
|
||||
async def _embed_tool_background(
|
||||
self,
|
||||
tool: PydanticTool,
|
||||
actor: PydanticUser,
|
||||
) -> None:
|
||||
"""Background task to embed a tool in Turbopuffer.
|
||||
|
||||
Args:
|
||||
tool: The tool to embed
|
||||
actor: User performing the action
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.insert_tools(
|
||||
tools=[tool],
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
)
|
||||
logger.info(f"Successfully embedded tool {tool.id} in Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed tool {tool.id} in Turbopuffer: {e}")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def search_tools_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "hybrid",
|
||||
tool_types: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
limit: int = 50,
|
||||
) -> List[tuple[PydanticTool, dict]]:
|
||||
"""
|
||||
Search tools using Turbopuffer semantic search.
|
||||
|
||||
Args:
|
||||
actor: User performing the search
|
||||
query_text: Text query for semantic search
|
||||
search_mode: "vector", "fts", or "hybrid" (default: "hybrid")
|
||||
tool_types: Optional list of tool types to filter by
|
||||
tags: Optional list of tags to filter by
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of (tool, metadata) tuples where metadata contains search scores
|
||||
|
||||
Raises:
|
||||
ValueError: If Turbopuffer is not enabled for tools
|
||||
"""
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_tools
|
||||
|
||||
if not should_use_tpuf_for_tools():
|
||||
raise ValueError("Tool semantic search requires tool embedding to be enabled (embed_tools=True).")
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
results = await tpuf_client.query_tools(
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
query_text=query_text,
|
||||
search_mode=search_mode,
|
||||
top_k=limit,
|
||||
tool_types=tool_types,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Fetch full tool objects from database
|
||||
tool_ids = [tool_dict["id"] for tool_dict, _, _ in results]
|
||||
tools = []
|
||||
for tool_id in tool_ids:
|
||||
try:
|
||||
tool = await self.get_tool_by_id_async(tool_id, actor=actor)
|
||||
tools.append(tool)
|
||||
except Exception:
|
||||
pass # Tool may have been deleted
|
||||
|
||||
tool_map = {tool.id: tool for tool in tools}
|
||||
|
||||
# Build result list preserving order and including metadata
|
||||
result_list = []
|
||||
for tool_dict, _, metadata in results:
|
||||
tool_id = tool_dict["id"]
|
||||
if tool_id in tool_map:
|
||||
result_list.append((tool_map[tool_id], metadata))
|
||||
|
||||
return result_list
|
||||
|
||||
@@ -334,6 +334,7 @@ class Settings(BaseSettings):
|
||||
tpuf_api_key: Optional[str] = None
|
||||
tpuf_region: str = "gcp-us-central1"
|
||||
embed_all_messages: bool = False
|
||||
embed_tools: bool = False
|
||||
|
||||
# For encryption
|
||||
encryption_key: Optional[str] = None
|
||||
|
||||
@@ -16,6 +16,7 @@ from letta_client.types import CreateBlockParam, MessageCreateParam
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.message import MessageSearchResult
|
||||
from letta.schemas.tool import ToolSearchResult
|
||||
from letta.server.rest_api.routers.v1.passages import PassageSearchResult
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import model_settings, settings
|
||||
@@ -53,6 +54,34 @@ def cleanup_agent_with_messages(client: Letta, agent_id: str):
|
||||
print(f"Warning: Failed to clean up agent {agent_id}: {e}")
|
||||
|
||||
|
||||
def cleanup_tool(client: Letta, tool_id: str):
|
||||
"""
|
||||
Helper function to properly clean up a tool by deleting it from both
|
||||
Turbopuffer and the database.
|
||||
|
||||
Args:
|
||||
client: Letta SDK client
|
||||
tool_id: ID of the tool to clean up
|
||||
"""
|
||||
try:
|
||||
# First, delete from Turbopuffer if tool embedding is enabled
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_tools
|
||||
|
||||
if should_use_tpuf_for_tools():
|
||||
tpuf_client = TurbopufferClient()
|
||||
asyncio.run(tpuf_client.delete_tools(DEFAULT_ORG_ID, [tool_id]))
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clean up Turbopuffer tool {tool_id}: {e}")
|
||||
|
||||
# Now delete the tool from the database
|
||||
client.tools.delete(tool_id=tool_id)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clean up tool {tool_id}: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""Server fixture for testing"""
|
||||
@@ -524,3 +553,163 @@ def test_passage_search_org_wide(client: Letta, enable_turbopuffer):
|
||||
# Clean up agents
|
||||
cleanup_agent_with_messages(client, agent1.id)
|
||||
cleanup_agent_with_messages(client, agent2.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_tool_embedding():
|
||||
"""Enable both Turbopuffer and tool embedding"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_embed_tools = settings.embed_tools
|
||||
original_environment = settings.environment
|
||||
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = settings.tpuf_api_key or "test-key"
|
||||
settings.embed_tools = True
|
||||
settings.environment = "DEV"
|
||||
|
||||
yield
|
||||
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.embed_tools = original_embed_tools
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (settings.use_tpuf and settings.tpuf_api_key and model_settings.openai_api_key and settings.embed_tools),
|
||||
reason="Tool search requires Turbopuffer, OpenAI, and tool embedding to be enabled",
|
||||
)
|
||||
def test_tool_search_basic(client: Letta, enable_tool_embedding):
|
||||
"""Test basic tool search functionality through the SDK"""
|
||||
tool_ids = []
|
||||
|
||||
try:
|
||||
# Create test tools with distinct descriptions for semantic search
|
||||
test_tools = [
|
||||
{
|
||||
"source_code": '''
|
||||
def send_email_to_user(recipient: str, subject: str, body: str) -> str:
|
||||
"""Send an email message to a specified recipient.
|
||||
|
||||
Args:
|
||||
recipient: Email address of the recipient
|
||||
subject: Subject line of the email
|
||||
body: Body content of the email message
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
"""
|
||||
return f"Email sent to {recipient}"
|
||||
''',
|
||||
"description": "Send an email message to a specified recipient with subject and body.",
|
||||
"tags": ["communication", "email"],
|
||||
},
|
||||
{
|
||||
"source_code": '''
|
||||
def fetch_weather_data(city: str, units: str = "celsius") -> str:
|
||||
"""Fetch current weather information for a city.
|
||||
|
||||
Args:
|
||||
city: Name of the city to get weather for
|
||||
units: Temperature units (celsius or fahrenheit)
|
||||
|
||||
Returns:
|
||||
Weather information string
|
||||
"""
|
||||
return f"Weather in {city}: sunny, 25 {units}"
|
||||
''',
|
||||
"description": "Fetch current weather information for a specified city.",
|
||||
"tags": ["weather", "api"],
|
||||
},
|
||||
{
|
||||
"source_code": '''
|
||||
def calculate_compound_interest(principal: float, rate: float, years: int) -> float:
|
||||
"""Calculate compound interest on an investment.
|
||||
|
||||
Args:
|
||||
principal: Initial investment amount
|
||||
rate: Annual interest rate as decimal
|
||||
years: Number of years
|
||||
|
||||
Returns:
|
||||
Final amount after compound interest
|
||||
"""
|
||||
return principal * (1 + rate) ** years
|
||||
''',
|
||||
"description": "Calculate compound interest on a financial investment over time.",
|
||||
"tags": ["finance", "calculator"],
|
||||
},
|
||||
]
|
||||
|
||||
# Create tools via SDK
|
||||
for tool_data in test_tools:
|
||||
tool = client.tools.create(
|
||||
source_code=tool_data["source_code"],
|
||||
description=tool_data["description"],
|
||||
tags=tool_data["tags"],
|
||||
)
|
||||
tool_ids.append(tool.id)
|
||||
|
||||
# Wait for embeddings to be indexed
|
||||
time.sleep(3)
|
||||
|
||||
# Test semantic search - should find email-related tool
|
||||
results = client.post(
|
||||
"/v1/tools/search",
|
||||
cast_to=list[ToolSearchResult],
|
||||
body={
|
||||
"query": "send message to someone",
|
||||
"search_mode": "hybrid",
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(results) > 0, "Should find at least one tool"
|
||||
|
||||
# The email tool should be ranked highly for this query
|
||||
tool_names = [result["tool"]["name"] for result in results]
|
||||
assert "send_email_to_user" in tool_names, "Should find email tool for messaging query"
|
||||
|
||||
# Verify result structure
|
||||
for result in results:
|
||||
assert "tool" in result, "Result should have tool field"
|
||||
assert "combined_score" in result, "Result should have combined_score field"
|
||||
assert isinstance(result["combined_score"], float), "combined_score should be a float"
|
||||
|
||||
# Test search with different query - should find weather tool
|
||||
weather_results = client.post(
|
||||
"/v1/tools/search",
|
||||
cast_to=list[ToolSearchResult],
|
||||
body={
|
||||
"query": "get temperature forecast",
|
||||
"search_mode": "hybrid",
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(weather_results) > 0, "Should find tools for weather query"
|
||||
weather_tool_names = [result["tool"]["name"] for result in weather_results]
|
||||
assert "fetch_weather_data" in weather_tool_names, "Should find weather tool"
|
||||
|
||||
# Test search with tag filter
|
||||
finance_results = client.post(
|
||||
"/v1/tools/search",
|
||||
cast_to=list[ToolSearchResult],
|
||||
body={
|
||||
"query": "money calculation",
|
||||
"tags": ["finance"],
|
||||
"search_mode": "hybrid",
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
# Should find the finance tool when filtering by tag
|
||||
if len(finance_results) > 0:
|
||||
finance_tool_names = [result["tool"]["name"] for result in finance_results]
|
||||
assert "calculate_compound_interest" in finance_tool_names, "Should find finance tool with tag filter"
|
||||
|
||||
finally:
|
||||
# Clean up all created tools
|
||||
for tool_id in tool_ids:
|
||||
cleanup_tool(client, tool_id)
|
||||
|
||||
Reference in New Issue
Block a user