diff --git a/fern/openapi.json b/fern/openapi.json index f0015183..781c3f9e 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -1275,6 +1275,51 @@ } } }, + "/v1/tools/search": { + "post": { + "tags": ["tools"], + "summary": "Search Tools", + "description": "Search tools using semantic search.\n\nRequires tool embedding to be enabled (embed_tools=True). Uses vector search,\nfull-text search, or hybrid mode to find tools matching the query.\n\nReturns tools ranked by relevance with their search scores.", + "operationId": "search_tools", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ToolSearchRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolSearchResult" + }, + "title": "Response Search Tools" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/tools/add-base-tools": { "post": { "tags": ["tools"], @@ -36392,6 +36437,125 @@ "required": ["source_code", "args"], "title": "ToolRunFromSource" }, + "ToolSearchRequest": { + "properties": { + "query": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Query", + "description": "Text query for semantic search." + }, + "search_mode": { + "type": "string", + "enum": ["vector", "fts", "hybrid"], + "title": "Search Mode", + "description": "Search mode: vector, fts, or hybrid.", + "default": "hybrid" + }, + "tool_types": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Tool Types", + "description": "Filter by tool types (e.g., 'custom', 'letta_core')." + }, + "tags": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Tags", + "description": "Filter by tags (match any)." + }, + "limit": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "title": "Limit", + "description": "Maximum number of results to return.", + "default": 50 + } + }, + "additionalProperties": false, + "type": "object", + "title": "ToolSearchRequest", + "description": "Request model for searching tools using semantic search." + }, + "ToolSearchResult": { + "properties": { + "tool": { + "$ref": "#/components/schemas/Tool", + "description": "The matched tool." + }, + "embedded_text": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Embedded Text", + "description": "The embedded text content used for matching." + }, + "fts_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Fts Rank", + "description": "Full-text search rank position." + }, + "vector_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Vector Rank", + "description": "Vector search rank position." + }, + "combined_score": { + "type": "number", + "title": "Combined Score", + "description": "Combined relevance score (RRF for hybrid mode)." + } + }, + "additionalProperties": false, + "type": "object", + "required": ["tool", "combined_score"], + "title": "ToolSearchResult", + "description": "Result from a tool search operation." + }, "ToolType": { "type": "string", "enum": [ diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 52c4c62d..38107060 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -1,6 +1,7 @@ """Turbopuffer utilities for archival memory storage.""" import asyncio +import json import logging from datetime import datetime, timezone from typing import Any, Callable, List, Optional, Tuple @@ -29,6 +30,11 @@ def should_use_tpuf_for_messages() -> bool: return should_use_tpuf() and bool(settings.embed_all_messages) +def should_use_tpuf_for_tools() -> bool: + """Check if Turbopuffer should be used for tools.""" + return should_use_tpuf() and bool(settings.embed_tools) + + class TurbopufferClient: """Client for managing archival memory with Turbopuffer vector database.""" @@ -104,6 +110,157 @@ class TurbopufferClient: return namespace_name + @trace_method + async def _get_tool_namespace_name(self, organization_id: str) -> str: + """Get namespace name for tools (org-scoped). + + Args: + organization_id: Organization ID for namespace generation + + Returns: + The org-scoped namespace name for tools + """ + environment = settings.environment + if environment: + namespace_name = f"tools_{organization_id}_{environment.lower()}" + else: + namespace_name = f"tools_{organization_id}" + + return namespace_name + + def _extract_tool_text(self, tool: "PydanticTool") -> str: + """Extract searchable text from a tool for embedding. + + Combines name, description, and JSON schema into a structured format + that provides rich context for semantic search. + + Args: + tool: The tool to extract text from + + Returns: + JSON-formatted string containing tool information + """ + + parts = { + "name": tool.name or "", + "description": tool.description or "", + } + + # Extract parameter information from JSON schema + if tool.json_schema: + # Include function description from schema if different from tool description + schema_description = tool.json_schema.get("description", "") + if schema_description and schema_description != tool.description: + parts["schema_description"] = schema_description + + # Extract parameter information + parameters = tool.json_schema.get("parameters", {}) + if parameters: + properties = parameters.get("properties", {}) + param_descriptions = [] + for param_name, param_info in properties.items(): + param_desc = param_info.get("description", "") + param_type = param_info.get("type", "any") + if param_desc: + param_descriptions.append(f"{param_name} ({param_type}): {param_desc}") + else: + param_descriptions.append(f"{param_name} ({param_type})") + if param_descriptions: + parts["parameters"] = param_descriptions + + # Include tags for additional context + if tool.tags: + parts["tags"] = tool.tags + + return json.dumps(parts) + + @trace_method + async def insert_tools( + self, + tools: List["PydanticTool"], + organization_id: str, + actor: "PydanticUser", + ) -> bool: + """Insert tools into Turbopuffer. + + Args: + tools: List of tools to store + organization_id: Organization ID for the tools + actor: User actor for embedding generation + + Returns: + True if successful + """ + from turbopuffer import AsyncTurbopuffer + + if not tools: + return True + + # Extract text and filter out empty content + tool_texts = [] + valid_tools = [] + for tool in tools: + text = self._extract_tool_text(tool) + if text.strip(): + tool_texts.append(text) + valid_tools.append(tool) + + if not valid_tools: + logger.warning("All tools had empty text content, skipping insertion") + return True + + # Generate embeddings + embeddings = await self._generate_embeddings(tool_texts, actor) + + namespace_name = await self._get_tool_namespace_name(organization_id) + + # Prepare column-based data + ids = [] + vectors = [] + texts = [] + names = [] + organization_ids = [] + tool_types = [] + tags_arrays = [] + created_ats = [] + + for tool, text, embedding in zip(valid_tools, tool_texts, embeddings): + ids.append(tool.id) + vectors.append(embedding) + texts.append(text) + names.append(tool.name or "") + organization_ids.append(organization_id) + tool_types.append(tool.tool_type.value if tool.tool_type else "custom") + tags_arrays.append(tool.tags or []) + created_ats.append(getattr(tool, "created_at", None) or datetime.now(timezone.utc)) + + upsert_columns = { + "id": ids, + "vector": vectors, + "text": texts, + "name": names, + "organization_id": organization_ids, + "tool_type": tool_types, + "tags": tags_arrays, + "created_at": created_ats, + } + + try: + async with _GLOBAL_TURBOPUFFER_SEMAPHORE: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + await namespace.write( + upsert_columns=upsert_columns, + distance_metric="cosine_distance", + schema={"text": {"type": "string", "full_text_search": True}}, + ) + logger.info(f"Successfully inserted {len(ids)} tools to Turbopuffer") + return True + + except Exception as e: + logger.error(f"Failed to insert tools to Turbopuffer: {e}") + raise + @trace_method async def insert_archival_memories( self, @@ -1468,3 +1625,150 @@ class TurbopufferClient: except Exception as e: logger.error(f"Failed to delete source passages from Turbopuffer: {e}") raise + + # tool methods + + @trace_method + async def delete_tools(self, organization_id: str, tool_ids: List[str]) -> bool: + """Delete tools from Turbopuffer. + + Args: + organization_id: Organization ID for namespace lookup + tool_ids: List of tool IDs to delete + + Returns: + True if successful + """ + from turbopuffer import AsyncTurbopuffer + + if not tool_ids: + return True + + namespace_name = await self._get_tool_namespace_name(organization_id) + + try: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + await namespace.write(deletes=tool_ids) + logger.info(f"Successfully deleted {len(tool_ids)} tools from Turbopuffer") + return True + except Exception as e: + logger.error(f"Failed to delete tools from Turbopuffer: {e}") + raise + + @trace_method + async def query_tools( + self, + organization_id: str, + actor: "PydanticUser", + query_text: Optional[str] = None, + search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp" + top_k: int = 50, + tool_types: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + vector_weight: float = 0.5, + fts_weight: float = 0.5, + ) -> List[Tuple[dict, float, dict]]: + """Query tools from Turbopuffer using semantic search. + + Args: + organization_id: Organization ID for namespace lookup + actor: User actor for embedding generation + query_text: Text query for search + search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" + top_k: Number of results to return + tool_types: Optional list of tool types to filter by + tags: Optional list of tags to filter by (match any) + vector_weight: Weight for vector search in hybrid mode + fts_weight: Weight for FTS in hybrid mode + + Returns: + List of (tool_dict, score, metadata) tuples + """ + # Generate embedding for vector/hybrid search + query_embedding = None + if query_text and search_mode in ["vector", "hybrid"]: + embeddings = await self._generate_embeddings([query_text], actor) + query_embedding = embeddings[0] if embeddings else None + + # Fallback to timestamp-based retrieval when no query + if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: + search_mode = "timestamp" + + namespace_name = await self._get_tool_namespace_name(organization_id) + + # Build filters + all_filters = [] + + if tool_types: + if len(tool_types) == 1: + all_filters.append(("tool_type", "Eq", tool_types[0])) + else: + all_filters.append(("tool_type", "In", tool_types)) + + if tags: + all_filters.append(("tags", "ContainsAny", tags)) + + # Combine filters + final_filter = None + if len(all_filters) == 1: + final_filter = all_filters[0] + elif len(all_filters) > 1: + final_filter = ("And", all_filters) + + try: + result = await self._execute_query( + namespace_name=namespace_name, + search_mode=search_mode, + query_embedding=query_embedding, + query_text=query_text, + top_k=top_k, + include_attributes=["text", "name", "organization_id", "tool_type", "tags", "created_at"], + filters=final_filter, + vector_weight=vector_weight, + fts_weight=fts_weight, + ) + + if search_mode == "hybrid": + vector_results = self._process_tool_query_results(result.results[0]) + fts_results = self._process_tool_query_results(result.results[1]) + results_with_metadata = self._reciprocal_rank_fusion( + vector_results=vector_results, + fts_results=fts_results, + get_id_func=lambda d: d["id"], + vector_weight=vector_weight, + fts_weight=fts_weight, + top_k=top_k, + ) + return results_with_metadata + else: + results = self._process_tool_query_results(result) + results_with_metadata = [] + for idx, tool_dict in enumerate(results): + metadata = { + "combined_score": 1.0 / (idx + 1), + "search_mode": search_mode, + f"{search_mode}_rank": idx + 1, + } + results_with_metadata.append((tool_dict, metadata["combined_score"], metadata)) + return results_with_metadata + + except Exception as e: + logger.error(f"Failed to query tools from Turbopuffer: {e}") + raise + + def _process_tool_query_results(self, result) -> List[dict]: + """Process results from a tool query into tool dicts.""" + tools = [] + for row in result.rows: + tool_dict = { + "id": row.id, + "text": getattr(row, "text", ""), + "name": getattr(row, "name", ""), + "organization_id": getattr(row, "organization_id", None), + "tool_type": getattr(row, "tool_type", None), + "tags": getattr(row, "tags", []), + "created_at": getattr(row, "created_at", None), + } + tools.append(tool_dict) + return tools diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index baab6f5f..f1d950f7 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional from pydantic import ConfigDict, Field, model_validator @@ -206,3 +206,23 @@ class ToolRunFromSource(LettaBase): ) pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.") npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.") + + +class ToolSearchRequest(LettaBase): + """Request model for searching tools using semantic search.""" + + query: Optional[str] = Field(None, description="Text query for semantic search.") + search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode: vector, fts, or hybrid.") + tool_types: Optional[List[str]] = Field(None, description="Filter by tool types (e.g., 'custom', 'letta_core').") + tags: Optional[List[str]] = Field(None, description="Filter by tags (match any).") + limit: int = Field(50, description="Maximum number of results to return.", ge=1, le=100) + + +class ToolSearchResult(LettaBase): + """Result from a tool search operation.""" + + tool: Tool = Field(..., description="The matched tool.") + embedded_text: Optional[str] = Field(None, description="The embedded text content used for matching.") + fts_rank: Optional[int] = Field(None, description="Full-text search rank position.") + vector_rank: Optional[int] = Field(None, description="Vector search rank position.") + combined_score: float = Field(..., description="Combined relevance score (RRF for hybrid mode).") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index edd4c83a..fbe989a9 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -31,7 +31,7 @@ from letta.schemas.letta_message_content import TextContent from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer from letta.schemas.message import Message from letta.schemas.pip_requirement import PipRequirement -from letta.schemas.tool import BaseTool, Tool, ToolCreate, ToolRunFromSource, ToolUpdate +from letta.schemas.tool import BaseTool, Tool, ToolCreate, ToolRunFromSource, ToolSearchRequest, ToolSearchResult, ToolUpdate from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode from letta.server.server import SyncServer @@ -271,6 +271,46 @@ async def list_tools( ) +@router.post("/search", response_model=List[ToolSearchResult], operation_id="search_tools") +async def search_tools( + request: ToolSearchRequest = Body(...), + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), +): + """ + Search tools using semantic search. + + Requires tool embedding to be enabled (embed_tools=True). Uses vector search, + full-text search, or hybrid mode to find tools matching the query. + + Returns tools ranked by relevance with their search scores. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + + try: + results = await server.tool_manager.search_tools_async( + actor=actor, + query_text=request.query, + search_mode=request.search_mode, + tool_types=request.tool_types, + tags=request.tags, + limit=request.limit, + ) + + return [ + ToolSearchResult( + tool=tool, + embedded_text=None, # Could be populated if needed + fts_rank=metadata.get("fts_rank"), + vector_rank=metadata.get("vector_rank"), + combined_score=metadata.get("combined_score", 0.0), + ) + for tool, metadata in results + ] + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + @router.post("/", response_model=Tool, operation_id="create_tool") async def create_tool( request: ToolCreate = Body(...), diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 5674e801..98f5a19f 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -41,7 +41,7 @@ from letta.services.helpers.agent_manager_helper import calculate_multi_agent_to from letta.services.mcp.types import SSEServerConfig, StdioServerConfig from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update from letta.settings import settings, tool_settings -from letta.utils import enforce_types, printd +from letta.utils import enforce_types, fire_and_forget, printd from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -338,6 +338,15 @@ class ToolManager: if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured: await self.create_or_update_modal_app(created_tool, actor) + # Embed tool in Turbopuffer if enabled + from letta.helpers.tpuf_client import should_use_tpuf_for_tools + + if should_use_tpuf_for_tools(): + fire_and_forget( + self._embed_tool_background(created_tool, actor), + task_name=f"embed_tool_{created_tool.id}", + ) + return created_tool @enforce_types @@ -869,6 +878,28 @@ class ToolManager: logger.info(f"Deploying Modal app for tool {updated_tool.id} with new hash: {new_hash}") await self.create_or_update_modal_app(updated_tool, actor) + # Update embedding in Turbopuffer if enabled (delete old, insert new) + from letta.helpers.tpuf_client import should_use_tpuf_for_tools + + if should_use_tpuf_for_tools(): + + async def update_tool_embedding(): + try: + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + # Delete old and re-insert (simpler than update) + await tpuf_client.delete_tools(actor.organization_id, [updated_tool.id]) + await tpuf_client.insert_tools([updated_tool], actor.organization_id, actor) + logger.info(f"Successfully updated tool {updated_tool.id} in Turbopuffer") + except Exception as e: + logger.error(f"Failed to update tool {updated_tool.id} in Turbopuffer: {e}") + + fire_and_forget( + update_tool_embedding(), + task_name=f"update_tool_embedding_{updated_tool.id}", + ) + return updated_tool @enforce_types @@ -901,6 +932,20 @@ class ToolManager: logger.warning(f"Skipping Modal cleanup for corrupted tool {tool_id}: {e}") await tool.hard_delete_async(db_session=session, actor=actor) + + # Delete from Turbopuffer if enabled + from letta.helpers.tpuf_client import should_use_tpuf_for_tools + + if should_use_tpuf_for_tools(): + try: + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + await tpuf_client.delete_tools(actor.organization_id, [tool_id]) + logger.info(f"Successfully deleted tool {tool_id} from Turbopuffer") + except Exception as e: + logger.warning(f"Failed to delete tool {tool_id} from Turbopuffer: {e}") + except NoResultFound: raise ValueError(f"Tool with id {tool_id} not found.") @@ -1121,3 +1166,95 @@ class ToolManager: except Exception as e: logger.error(f"Error during Modal app deletion for tool {tool.name}: {e}") raise + + async def _embed_tool_background( + self, + tool: PydanticTool, + actor: PydanticUser, + ) -> None: + """Background task to embed a tool in Turbopuffer. + + Args: + tool: The tool to embed + actor: User performing the action + """ + try: + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + await tpuf_client.insert_tools( + tools=[tool], + organization_id=actor.organization_id, + actor=actor, + ) + logger.info(f"Successfully embedded tool {tool.id} in Turbopuffer") + except Exception as e: + logger.error(f"Failed to embed tool {tool.id} in Turbopuffer: {e}") + + @enforce_types + @trace_method + async def search_tools_async( + self, + actor: PydanticUser, + query_text: Optional[str] = None, + search_mode: str = "hybrid", + tool_types: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + limit: int = 50, + ) -> List[tuple[PydanticTool, dict]]: + """ + Search tools using Turbopuffer semantic search. + + Args: + actor: User performing the search + query_text: Text query for semantic search + search_mode: "vector", "fts", or "hybrid" (default: "hybrid") + tool_types: Optional list of tool types to filter by + tags: Optional list of tags to filter by + limit: Maximum number of results to return + + Returns: + List of (tool, metadata) tuples where metadata contains search scores + + Raises: + ValueError: If Turbopuffer is not enabled for tools + """ + from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_tools + + if not should_use_tpuf_for_tools(): + raise ValueError("Tool semantic search requires tool embedding to be enabled (embed_tools=True).") + + tpuf_client = TurbopufferClient() + results = await tpuf_client.query_tools( + organization_id=actor.organization_id, + actor=actor, + query_text=query_text, + search_mode=search_mode, + top_k=limit, + tool_types=tool_types, + tags=tags, + ) + + if not results: + return [] + + # Fetch full tool objects from database + tool_ids = [tool_dict["id"] for tool_dict, _, _ in results] + tools = [] + for tool_id in tool_ids: + try: + tool = await self.get_tool_by_id_async(tool_id, actor=actor) + tools.append(tool) + except Exception: + pass # Tool may have been deleted + + tool_map = {tool.id: tool for tool in tools} + + # Build result list preserving order and including metadata + result_list = [] + for tool_dict, _, metadata in results: + tool_id = tool_dict["id"] + if tool_id in tool_map: + result_list.append((tool_map[tool_id], metadata)) + + return result_list diff --git a/letta/settings.py b/letta/settings.py index 71efaed5..9f929eeb 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -334,6 +334,7 @@ class Settings(BaseSettings): tpuf_api_key: Optional[str] = None tpuf_region: str = "gcp-us-central1" embed_all_messages: bool = False + embed_tools: bool = False # For encryption encryption_key: Optional[str] = None diff --git a/tests/sdk/search_test.py b/tests/sdk/search_test.py index 67bfab49..d167af41 100644 --- a/tests/sdk/search_test.py +++ b/tests/sdk/search_test.py @@ -16,6 +16,7 @@ from letta_client.types import CreateBlockParam, MessageCreateParam from letta.config import LettaConfig from letta.schemas.message import MessageSearchResult +from letta.schemas.tool import ToolSearchResult from letta.server.rest_api.routers.v1.passages import PassageSearchResult from letta.server.server import SyncServer from letta.settings import model_settings, settings @@ -53,6 +54,34 @@ def cleanup_agent_with_messages(client: Letta, agent_id: str): print(f"Warning: Failed to clean up agent {agent_id}: {e}") +def cleanup_tool(client: Letta, tool_id: str): + """ + Helper function to properly clean up a tool by deleting it from both + Turbopuffer and the database. + + Args: + client: Letta SDK client + tool_id: ID of the tool to clean up + """ + try: + # First, delete from Turbopuffer if tool embedding is enabled + try: + import asyncio + + from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_tools + + if should_use_tpuf_for_tools(): + tpuf_client = TurbopufferClient() + asyncio.run(tpuf_client.delete_tools(DEFAULT_ORG_ID, [tool_id])) + except Exception as e: + print(f"Warning: Failed to clean up Turbopuffer tool {tool_id}: {e}") + + # Now delete the tool from the database + client.tools.delete(tool_id=tool_id) + except Exception as e: + print(f"Warning: Failed to clean up tool {tool_id}: {e}") + + @pytest.fixture(scope="module") def server(): """Server fixture for testing""" @@ -524,3 +553,163 @@ def test_passage_search_org_wide(client: Letta, enable_turbopuffer): # Clean up agents cleanup_agent_with_messages(client, agent1.id) cleanup_agent_with_messages(client, agent2.id) + + +@pytest.fixture +def enable_tool_embedding(): + """Enable both Turbopuffer and tool embedding""" + original_use_tpuf = settings.use_tpuf + original_api_key = settings.tpuf_api_key + original_embed_tools = settings.embed_tools + original_environment = settings.environment + + settings.use_tpuf = True + settings.tpuf_api_key = settings.tpuf_api_key or "test-key" + settings.embed_tools = True + settings.environment = "DEV" + + yield + + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_api_key + settings.embed_tools = original_embed_tools + settings.environment = original_environment + + +@pytest.mark.skipif( + not (settings.use_tpuf and settings.tpuf_api_key and model_settings.openai_api_key and settings.embed_tools), + reason="Tool search requires Turbopuffer, OpenAI, and tool embedding to be enabled", +) +def test_tool_search_basic(client: Letta, enable_tool_embedding): + """Test basic tool search functionality through the SDK""" + tool_ids = [] + + try: + # Create test tools with distinct descriptions for semantic search + test_tools = [ + { + "source_code": ''' +def send_email_to_user(recipient: str, subject: str, body: str) -> str: + """Send an email message to a specified recipient. + + Args: + recipient: Email address of the recipient + subject: Subject line of the email + body: Body content of the email message + + Returns: + Confirmation message + """ + return f"Email sent to {recipient}" +''', + "description": "Send an email message to a specified recipient with subject and body.", + "tags": ["communication", "email"], + }, + { + "source_code": ''' +def fetch_weather_data(city: str, units: str = "celsius") -> str: + """Fetch current weather information for a city. + + Args: + city: Name of the city to get weather for + units: Temperature units (celsius or fahrenheit) + + Returns: + Weather information string + """ + return f"Weather in {city}: sunny, 25 {units}" +''', + "description": "Fetch current weather information for a specified city.", + "tags": ["weather", "api"], + }, + { + "source_code": ''' +def calculate_compound_interest(principal: float, rate: float, years: int) -> float: + """Calculate compound interest on an investment. + + Args: + principal: Initial investment amount + rate: Annual interest rate as decimal + years: Number of years + + Returns: + Final amount after compound interest + """ + return principal * (1 + rate) ** years +''', + "description": "Calculate compound interest on a financial investment over time.", + "tags": ["finance", "calculator"], + }, + ] + + # Create tools via SDK + for tool_data in test_tools: + tool = client.tools.create( + source_code=tool_data["source_code"], + description=tool_data["description"], + tags=tool_data["tags"], + ) + tool_ids.append(tool.id) + + # Wait for embeddings to be indexed + time.sleep(3) + + # Test semantic search - should find email-related tool + results = client.post( + "/v1/tools/search", + cast_to=list[ToolSearchResult], + body={ + "query": "send message to someone", + "search_mode": "hybrid", + "limit": 10, + }, + ) + + assert len(results) > 0, "Should find at least one tool" + + # The email tool should be ranked highly for this query + tool_names = [result["tool"]["name"] for result in results] + assert "send_email_to_user" in tool_names, "Should find email tool for messaging query" + + # Verify result structure + for result in results: + assert "tool" in result, "Result should have tool field" + assert "combined_score" in result, "Result should have combined_score field" + assert isinstance(result["combined_score"], float), "combined_score should be a float" + + # Test search with different query - should find weather tool + weather_results = client.post( + "/v1/tools/search", + cast_to=list[ToolSearchResult], + body={ + "query": "get temperature forecast", + "search_mode": "hybrid", + "limit": 10, + }, + ) + + assert len(weather_results) > 0, "Should find tools for weather query" + weather_tool_names = [result["tool"]["name"] for result in weather_results] + assert "fetch_weather_data" in weather_tool_names, "Should find weather tool" + + # Test search with tag filter + finance_results = client.post( + "/v1/tools/search", + cast_to=list[ToolSearchResult], + body={ + "query": "money calculation", + "tags": ["finance"], + "search_mode": "hybrid", + "limit": 10, + }, + ) + + # Should find the finance tool when filtering by tag + if len(finance_results) > 0: + finance_tool_names = [result["tool"]["name"] for result in finance_results] + assert "calculate_compound_interest" in finance_tool_names, "Should find finance tool with tag filter" + + finally: + # Clean up all created tools + for tool_id in tool_ids: + cleanup_tool(client, tool_id)