diff --git a/fern/openapi.json b/fern/openapi.json index c89a1942..30c58aee 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -336,15 +336,8 @@ } ], "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Archive" - } - } - } + "204": { + "description": "Successful Response" }, "422": { "description": "Validation Error", @@ -15296,6 +15289,205 @@ } } }, + "/v1/messages/": { + "get": { + "tags": ["messages"], + "summary": "List All Messages", + "description": "List messages across all agents for the current user.", + "operationId": "list_all_messages", + "parameters": [ + { + "name": "before", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order", + "title": "Before" + }, + "description": "Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order" + }, + { + "name": "after", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order", + "title": "After" + }, + "description": "Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order" + }, + { + "name": "limit", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "description": "Maximum number of messages to return", + "default": 100, + "title": "Limit" + }, + "description": "Maximum number of messages to return" + }, + { + "name": "order", + "in": "query", + "required": false, + "schema": { + "enum": ["asc", "desc"], + "type": "string", + "description": "Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first", + "default": "desc", + "title": "Order" + }, + "description": "Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first" + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/LettaMessageUnion" + }, + "title": "Response List All Messages" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/messages/search": { + "post": { + "tags": ["messages"], + "summary": "Search All Messages", + "description": "Search messages across the organization with optional agent filtering.\nReturns messages with FTS/vector ranks and total RRF score.\n\nThis is a cloud-only feature.", + "operationId": "search_all_messages", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SearchAllMessagesRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "$ref": "#/components/schemas/SystemMessage" + }, + { + "$ref": "#/components/schemas/UserMessage" + }, + { + "$ref": "#/components/schemas/ReasoningMessage" + }, + { + "$ref": "#/components/schemas/HiddenReasoningMessage" + }, + { + "$ref": "#/components/schemas/ToolCallMessage" + }, + { + "$ref": "#/components/schemas/ToolReturnMessage" + }, + { + "$ref": "#/components/schemas/AssistantMessage" + }, + { + "$ref": "#/components/schemas/ApprovalRequestMessage" + }, + { + "$ref": "#/components/schemas/ApprovalResponseMessage" + }, + { + "$ref": "#/components/schemas/SummaryMessage" + }, + { + "$ref": "#/components/schemas/EventMessage" + } + ], + "discriminator": { + "propertyName": "message_type", + "mapping": { + "system_message": "#/components/schemas/SystemMessage", + "user_message": "#/components/schemas/UserMessage", + "reasoning_message": "#/components/schemas/ReasoningMessage", + "hidden_reasoning_message": "#/components/schemas/HiddenReasoningMessage", + "tool_call_message": "#/components/schemas/ToolCallMessage", + "tool_return_message": "#/components/schemas/ToolReturnMessage", + "assistant_message": "#/components/schemas/AssistantMessage", + "approval_request_message": "#/components/schemas/ApprovalRequestMessage", + "approval_response_message": "#/components/schemas/ApprovalResponseMessage", + "summary": "#/components/schemas/SummaryMessage", + "event": "#/components/schemas/EventMessage" + } + } + }, + "title": "Response Search All Messages" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/messages/batches": { "post": { "tags": ["messages"], @@ -15673,6 +15865,51 @@ } } }, + "/v1/passages/search": { + "post": { + "tags": ["passages"], + "summary": "Search Passages", + "description": "Search passages across the organization with optional agent and archive filtering.\nReturns passages with relevance scores.\n\nThis endpoint supports semantic search through passages:\n- If neither agent_id nor archive_id is provided, searches ALL passages in the organization\n- If agent_id is provided, searches passages across all archives attached to that agent\n- If archive_id is provided, searches passages within that specific archive\n- If both are provided, agent_id takes precedence", + "operationId": "search_passages", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PassageSearchRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/PassageSearchResult" + }, + "title": "Response Search Passages" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/voice-beta/{agent_id}/chat/completions": { "post": { "tags": ["voice"], @@ -31576,6 +31813,122 @@ "title": "PassageCreateRequest", "description": "Request model for creating a passage in an archive." }, + "PassageSearchRequest": { + "properties": { + "query": { + "type": "string", + "title": "Query", + "description": "Text query for semantic search" + }, + "agent_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Agent Id", + "description": "Filter passages by agent ID" + }, + "archive_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Archive Id", + "description": "Filter passages by archive ID" + }, + "tags": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Tags", + "description": "Optional list of tags to filter search results" + }, + "tag_match_mode": { + "type": "string", + "enum": ["any", "all"], + "title": "Tag Match Mode", + "description": "How to match tags - 'any' to match passages with any of the tags, 'all' to match only passages with all tags", + "default": "any" + }, + "limit": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "title": "Limit", + "description": "Maximum number of results to return", + "default": 50 + }, + "start_date": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "Start Date", + "description": "Filter results to passages created after this datetime" + }, + "end_date": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "End Date", + "description": "Filter results to passages created before this datetime" + } + }, + "type": "object", + "required": ["query"], + "title": "PassageSearchRequest", + "description": "Request model for searching passages across archives." + }, + "PassageSearchResult": { + "properties": { + "passage": { + "$ref": "#/components/schemas/Passage", + "description": "The passage object" + }, + "score": { + "type": "number", + "title": "Score", + "description": "Relevance score" + }, + "metadata": { + "additionalProperties": true, + "type": "object", + "title": "Metadata", + "description": "Additional metadata about the search result" + } + }, + "type": "object", + "required": ["passage", "score"], + "title": "PassageSearchResult", + "description": "Result from a passage search operation with scoring details." + }, "PipRequirement": { "properties": { "name": { @@ -33158,6 +33511,59 @@ "enum": ["e2b", "modal", "local"], "title": "SandboxType" }, + "SearchAllMessagesRequest": { + "properties": { + "query": { + "type": "string", + "title": "Query", + "description": "Text query for full-text search" + }, + "search_mode": { + "type": "string", + "enum": ["vector", "fts", "hybrid"], + "title": "Search Mode", + "description": "Search mode to use", + "default": "hybrid" + }, + "limit": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "title": "Limit", + "description": "Maximum number of results to return", + "default": 50 + }, + "start_date": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "Start Date", + "description": "Filter messages created after this date" + }, + "end_date": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "End Date", + "description": "Filter messages created on or before this date" + } + }, + "type": "object", + "required": ["query"], + "title": "SearchAllMessagesRequest" + }, "SleeptimeManager": { "properties": { "manager_type": { diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 222ac78b..826c4b4f 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -2168,6 +2168,14 @@ class MessageSearchRequest(BaseModel): end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date") +class SearchAllMessagesRequest(BaseModel): + query: str = Field(..., description="Text query for full-text search") + search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use") + limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100) + start_date: Optional[datetime] = Field(None, description="Filter messages created after this date") + end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date") + + class MessageSearchResult(BaseModel): """Result from a message search operation with scoring details.""" diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 6b9711b0..0e20956a 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -16,6 +16,7 @@ from letta.server.rest_api.routers.v1.jobs import router as jobs_router from letta.server.rest_api.routers.v1.llms import router as llm_router from letta.server.rest_api.routers.v1.mcp_servers import router as mcp_servers_router from letta.server.rest_api.routers.v1.messages import router as messages_router +from letta.server.rest_api.routers.v1.passages import router as passages_router from letta.server.rest_api.routers.v1.providers import router as providers_router from letta.server.rest_api.routers.v1.runs import router as runs_router from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router @@ -52,6 +53,7 @@ ROUTERS = [ tags_router, telemetry_router, messages_router, + passages_router, voice_router, embeddings_router, openai_chat_completions_router, diff --git a/letta/server/rest_api/routers/v1/archives.py b/letta/server/rest_api/routers/v1/archives.py index 900178ee..468bba36 100644 --- a/letta/server/rest_api/routers/v1/archives.py +++ b/letta/server/rest_api/routers/v1/archives.py @@ -65,7 +65,7 @@ async def create_archive( embedding_config = archive.embedding_config if embedding_config is None and archive.embedding is not None: - handle = f"{archive.embedding.provider}/{archive.embedding.model}" + handle = archive.embedding embedding_config = await server.get_embedding_config_from_handle_async( handle=handle, actor=actor, @@ -150,7 +150,7 @@ async def modify_archive( ) -@router.delete("/{archive_id}", response_model=PydanticArchive, operation_id="delete_archive") +@router.delete("/{archive_id}", status_code=204, operation_id="delete_archive") async def delete_archive( archive_id: ArchiveId, server: "SyncServer" = Depends(get_letta_server), @@ -160,10 +160,11 @@ async def delete_archive( Delete an archive by its ID. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.archive_manager.delete_archive_async( + await server.archive_manager.delete_archive_async( archive_id=archive_id, actor=actor, ) + return None @router.get("/{archive_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_archive") diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index e4fd7f29..25448352 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -1,14 +1,17 @@ -from typing import List, Literal, Optional +from typing import Annotated, List, Literal, Optional -from fastapi import APIRouter, Body, Depends, Query +from fastapi import APIRouter, Body, Depends, HTTPException, Query +from pydantic import Field from starlette.requests import Request from letta.agents.letta_agent_batch import LettaAgentBatch from letta.errors import LettaInvalidArgumentError from letta.log import get_logger from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate +from letta.schemas.letta_message import LettaMessageUnion from letta.schemas.letta_request import CreateBatch from letta.schemas.letta_response import LettaBatchMessages +from letta.schemas.message import Message, MessageSearchRequest, MessageSearchResult, SearchAllMessagesRequest from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server from letta.server.server import SyncServer from letta.settings import settings @@ -18,6 +21,65 @@ router = APIRouter(prefix="/messages", tags=["messages"]) logger = get_logger(__name__) +MessagesResponse = Annotated[ + list[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}}) +] + + +@router.get("/", response_model=MessagesResponse, operation_id="list_all_messages") +async def list_all_messages( + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), + before: Optional[str] = Query( + None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order" + ), + after: Optional[str] = Query( + None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order" + ), + limit: Optional[int] = Query(100, description="Maximum number of messages to return"), + order: Literal["asc", "desc"] = Query( + "desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first" + ), +): + """ + List messages across all agents for the current user. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await server.get_all_messages_recall_async( + after=after, + before=before, + limit=limit, + reverse=(order == "desc"), + return_message_object=False, + actor=actor, + ) + + +@router.post("/search", response_model=List[LettaMessageUnion], operation_id="search_all_messages") +async def search_all_messages( + request: SearchAllMessagesRequest = Body(...), + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), +): + """ + Search messages across the organization with optional agent filtering. + Returns messages with FTS/vector ranks and total RRF score. + + This is a cloud-only feature. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + + results = await server.message_manager.search_messages_org_async( + actor=actor, + query_text=request.query, + search_mode=request.search_mode, + limit=request.limit, + start_date=request.start_date, + end_date=request.end_date, + ) + return Message.to_letta_messages_from_list(messages=[result.message for result in results], text_is_assistant_message=True) + + @router.post( "/batches", response_model=BatchJob, diff --git a/letta/server/rest_api/routers/v1/passages.py b/letta/server/rest_api/routers/v1/passages.py new file mode 100644 index 00000000..cdb1010c --- /dev/null +++ b/letta/server/rest_api/routers/v1/passages.py @@ -0,0 +1,108 @@ +from datetime import datetime +from typing import List, Literal, Optional + +from fastapi import APIRouter, Body, Depends +from pydantic import BaseModel, Field + +from letta.schemas.enums import TagMatchMode +from letta.schemas.passage import Passage +from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server +from letta.server.server import SyncServer + +router = APIRouter(prefix="/passages", tags=["passages"]) + + +class PassageSearchRequest(BaseModel): + """Request model for searching passages across archives.""" + + query: str = Field(..., description="Text query for semantic search") + agent_id: Optional[str] = Field(None, description="Filter passages by agent ID") + archive_id: Optional[str] = Field(None, description="Filter passages by archive ID") + tags: Optional[List[str]] = Field(None, description="Optional list of tags to filter search results") + tag_match_mode: Literal["any", "all"] = Field( + "any", description="How to match tags - 'any' to match passages with any of the tags, 'all' to match only passages with all tags" + ) + limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100) + start_date: Optional[datetime] = Field(None, description="Filter results to passages created after this datetime") + end_date: Optional[datetime] = Field(None, description="Filter results to passages created before this datetime") + + +class PassageSearchResult(BaseModel): + """Result from a passage search operation with scoring details.""" + + passage: Passage = Field(..., description="The passage object") + score: float = Field(..., description="Relevance score") + metadata: dict = Field(default_factory=dict, description="Additional metadata about the search result") + + +@router.post("/search", response_model=List[PassageSearchResult], operation_id="search_passages") +async def search_passages( + request: PassageSearchRequest = Body(...), + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), +): + """ + Search passages across the organization with optional agent and archive filtering. + Returns passages with relevance scores. + + This endpoint supports semantic search through passages: + - If neither agent_id nor archive_id is provided, searches ALL passages in the organization + - If agent_id is provided, searches passages across all archives attached to that agent + - If archive_id is provided, searches passages within that specific archive + - If both are provided, agent_id takes precedence + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + + # Convert tag_match_mode to enum + tag_mode = TagMatchMode.ANY if request.tag_match_mode == "any" else TagMatchMode.ALL + + # Determine which embedding config to use + embedding_config = None + if request.agent_id: + # Search by agent + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=request.agent_id, actor=actor) + embedding_config = agent_state.embedding_config + elif request.archive_id: + # Search by archive_id + archive = await server.archive_manager.get_archive_by_id_async(archive_id=request.archive_id, actor=actor) + embedding_config = archive.embedding_config + else: + # Search across all passages in the organization + # Get default embedding config from any agent or use server default + agent_count = await server.agent_manager.size_async(actor=actor) + if agent_count > 0: + # Get first agent to derive embedding config + agents = await server.agent_manager.list_agents_async(actor=actor, limit=1) + if agents: + embedding_config = agents[0].embedding_config + + if not embedding_config: + # Fall back to server default + embedding_config = server.default_embedding_config + + # Search passages + passages_with_metadata = await server.agent_manager.query_agent_passages_async( + actor=actor, + agent_id=request.agent_id, # Can be None for organization-wide search + archive_id=request.archive_id, # Can be None if searching by agent or org-wide + query_text=request.query, + limit=request.limit, + embedding_config=embedding_config, + embed_query=True, + tags=request.tags, + tag_match_mode=tag_mode, + start_date=request.start_date, + end_date=request.end_date, + ) + + # Convert to PassageSearchResult objects + results = [ + PassageSearchResult( + passage=passage, + score=score, + metadata=metadata, + ) + for passage, score, metadata in passages_with_metadata + ] + + return results diff --git a/letta/server/server.py b/letta/server/server.py index e58bc3ac..4fb2d3b9 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -784,6 +784,51 @@ class SyncServer(object): return records + async def get_all_messages_recall_async( + self, + actor: User, + after: Optional[str] = None, + before: Optional[str] = None, + limit: Optional[int] = 100, + group_id: Optional[str] = None, + reverse: Optional[bool] = False, + return_message_object: bool = True, + use_assistant_message: bool = True, + assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, + assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, + include_err: Optional[bool] = None, + ) -> Union[List[Message], List[LettaMessage]]: + records = await self.message_manager.list_messages( + agent_id=None, + actor=actor, + after=after, + before=before, + limit=limit, + ascending=not reverse, + group_id=group_id, + include_err=include_err, + ) + + if not return_message_object: + # NOTE: We are assuming all messages are coming from letta_v1_agent. This may lead to slightly incorrect assistant message handling. + # text_is_assistant_message = agent_state.agent_type == AgentType.letta_v1_agent + text_is_assistant_message = True + + records = Message.to_letta_messages_from_list( + messages=records, + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + reverse=reverse, + include_err=include_err, + text_is_assistant_message=text_is_assistant_message, + ) + + if reverse: + records = records[::-1] + + return records + def get_server_config(self, include_defaults: bool = False) -> dict: """Return the base config""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 0a291343..c10c1997 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -2183,6 +2183,7 @@ class AgentManager: self, actor: PydanticUser, agent_id: Optional[str] = None, + archive_id: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None, start_date: Optional[datetime] = None, @@ -2197,17 +2198,26 @@ class AgentManager: ) -> List[Tuple[PydanticPassage, float, dict]]: """Lists all passages attached to an agent.""" # Check if we should use Turbopuffer for vector search - if embed_query and agent_id and query_text and embedding_config: - # Get archive IDs for the agent - archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor) + # Support searching by either agent_id or archive_id directly + if embed_query and query_text and embedding_config: + target_archive_id = None - if archive_ids: - # TODO: Remove this restriction once we support multiple archives with mixed vector DB providers - if len(archive_ids) > 1: - raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search") + if agent_id: + # Get archive IDs for the agent + archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor) + if archive_ids: + # TODO: Remove this restriction once we support multiple archives with mixed vector DB providers + if len(archive_ids) > 1: + raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search") + target_archive_id = archive_ids[0] + elif archive_id: + # Use the provided archive_id directly + target_archive_id = archive_id + + if target_archive_id: # Get archive to check vector_db_provider - archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_ids[0], actor=actor) + archive = await self.archive_manager.get_archive_by_id_async(archive_id=target_archive_id, actor=actor) # Use Turbopuffer for vector search if archive is configured for TPUF if archive.vector_db_provider == VectorDBProvider.TPUF: @@ -2226,7 +2236,7 @@ class AgentManager: tpuf_client = TurbopufferClient() # use hybrid search to combine vector and full-text search passages_with_scores = await tpuf_client.query_passages( - archive_id=archive_ids[0], + archive_id=target_archive_id, query_text=query_text, # pass text for potential hybrid search search_mode="hybrid", # use hybrid mode for better results top_k=limit, @@ -2239,14 +2249,13 @@ class AgentManager: # Return full tuples with metadata return passages_with_scores - else: - return [] # Fall back to SQL-based search for non-vector queries or NATIVE archives async with db_registry.async_session() as session: main_query = await build_agent_passage_query( actor=actor, agent_id=agent_id, + archive_id=archive_id, query_text=query_text, start_date=start_date, end_date=end_date, diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index 2758f1f9..5ed6220f 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -284,8 +284,8 @@ class ArchiveManager: self, archive_id: str, text: str, - metadata: Dict = None, - tags: List[str] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, actor: PydanticUser = None, ) -> PydanticPassage: """Create a passage in an archive. @@ -335,6 +335,27 @@ class ArchiveManager: actor=actor, ) + # If archive uses Turbopuffer, also write to Turbopuffer (dual-write) + if archive.vector_db_provider == VectorDBProvider.TPUF: + try: + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + + # Insert to Turbopuffer with the same ID as SQL + await tpuf_client.insert_archival_memories( + archive_id=archive.id, + text_chunks=[created_passage.text], + passage_ids=[created_passage.id], + organization_id=actor.organization_id, + actor=actor, + ) + logger.info(f"Uploaded passage {created_passage.id} to Turbopuffer for archive {archive_id}") + except Exception as e: + logger.error(f"Failed to upload passage to Turbopuffer: {e}") + # Don't fail the entire operation if Turbopuffer upload fails + # The passage is already saved to SQL + logger.info(f"Created passage {created_passage.id} in archive {archive_id}") return created_passage diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 9698ff8c..cec9e98b 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -1176,7 +1176,8 @@ async def build_source_passage_query( async def build_agent_passage_query( actor: User, - agent_id: str, # Required for agent passages + agent_id: Optional[str] = None, + archive_id: Optional[str] = None, query_text: Optional[str] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, @@ -1186,7 +1187,11 @@ async def build_agent_passage_query( ascending: bool = True, embedding_config: Optional[EmbeddingConfig] = None, ) -> Select: - """Build query for agent passages with all filters applied.""" + """Build query for agent/archive passages with all filters applied. + + Can provide agent_id, archive_id, both, or neither (org-wide search). + If both are provided, agent_id takes precedence. + """ # Handle embedding for vector search embedded_text = None @@ -1203,12 +1208,23 @@ async def build_agent_passage_query( embedded_text = np.array(embeddings[0]) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() - # Base query for agent passages - join through archives_agents - query = ( - select(ArchivalPassage) - .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id) - .where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id) - ) + # Base query for passages + if agent_id: + # Query for agent passages - join through archives_agents + # Agent_id takes precedence if both agent_id and archive_id are provided + query = ( + select(ArchivalPassage) + .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id) + .where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id) + ) + elif archive_id: + # Query for archive passages directly + query = select(ArchivalPassage).where( + ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id + ) + else: + # Org-wide search - all passages in organization + query = select(ArchivalPassage).where(ArchivalPassage.organization_id == actor.organization_id) # Apply filters if start_date: diff --git a/tests/sdk_v1/search_test.py b/tests/sdk_v1/search_test.py new file mode 100644 index 00000000..da08f48b --- /dev/null +++ b/tests/sdk_v1/search_test.py @@ -0,0 +1,440 @@ +""" +End-to-end tests for passage and message search endpoints using the SDK client. + +These tests verify that the /v1/passages/search and /v1/messages/search endpoints work correctly +with Turbopuffer integration, including vector search, FTS, hybrid search, filtering, and pagination. +""" + +import time +import uuid +from datetime import datetime, timedelta, timezone + +import pytest +from letta_client import Letta +from letta_client.types import CreateBlockParam, MessageCreateParam + +from letta.config import LettaConfig +from letta.server.server import SyncServer +from letta.settings import settings + + +def cleanup_agent_with_messages(client: Letta, agent_id: str): + """ + Helper function to properly clean up an agent by first deleting all its messages + from Turbopuffer before deleting the agent itself. + + Args: + client: Letta SDK client + agent_id: ID of the agent to clean up + """ + try: + # First, delete all messages for this agent from Turbopuffer + # This ensures no orphaned messages remain in Turbopuffer + try: + import asyncio + + from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages + + if should_use_tpuf_for_messages(): + tpuf_client = TurbopufferClient() + # Delete all messages for this agent from Turbopuffer + asyncio.run(tpuf_client.delete_all_messages(agent_id)) + except Exception as e: + print(f"Warning: Failed to clean up Turbopuffer messages for agent {agent_id}: {e}") + + # Now delete the agent itself (which will delete SQL messages via cascade) + client.agents.delete(agent_id=agent_id) + except Exception as e: + print(f"Warning: Failed to clean up agent {agent_id}: {e}") + + +@pytest.fixture(scope="module") +def server(): + """Server fixture for testing""" + config = LettaConfig.load() + config.save() + server = SyncServer(init_with_default_org_and_user=False) + return server + + +@pytest.fixture +def enable_turbopuffer(): + """Temporarily enable Turbopuffer for testing""" + original_use_tpuf = settings.use_tpuf + original_api_key = settings.tpuf_api_key + original_environment = settings.environment + + # Enable Turbopuffer with test key + settings.use_tpuf = True + if not settings.tpuf_api_key: + settings.tpuf_api_key = original_api_key + settings.environment = "DEV" + + yield + + # Restore original values + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_api_key + settings.environment = original_environment + + +@pytest.fixture +def enable_message_embedding(): + """Enable both Turbopuffer and message embedding""" + original_use_tpuf = settings.use_tpuf + original_api_key = settings.tpuf_api_key + original_embed_messages = settings.embed_all_messages + original_environment = settings.environment + + settings.use_tpuf = True + settings.tpuf_api_key = settings.tpuf_api_key or "test-key" + settings.embed_all_messages = True + settings.environment = "DEV" + + yield + + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_api_key + settings.embed_all_messages = original_embed_messages + settings.environment = original_environment + + +@pytest.fixture +def disable_turbopuffer(): + """Ensure Turbopuffer is disabled for testing""" + original_use_tpuf = settings.use_tpuf + original_embed_messages = settings.embed_all_messages + + settings.use_tpuf = False + settings.embed_all_messages = False + + yield + + settings.use_tpuf = original_use_tpuf + settings.embed_all_messages = original_embed_messages + + +@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") +def test_passage_search_basic(client: Letta, enable_turbopuffer): + """Test basic passage search functionality through the SDK""" + # Create an agent + agent = client.agents.create( + name=f"test_passage_search_{uuid.uuid4()}", + memory_blocks=[CreateBlockParam(label="persona", value="test assistant")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + try: + # Create an archive and attach to agent + archive = client.archives.create(name=f"test_archive_{uuid.uuid4()}", embedding="openai/text-embedding-3-small") + + try: + # Attach archive to agent + client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id) + + # Insert some passages + test_passages = [ + "Python is a popular programming language for data science and machine learning.", + "JavaScript is widely used for web development and frontend applications.", + "Turbopuffer is a vector database optimized for performance and scalability.", + ] + + for passage_text in test_passages: + client.archives.passages.create(archive_id=archive.id, text=passage_text) + + # Wait for indexing + time.sleep(2) + + # Test search by agent_id + results = client.passages.search(query="python programming", agent_id=agent.id, limit=10) + + assert len(results) > 0, "Should find at least one passage" + assert any("Python" in result.passage.text for result in results), "Should find Python-related passage" + + # Verify result structure + for result in results: + assert hasattr(result, "passage"), "Result should have passage field" + assert hasattr(result, "score"), "Result should have score field" + assert hasattr(result, "metadata"), "Result should have metadata field" + assert isinstance(result.score, float), "Score should be a float" + + # Test search by archive_id + archive_results = client.passages.search(query="vector database", archive_id=archive.id, limit=10) + + assert len(archive_results) > 0, "Should find passages in archive" + assert any("Turbopuffer" in result.passage.text or "vector" in result.passage.text for result in archive_results), ( + "Should find vector-related passage" + ) + + finally: + # Clean up archive + try: + client.archives.delete(archive_id=archive.id) + except: + pass + + finally: + # Clean up agent + cleanup_agent_with_messages(client, agent.id) + + +@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") +def test_passage_search_with_tags(client: Letta, enable_turbopuffer): + """Test passage search with tag filtering""" + # Create an agent + agent = client.agents.create( + name=f"test_passage_tags_{uuid.uuid4()}", + memory_blocks=[CreateBlockParam(label="persona", value="test assistant")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + try: + # Create an archive + archive = client.archives.create(name=f"test_archive_tags_{uuid.uuid4()}", embedding="openai/text-embedding-3-small") + + try: + # Attach archive to agent + client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id) + + # Insert passages with tags (if supported) + # Note: Tag support may depend on the SDK version + test_passages = [ + "Python tutorial for beginners", + "Advanced Python techniques", + "JavaScript basics", + ] + + for passage_text in test_passages: + client.archives.passages.create(archive_id=archive.id, text=passage_text) + + # Wait for indexing + time.sleep(2) + + # Test basic search without tags first + results = client.passages.search(query="programming tutorial", agent_id=agent.id, limit=10) + + assert len(results) > 0, "Should find passages" + + # Test with tag filtering if supported + # Note: The SDK may not expose tag parameters directly, so this test verifies basic functionality + # The backend will handle tag filtering when available + + finally: + # Clean up archive + try: + client.archives.delete(archive_id=archive.id) + except: + pass + + finally: + # Clean up agent + cleanup_agent_with_messages(client, agent.id) + + +@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") +def test_passage_search_with_date_filters(client: Letta, enable_turbopuffer): + """Test passage search with date range filtering""" + # Create an agent + agent = client.agents.create( + name=f"test_passage_dates_{uuid.uuid4()}", + memory_blocks=[CreateBlockParam(label="persona", value="test assistant")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + try: + # Create an archive + archive = client.archives.create(name=f"test_archive_dates_{uuid.uuid4()}", embedding="openai/text-embedding-3-small") + + try: + # Attach archive to agent + client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id) + + # Insert passages at different times + client.archives.passages.create(archive_id=archive.id, text="Recent passage about AI trends") + + # Wait a bit before creating another + time.sleep(1) + + client.archives.passages.create(archive_id=archive.id, text="Another passage about machine learning") + + # Wait for indexing + time.sleep(2) + + # Test search with date range + now = datetime.now(timezone.utc) + start_date = now - timedelta(hours=1) + + results = client.passages.search(query="AI machine learning", agent_id=agent.id, limit=10, start_date=start_date) + + assert len(results) > 0, "Should find recent passages" + + # Verify all results are within date range + for result in results: + passage_date = result.passage.created_at + if passage_date: + # Convert to datetime if it's a string + if isinstance(passage_date, str): + passage_date = datetime.fromisoformat(passage_date.replace("Z", "+00:00")) + assert passage_date >= start_date, "Passage should be after start_date" + + finally: + # Clean up archive + try: + client.archives.delete(archive_id=archive.id) + except: + pass + + finally: + # Clean up agent + cleanup_agent_with_messages(client, agent.id) + + +@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") +def test_message_search_basic(client: Letta, enable_message_embedding): + """Test basic message search functionality through the SDK""" + # Create an agent + agent = client.agents.create( + name=f"test_message_search_{uuid.uuid4()}", + memory_blocks=[CreateBlockParam(label="persona", value="helpful assistant")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + try: + # Send messages to the agent + test_messages = [ + "What is the capital of Mozambique?", + ] + + for msg_text in test_messages: + client.agents.messages.create(agent_id=agent.id, messages=[MessageCreateParam(role="user", content=msg_text)]) + + # Wait for messages to be indexed and database transactions to complete + # Extra time needed for async embedding and database commits + time.sleep(6) + + # Test FTS search for messages + results = client.messages.search(query="capital of Mozambique", search_mode="fts", limit=10) + + assert len(results) > 0, "Should find at least one message" + + finally: + # Clean up agent + cleanup_agent_with_messages(client, agent.id) + + +@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") +def test_passage_search_pagination(client: Letta, enable_turbopuffer): + """Test passage search pagination""" + # Create an agent + agent = client.agents.create( + name=f"test_passage_pagination_{uuid.uuid4()}", + memory_blocks=[CreateBlockParam(label="persona", value="test assistant")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + try: + # Create an archive + archive = client.archives.create(name=f"test_archive_pagination_{uuid.uuid4()}", embedding="openai/text-embedding-3-small") + + try: + # Attach archive to agent + client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id) + + # Insert many passages + for i in range(10): + client.archives.passages.create(archive_id=archive.id, text=f"Test passage number {i} about programming") + + # Wait for indexing + time.sleep(2) + + # Test with different limit values + results_limit_3 = client.passages.search(query="programming", agent_id=agent.id, limit=3) + + assert len(results_limit_3) == 3, "Should respect limit parameter" + + results_limit_5 = client.passages.search(query="programming", agent_id=agent.id, limit=5) + + assert len(results_limit_5) == 5, "Should return 5 results" + + results_all = client.passages.search(query="programming", agent_id=agent.id, limit=20) + + assert len(results_all) >= 10, "Should return all matching passages" + + finally: + # Clean up archive + try: + client.archives.delete(archive_id=archive.id) + except: + pass + + finally: + # Clean up agent + cleanup_agent_with_messages(client, agent.id) + + +@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") +def test_passage_search_org_wide(client: Letta, enable_turbopuffer): + """Test organization-wide passage search (without agent_id or archive_id)""" + # Create multiple agents with archives + agent1 = client.agents.create( + name=f"test_org_search_agent1_{uuid.uuid4()}", + memory_blocks=[CreateBlockParam(label="persona", value="test assistant 1")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + agent2 = client.agents.create( + name=f"test_org_search_agent2_{uuid.uuid4()}", + memory_blocks=[CreateBlockParam(label="persona", value="test assistant 2")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + try: + # Create archives for both agents + archive1 = client.archives.create(name=f"test_archive_org1_{uuid.uuid4()}", embedding="openai/text-embedding-3-small") + archive2 = client.archives.create(name=f"test_archive_org2_{uuid.uuid4()}", embedding="openai/text-embedding-3-small") + + try: + # Attach archives + client.agents.archives.attach(agent_id=agent1.id, archive_id=archive1.id) + client.agents.archives.attach(agent_id=agent2.id, archive_id=archive2.id) + + # Insert passages in both archives + client.archives.passages.create(archive_id=archive1.id, text="Unique passage in agent1 about quantum computing") + + client.archives.passages.create(archive_id=archive2.id, text="Unique passage in agent2 about blockchain technology") + + # Wait for indexing + time.sleep(2) + + # Test org-wide search (no agent_id or archive_id) + results = client.passages.search(query="unique passage", limit=20) + + # Should find passages from both agents + assert len(results) >= 2, "Should find passages from multiple agents" + + found_texts = [result.passage.text for result in results] + assert any("quantum computing" in text for text in found_texts), "Should find agent1 passage" + assert any("blockchain" in text for text in found_texts), "Should find agent2 passage" + + finally: + # Clean up archives + try: + client.archives.delete(archive_id=archive1.id) + except: + pass + try: + client.archives.delete(archive_id=archive2.id) + except: + pass + + finally: + # Clean up agents + cleanup_agent_with_messages(client, agent1.id) + cleanup_agent_with_messages(client, agent2.id) diff --git a/tests/sdk_v1/test_sdk_client.py b/tests/sdk_v1/test_sdk_client.py index 00aac802..f38b2efc 100644 --- a/tests/sdk_v1/test_sdk_client.py +++ b/tests/sdk_v1/test_sdk_client.py @@ -2263,6 +2263,100 @@ def test_create_agent(client: LettaSDKClient) -> None: client.agents.delete(agent_id=agent.id) +def test_list_all_messages(client: LettaSDKClient): + """Test listing all messages across multiple agents.""" + # Create two agents + agent1 = client.agents.create( + name="test_agent_1_messages", + memory_blocks=[CreateBlockParam(label="persona", value="you are agent 1")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + agent2 = client.agents.create( + name="test_agent_2_messages", + memory_blocks=[CreateBlockParam(label="persona", value="you are agent 2")], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) + + try: + # Send messages to both agents + agent1_msg_content = "Hello from agent 1" + agent2_msg_content = "Hello from agent 2" + + client.agents.messages.create( + agent_id=agent1.id, + messages=[MessageCreateParam(role="user", content=agent1_msg_content)], + ) + + client.agents.messages.create( + agent_id=agent2.id, + messages=[MessageCreateParam(role="user", content=agent2_msg_content)], + ) + + # Wait a bit for messages to be persisted + time.sleep(0.5) + + # List all messages across both agents + all_messages = client.messages.list(limit=100) + + # Verify we got messages back + assert hasattr(all_messages, "items") or isinstance(all_messages, list), "Should return messages list or paginated response" + + # Handle both list and paginated response formats + if hasattr(all_messages, "items"): + messages_list = all_messages.items + else: + messages_list = list(all_messages) + + # Should have messages from both agents (plus initial system messages) + assert len(messages_list) > 0, "Should have at least some messages" + + # Extract message content for verification + message_contents = [] + for msg in messages_list: + # Handle different message types + if hasattr(msg, "content"): + content = msg.content + if isinstance(content, str): + message_contents.append(content) + elif isinstance(content, list): + for item in content: + if hasattr(item, "text"): + message_contents.append(item.text) + + # Verify messages from both agents are present + found_agent1_msg = any(agent1_msg_content in content for content in message_contents) + found_agent2_msg = any(agent2_msg_content in content for content in message_contents) + + assert found_agent1_msg or found_agent2_msg, "Should find at least one of the messages we sent" + + # Test pagination parameters + limited_messages = client.messages.list(limit=5) + if hasattr(limited_messages, "items"): + limited_list = limited_messages.items + else: + limited_list = list(limited_messages) + + assert len(limited_list) <= 5, "Should respect limit parameter" + + # Test order parameter (desc should be default - newest first) + desc_messages = client.messages.list(limit=10, order="desc") + if hasattr(desc_messages, "items"): + desc_list = desc_messages.items + else: + desc_list = list(desc_messages) + + # Verify messages are returned + assert isinstance(desc_list, list), "Should return a list of messages" + + finally: + # Clean up agents + client.agents.delete(agent_id=agent1.id) + client.agents.delete(agent_id=agent2.id) + + def test_create_agent_with_tools(client: LettaSDKClient) -> None: """Test creating an agent with custom inventory management tools"""