Files
letta-server/letta/server/rest_api/routers/v1/passages.py
Ari Webb f9b405372d feat: add search routes [LET-6236] (#6280)
* claude code first pass

* rename routes

* search_messages and list_messages

* revert agents messagesearch

* generate api

* fix backend for list all messages

* request for message search

* return list of letta message

* add tests

* error in archive endpoint

* archive delete return type wrong

* optional params for archive creation

* add passage to tpuf on create

* fix archive manager

* support global passage search

* search by agent

* just do basic org wide search for now

* change message test to be about fresh data, cleanup after

---------

Co-authored-by: Ari Webb <ari@letta.com>
2025-11-24 19:10:27 -08:00

109 lines
4.7 KiB
Python

from datetime import datetime
from typing import List, Literal, Optional
from fastapi import APIRouter, Body, Depends
from pydantic import BaseModel, Field
from letta.schemas.enums import TagMatchMode
from letta.schemas.passage import Passage
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.server import SyncServer
router = APIRouter(prefix="/passages", tags=["passages"])
class PassageSearchRequest(BaseModel):
"""Request model for searching passages across archives."""
query: str = Field(..., description="Text query for semantic search")
agent_id: Optional[str] = Field(None, description="Filter passages by agent ID")
archive_id: Optional[str] = Field(None, description="Filter passages by archive ID")
tags: Optional[List[str]] = Field(None, description="Optional list of tags to filter search results")
tag_match_mode: Literal["any", "all"] = Field(
"any", description="How to match tags - 'any' to match passages with any of the tags, 'all' to match only passages with all tags"
)
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
start_date: Optional[datetime] = Field(None, description="Filter results to passages created after this datetime")
end_date: Optional[datetime] = Field(None, description="Filter results to passages created before this datetime")
class PassageSearchResult(BaseModel):
"""Result from a passage search operation with scoring details."""
passage: Passage = Field(..., description="The passage object")
score: float = Field(..., description="Relevance score")
metadata: dict = Field(default_factory=dict, description="Additional metadata about the search result")
@router.post("/search", response_model=List[PassageSearchResult], operation_id="search_passages")
async def search_passages(
request: PassageSearchRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Search passages across the organization with optional agent and archive filtering.
Returns passages with relevance scores.
This endpoint supports semantic search through passages:
- If neither agent_id nor archive_id is provided, searches ALL passages in the organization
- If agent_id is provided, searches passages across all archives attached to that agent
- If archive_id is provided, searches passages within that specific archive
- If both are provided, agent_id takes precedence
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Convert tag_match_mode to enum
tag_mode = TagMatchMode.ANY if request.tag_match_mode == "any" else TagMatchMode.ALL
# Determine which embedding config to use
embedding_config = None
if request.agent_id:
# Search by agent
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=request.agent_id, actor=actor)
embedding_config = agent_state.embedding_config
elif request.archive_id:
# Search by archive_id
archive = await server.archive_manager.get_archive_by_id_async(archive_id=request.archive_id, actor=actor)
embedding_config = archive.embedding_config
else:
# Search across all passages in the organization
# Get default embedding config from any agent or use server default
agent_count = await server.agent_manager.size_async(actor=actor)
if agent_count > 0:
# Get first agent to derive embedding config
agents = await server.agent_manager.list_agents_async(actor=actor, limit=1)
if agents:
embedding_config = agents[0].embedding_config
if not embedding_config:
# Fall back to server default
embedding_config = server.default_embedding_config
# Search passages
passages_with_metadata = await server.agent_manager.query_agent_passages_async(
actor=actor,
agent_id=request.agent_id, # Can be None for organization-wide search
archive_id=request.archive_id, # Can be None if searching by agent or org-wide
query_text=request.query,
limit=request.limit,
embedding_config=embedding_config,
embed_query=True,
tags=request.tags,
tag_match_mode=tag_mode,
start_date=request.start_date,
end_date=request.end_date,
)
# Convert to PassageSearchResult objects
results = [
PassageSearchResult(
passage=passage,
score=score,
metadata=metadata,
)
for passage, score, metadata in passages_with_metadata
]
return results