feat: add search routes [LET-6236] (#6280)
* claude code first pass * rename routes * search_messages and list_messages * revert agents messagesearch * generate api * fix backend for list all messages * request for message search * return list of letta message * add tests * error in archive endpoint * archive delete return type wrong * optional params for archive creation * add passage to tpuf on create * fix archive manager * support global passage search * search by agent * just do basic org wide search for now * change message test to be about fresh data, cleanup after --------- Co-authored-by: Ari Webb <ari@letta.com>
This commit is contained in:
@@ -336,15 +336,8 @@
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Archive"
|
||||
}
|
||||
}
|
||||
}
|
||||
"204": {
|
||||
"description": "Successful Response"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
@@ -15296,6 +15289,205 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/messages/": {
|
||||
"get": {
|
||||
"tags": ["messages"],
|
||||
"summary": "List All Messages",
|
||||
"description": "List messages across all agents for the current user.",
|
||||
"operationId": "list_all_messages",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "before",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order",
|
||||
"title": "Before"
|
||||
},
|
||||
"description": "Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order"
|
||||
},
|
||||
{
|
||||
"name": "after",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order",
|
||||
"title": "After"
|
||||
},
|
||||
"description": "Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order"
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Maximum number of messages to return",
|
||||
"default": 100,
|
||||
"title": "Limit"
|
||||
},
|
||||
"description": "Maximum number of messages to return"
|
||||
},
|
||||
{
|
||||
"name": "order",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"enum": ["asc", "desc"],
|
||||
"type": "string",
|
||||
"description": "Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first",
|
||||
"default": "desc",
|
||||
"title": "Order"
|
||||
},
|
||||
"description": "Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaMessageUnion"
|
||||
},
|
||||
"title": "Response List All Messages"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/messages/search": {
|
||||
"post": {
|
||||
"tags": ["messages"],
|
||||
"summary": "Search All Messages",
|
||||
"description": "Search messages across the organization with optional agent filtering.\nReturns messages with FTS/vector ranks and total RRF score.\n\nThis is a cloud-only feature.",
|
||||
"operationId": "search_all_messages",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SearchAllMessagesRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/SystemMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/UserMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ReasoningMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/HiddenReasoningMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolCallMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolReturnMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/AssistantMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ApprovalRequestMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ApprovalResponseMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/SummaryMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/EventMessage"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "message_type",
|
||||
"mapping": {
|
||||
"system_message": "#/components/schemas/SystemMessage",
|
||||
"user_message": "#/components/schemas/UserMessage",
|
||||
"reasoning_message": "#/components/schemas/ReasoningMessage",
|
||||
"hidden_reasoning_message": "#/components/schemas/HiddenReasoningMessage",
|
||||
"tool_call_message": "#/components/schemas/ToolCallMessage",
|
||||
"tool_return_message": "#/components/schemas/ToolReturnMessage",
|
||||
"assistant_message": "#/components/schemas/AssistantMessage",
|
||||
"approval_request_message": "#/components/schemas/ApprovalRequestMessage",
|
||||
"approval_response_message": "#/components/schemas/ApprovalResponseMessage",
|
||||
"summary": "#/components/schemas/SummaryMessage",
|
||||
"event": "#/components/schemas/EventMessage"
|
||||
}
|
||||
}
|
||||
},
|
||||
"title": "Response Search All Messages"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/messages/batches": {
|
||||
"post": {
|
||||
"tags": ["messages"],
|
||||
@@ -15673,6 +15865,51 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/passages/search": {
|
||||
"post": {
|
||||
"tags": ["passages"],
|
||||
"summary": "Search Passages",
|
||||
"description": "Search passages across the organization with optional agent and archive filtering.\nReturns passages with relevance scores.\n\nThis endpoint supports semantic search through passages:\n- If neither agent_id nor archive_id is provided, searches ALL passages in the organization\n- If agent_id is provided, searches passages across all archives attached to that agent\n- If archive_id is provided, searches passages within that specific archive\n- If both are provided, agent_id takes precedence",
|
||||
"operationId": "search_passages",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/PassageSearchRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/PassageSearchResult"
|
||||
},
|
||||
"title": "Response Search Passages"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/voice-beta/{agent_id}/chat/completions": {
|
||||
"post": {
|
||||
"tags": ["voice"],
|
||||
@@ -31576,6 +31813,122 @@
|
||||
"title": "PassageCreateRequest",
|
||||
"description": "Request model for creating a passage in an archive."
|
||||
},
|
||||
"PassageSearchRequest": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"title": "Query",
|
||||
"description": "Text query for semantic search"
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Agent Id",
|
||||
"description": "Filter passages by agent ID"
|
||||
},
|
||||
"archive_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Archive Id",
|
||||
"description": "Filter passages by archive ID"
|
||||
},
|
||||
"tags": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Tags",
|
||||
"description": "Optional list of tags to filter search results"
|
||||
},
|
||||
"tag_match_mode": {
|
||||
"type": "string",
|
||||
"enum": ["any", "all"],
|
||||
"title": "Tag Match Mode",
|
||||
"description": "How to match tags - 'any' to match passages with any of the tags, 'all' to match only passages with all tags",
|
||||
"default": "any"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"maximum": 100,
|
||||
"minimum": 1,
|
||||
"title": "Limit",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 50
|
||||
},
|
||||
"start_date": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Start Date",
|
||||
"description": "Filter results to passages created after this datetime"
|
||||
},
|
||||
"end_date": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "End Date",
|
||||
"description": "Filter results to passages created before this datetime"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["query"],
|
||||
"title": "PassageSearchRequest",
|
||||
"description": "Request model for searching passages across archives."
|
||||
},
|
||||
"PassageSearchResult": {
|
||||
"properties": {
|
||||
"passage": {
|
||||
"$ref": "#/components/schemas/Passage",
|
||||
"description": "The passage object"
|
||||
},
|
||||
"score": {
|
||||
"type": "number",
|
||||
"title": "Score",
|
||||
"description": "Relevance score"
|
||||
},
|
||||
"metadata": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Metadata",
|
||||
"description": "Additional metadata about the search result"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["passage", "score"],
|
||||
"title": "PassageSearchResult",
|
||||
"description": "Result from a passage search operation with scoring details."
|
||||
},
|
||||
"PipRequirement": {
|
||||
"properties": {
|
||||
"name": {
|
||||
@@ -33158,6 +33511,59 @@
|
||||
"enum": ["e2b", "modal", "local"],
|
||||
"title": "SandboxType"
|
||||
},
|
||||
"SearchAllMessagesRequest": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"title": "Query",
|
||||
"description": "Text query for full-text search"
|
||||
},
|
||||
"search_mode": {
|
||||
"type": "string",
|
||||
"enum": ["vector", "fts", "hybrid"],
|
||||
"title": "Search Mode",
|
||||
"description": "Search mode to use",
|
||||
"default": "hybrid"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"maximum": 100,
|
||||
"minimum": 1,
|
||||
"title": "Limit",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 50
|
||||
},
|
||||
"start_date": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Start Date",
|
||||
"description": "Filter messages created after this date"
|
||||
},
|
||||
"end_date": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "End Date",
|
||||
"description": "Filter messages created on or before this date"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["query"],
|
||||
"title": "SearchAllMessagesRequest"
|
||||
},
|
||||
"SleeptimeManager": {
|
||||
"properties": {
|
||||
"manager_type": {
|
||||
|
||||
@@ -2168,6 +2168,14 @@ class MessageSearchRequest(BaseModel):
|
||||
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
|
||||
|
||||
|
||||
class SearchAllMessagesRequest(BaseModel):
|
||||
query: str = Field(..., description="Text query for full-text search")
|
||||
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
|
||||
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
|
||||
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
|
||||
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
|
||||
|
||||
|
||||
class MessageSearchResult(BaseModel):
|
||||
"""Result from a message search operation with scoring details."""
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from letta.server.rest_api.routers.v1.jobs import router as jobs_router
|
||||
from letta.server.rest_api.routers.v1.llms import router as llm_router
|
||||
from letta.server.rest_api.routers.v1.mcp_servers import router as mcp_servers_router
|
||||
from letta.server.rest_api.routers.v1.messages import router as messages_router
|
||||
from letta.server.rest_api.routers.v1.passages import router as passages_router
|
||||
from letta.server.rest_api.routers.v1.providers import router as providers_router
|
||||
from letta.server.rest_api.routers.v1.runs import router as runs_router
|
||||
from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router
|
||||
@@ -52,6 +53,7 @@ ROUTERS = [
|
||||
tags_router,
|
||||
telemetry_router,
|
||||
messages_router,
|
||||
passages_router,
|
||||
voice_router,
|
||||
embeddings_router,
|
||||
openai_chat_completions_router,
|
||||
|
||||
@@ -65,7 +65,7 @@ async def create_archive(
|
||||
|
||||
embedding_config = archive.embedding_config
|
||||
if embedding_config is None and archive.embedding is not None:
|
||||
handle = f"{archive.embedding.provider}/{archive.embedding.model}"
|
||||
handle = archive.embedding
|
||||
embedding_config = await server.get_embedding_config_from_handle_async(
|
||||
handle=handle,
|
||||
actor=actor,
|
||||
@@ -150,7 +150,7 @@ async def modify_archive(
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{archive_id}", response_model=PydanticArchive, operation_id="delete_archive")
|
||||
@router.delete("/{archive_id}", status_code=204, operation_id="delete_archive")
|
||||
async def delete_archive(
|
||||
archive_id: ArchiveId,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -160,10 +160,11 @@ async def delete_archive(
|
||||
Delete an archive by its ID.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.archive_manager.delete_archive_async(
|
||||
await server.archive_manager.delete_archive_async(
|
||||
archive_id=archive_id,
|
||||
actor=actor,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/{archive_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_archive")
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from typing import List, Literal, Optional
|
||||
from typing import Annotated, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from pydantic import Field
|
||||
from starlette.requests import Request
|
||||
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import CreateBatch
|
||||
from letta.schemas.letta_response import LettaBatchMessages
|
||||
from letta.schemas.message import Message, MessageSearchRequest, MessageSearchResult, SearchAllMessagesRequest
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
@@ -18,6 +21,65 @@ router = APIRouter(prefix="/messages", tags=["messages"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
MessagesResponse = Annotated[
|
||||
list[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
|
||||
]
|
||||
|
||||
|
||||
@router.get("/", response_model=MessagesResponse, operation_id="list_all_messages")
|
||||
async def list_all_messages(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
before: Optional[str] = Query(
|
||||
None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order"
|
||||
),
|
||||
after: Optional[str] = Query(
|
||||
None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order"
|
||||
),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
order: Literal["asc", "desc"] = Query(
|
||||
"desc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List messages across all agents for the current user.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.get_all_messages_recall_async(
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
reverse=(order == "desc"),
|
||||
return_message_object=False,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/search", response_model=List[LettaMessageUnion], operation_id="search_all_messages")
|
||||
async def search_all_messages(
|
||||
request: SearchAllMessagesRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""
|
||||
Search messages across the organization with optional agent filtering.
|
||||
Returns messages with FTS/vector ranks and total RRF score.
|
||||
|
||||
This is a cloud-only feature.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
results = await server.message_manager.search_messages_org_async(
|
||||
actor=actor,
|
||||
query_text=request.query,
|
||||
search_mode=request.search_mode,
|
||||
limit=request.limit,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
)
|
||||
return Message.to_letta_messages_from_list(messages=[result.message for result in results], text_is_assistant_message=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/batches",
|
||||
response_model=BatchJob,
|
||||
|
||||
108
letta/server/rest_api/routers/v1/passages.py
Normal file
108
letta/server/rest_api/routers/v1/passages.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/passages", tags=["passages"])
|
||||
|
||||
|
||||
class PassageSearchRequest(BaseModel):
|
||||
"""Request model for searching passages across archives."""
|
||||
|
||||
query: str = Field(..., description="Text query for semantic search")
|
||||
agent_id: Optional[str] = Field(None, description="Filter passages by agent ID")
|
||||
archive_id: Optional[str] = Field(None, description="Filter passages by archive ID")
|
||||
tags: Optional[List[str]] = Field(None, description="Optional list of tags to filter search results")
|
||||
tag_match_mode: Literal["any", "all"] = Field(
|
||||
"any", description="How to match tags - 'any' to match passages with any of the tags, 'all' to match only passages with all tags"
|
||||
)
|
||||
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
|
||||
start_date: Optional[datetime] = Field(None, description="Filter results to passages created after this datetime")
|
||||
end_date: Optional[datetime] = Field(None, description="Filter results to passages created before this datetime")
|
||||
|
||||
|
||||
class PassageSearchResult(BaseModel):
|
||||
"""Result from a passage search operation with scoring details."""
|
||||
|
||||
passage: Passage = Field(..., description="The passage object")
|
||||
score: float = Field(..., description="Relevance score")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata about the search result")
|
||||
|
||||
|
||||
@router.post("/search", response_model=List[PassageSearchResult], operation_id="search_passages")
|
||||
async def search_passages(
|
||||
request: PassageSearchRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""
|
||||
Search passages across the organization with optional agent and archive filtering.
|
||||
Returns passages with relevance scores.
|
||||
|
||||
This endpoint supports semantic search through passages:
|
||||
- If neither agent_id nor archive_id is provided, searches ALL passages in the organization
|
||||
- If agent_id is provided, searches passages across all archives attached to that agent
|
||||
- If archive_id is provided, searches passages within that specific archive
|
||||
- If both are provided, agent_id takes precedence
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Convert tag_match_mode to enum
|
||||
tag_mode = TagMatchMode.ANY if request.tag_match_mode == "any" else TagMatchMode.ALL
|
||||
|
||||
# Determine which embedding config to use
|
||||
embedding_config = None
|
||||
if request.agent_id:
|
||||
# Search by agent
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=request.agent_id, actor=actor)
|
||||
embedding_config = agent_state.embedding_config
|
||||
elif request.archive_id:
|
||||
# Search by archive_id
|
||||
archive = await server.archive_manager.get_archive_by_id_async(archive_id=request.archive_id, actor=actor)
|
||||
embedding_config = archive.embedding_config
|
||||
else:
|
||||
# Search across all passages in the organization
|
||||
# Get default embedding config from any agent or use server default
|
||||
agent_count = await server.agent_manager.size_async(actor=actor)
|
||||
if agent_count > 0:
|
||||
# Get first agent to derive embedding config
|
||||
agents = await server.agent_manager.list_agents_async(actor=actor, limit=1)
|
||||
if agents:
|
||||
embedding_config = agents[0].embedding_config
|
||||
|
||||
if not embedding_config:
|
||||
# Fall back to server default
|
||||
embedding_config = server.default_embedding_config
|
||||
|
||||
# Search passages
|
||||
passages_with_metadata = await server.agent_manager.query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=request.agent_id, # Can be None for organization-wide search
|
||||
archive_id=request.archive_id, # Can be None if searching by agent or org-wide
|
||||
query_text=request.query,
|
||||
limit=request.limit,
|
||||
embedding_config=embedding_config,
|
||||
embed_query=True,
|
||||
tags=request.tags,
|
||||
tag_match_mode=tag_mode,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
)
|
||||
|
||||
# Convert to PassageSearchResult objects
|
||||
results = [
|
||||
PassageSearchResult(
|
||||
passage=passage,
|
||||
score=score,
|
||||
metadata=metadata,
|
||||
)
|
||||
for passage, score, metadata in passages_with_metadata
|
||||
]
|
||||
|
||||
return results
|
||||
@@ -784,6 +784,51 @@ class SyncServer(object):
|
||||
|
||||
return records
|
||||
|
||||
async def get_all_messages_recall_async(
|
||||
self,
|
||||
actor: User,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
group_id: Optional[str] = None,
|
||||
reverse: Optional[bool] = False,
|
||||
return_message_object: bool = True,
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
include_err: Optional[bool] = None,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
records = await self.message_manager.list_messages(
|
||||
agent_id=None,
|
||||
actor=actor,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
include_err=include_err,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
# NOTE: We are assuming all messages are coming from letta_v1_agent. This may lead to slightly incorrect assistant message handling.
|
||||
# text_is_assistant_message = agent_state.agent_type == AgentType.letta_v1_agent
|
||||
text_is_assistant_message = True
|
||||
|
||||
records = Message.to_letta_messages_from_list(
|
||||
messages=records,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
reverse=reverse,
|
||||
include_err=include_err,
|
||||
text_is_assistant_message=text_is_assistant_message,
|
||||
)
|
||||
|
||||
if reverse:
|
||||
records = records[::-1]
|
||||
|
||||
return records
|
||||
|
||||
def get_server_config(self, include_defaults: bool = False) -> dict:
|
||||
"""Return the base config"""
|
||||
|
||||
|
||||
@@ -2183,6 +2183,7 @@ class AgentManager:
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
archive_id: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -2197,17 +2198,26 @@ class AgentManager:
|
||||
) -> List[Tuple[PydanticPassage, float, dict]]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
# Check if we should use Turbopuffer for vector search
|
||||
if embed_query and agent_id and query_text and embedding_config:
|
||||
# Get archive IDs for the agent
|
||||
archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor)
|
||||
# Support searching by either agent_id or archive_id directly
|
||||
if embed_query and query_text and embedding_config:
|
||||
target_archive_id = None
|
||||
|
||||
if archive_ids:
|
||||
# TODO: Remove this restriction once we support multiple archives with mixed vector DB providers
|
||||
if len(archive_ids) > 1:
|
||||
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search")
|
||||
if agent_id:
|
||||
# Get archive IDs for the agent
|
||||
archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
if archive_ids:
|
||||
# TODO: Remove this restriction once we support multiple archives with mixed vector DB providers
|
||||
if len(archive_ids) > 1:
|
||||
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search")
|
||||
target_archive_id = archive_ids[0]
|
||||
elif archive_id:
|
||||
# Use the provided archive_id directly
|
||||
target_archive_id = archive_id
|
||||
|
||||
if target_archive_id:
|
||||
# Get archive to check vector_db_provider
|
||||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_ids[0], actor=actor)
|
||||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=target_archive_id, actor=actor)
|
||||
|
||||
# Use Turbopuffer for vector search if archive is configured for TPUF
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
@@ -2226,7 +2236,7 @@ class AgentManager:
|
||||
tpuf_client = TurbopufferClient()
|
||||
# use hybrid search to combine vector and full-text search
|
||||
passages_with_scores = await tpuf_client.query_passages(
|
||||
archive_id=archive_ids[0],
|
||||
archive_id=target_archive_id,
|
||||
query_text=query_text, # pass text for potential hybrid search
|
||||
search_mode="hybrid", # use hybrid mode for better results
|
||||
top_k=limit,
|
||||
@@ -2239,14 +2249,13 @@ class AgentManager:
|
||||
|
||||
# Return full tuples with metadata
|
||||
return passages_with_scores
|
||||
else:
|
||||
return []
|
||||
|
||||
# Fall back to SQL-based search for non-vector queries or NATIVE archives
|
||||
async with db_registry.async_session() as session:
|
||||
main_query = await build_agent_passage_query(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
archive_id=archive_id,
|
||||
query_text=query_text,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
|
||||
@@ -284,8 +284,8 @@ class ArchiveManager:
|
||||
self,
|
||||
archive_id: str,
|
||||
text: str,
|
||||
metadata: Dict = None,
|
||||
tags: List[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
actor: PydanticUser = None,
|
||||
) -> PydanticPassage:
|
||||
"""Create a passage in an archive.
|
||||
@@ -335,6 +335,27 @@ class ArchiveManager:
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# If archive uses Turbopuffer, also write to Turbopuffer (dual-write)
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Insert to Turbopuffer with the same ID as SQL
|
||||
await tpuf_client.insert_archival_memories(
|
||||
archive_id=archive.id,
|
||||
text_chunks=[created_passage.text],
|
||||
passage_ids=[created_passage.id],
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
)
|
||||
logger.info(f"Uploaded passage {created_passage.id} to Turbopuffer for archive {archive_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload passage to Turbopuffer: {e}")
|
||||
# Don't fail the entire operation if Turbopuffer upload fails
|
||||
# The passage is already saved to SQL
|
||||
|
||||
logger.info(f"Created passage {created_passage.id} in archive {archive_id}")
|
||||
return created_passage
|
||||
|
||||
|
||||
@@ -1176,7 +1176,8 @@ async def build_source_passage_query(
|
||||
|
||||
async def build_agent_passage_query(
|
||||
actor: User,
|
||||
agent_id: str, # Required for agent passages
|
||||
agent_id: Optional[str] = None,
|
||||
archive_id: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
@@ -1186,7 +1187,11 @@ async def build_agent_passage_query(
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
) -> Select:
|
||||
"""Build query for agent passages with all filters applied."""
|
||||
"""Build query for agent/archive passages with all filters applied.
|
||||
|
||||
Can provide agent_id, archive_id, both, or neither (org-wide search).
|
||||
If both are provided, agent_id takes precedence.
|
||||
"""
|
||||
|
||||
# Handle embedding for vector search
|
||||
embedded_text = None
|
||||
@@ -1203,12 +1208,23 @@ async def build_agent_passage_query(
|
||||
embedded_text = np.array(embeddings[0])
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
# Base query for agent passages - join through archives_agents
|
||||
query = (
|
||||
select(ArchivalPassage)
|
||||
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
|
||||
.where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
# Base query for passages
|
||||
if agent_id:
|
||||
# Query for agent passages - join through archives_agents
|
||||
# Agent_id takes precedence if both agent_id and archive_id are provided
|
||||
query = (
|
||||
select(ArchivalPassage)
|
||||
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
|
||||
.where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
elif archive_id:
|
||||
# Query for archive passages directly
|
||||
query = select(ArchivalPassage).where(
|
||||
ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id
|
||||
)
|
||||
else:
|
||||
# Org-wide search - all passages in organization
|
||||
query = select(ArchivalPassage).where(ArchivalPassage.organization_id == actor.organization_id)
|
||||
|
||||
# Apply filters
|
||||
if start_date:
|
||||
|
||||
440
tests/sdk_v1/search_test.py
Normal file
440
tests/sdk_v1/search_test.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
End-to-end tests for passage and message search endpoints using the SDK client.
|
||||
|
||||
These tests verify that the /v1/passages/search and /v1/messages/search endpoints work correctly
|
||||
with Turbopuffer integration, including vector search, FTS, hybrid search, filtering, and pagination.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from letta_client import Letta
|
||||
from letta_client.types import CreateBlockParam, MessageCreateParam
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
def cleanup_agent_with_messages(client: Letta, agent_id: str):
|
||||
"""
|
||||
Helper function to properly clean up an agent by first deleting all its messages
|
||||
from Turbopuffer before deleting the agent itself.
|
||||
|
||||
Args:
|
||||
client: Letta SDK client
|
||||
agent_id: ID of the agent to clean up
|
||||
"""
|
||||
try:
|
||||
# First, delete all messages for this agent from Turbopuffer
|
||||
# This ensures no orphaned messages remain in Turbopuffer
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages():
|
||||
tpuf_client = TurbopufferClient()
|
||||
# Delete all messages for this agent from Turbopuffer
|
||||
asyncio.run(tpuf_client.delete_all_messages(agent_id))
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clean up Turbopuffer messages for agent {agent_id}: {e}")
|
||||
|
||||
# Now delete the agent itself (which will delete SQL messages via cascade)
|
||||
client.agents.delete(agent_id=agent_id)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clean up agent {agent_id}: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""Server fixture for testing"""
|
||||
config = LettaConfig.load()
|
||||
config.save()
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_turbopuffer():
|
||||
"""Temporarily enable Turbopuffer for testing"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_environment = settings.environment
|
||||
|
||||
# Enable Turbopuffer with test key
|
||||
settings.use_tpuf = True
|
||||
if not settings.tpuf_api_key:
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.environment = "DEV"
|
||||
|
||||
yield
|
||||
|
||||
# Restore original values
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_message_embedding():
|
||||
"""Enable both Turbopuffer and message embedding"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_embed_messages = settings.embed_all_messages
|
||||
original_environment = settings.environment
|
||||
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = settings.tpuf_api_key or "test-key"
|
||||
settings.embed_all_messages = True
|
||||
settings.environment = "DEV"
|
||||
|
||||
yield
|
||||
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_turbopuffer():
|
||||
"""Ensure Turbopuffer is disabled for testing"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_embed_messages = settings.embed_all_messages
|
||||
|
||||
settings.use_tpuf = False
|
||||
settings.embed_all_messages = False
|
||||
|
||||
yield
|
||||
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_basic(client: Letta, enable_turbopuffer):
|
||||
"""Test basic passage search functionality through the SDK"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_passage_search_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create an archive and attach to agent
|
||||
archive = client.archives.create(name=f"test_archive_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archive to agent
|
||||
client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id)
|
||||
|
||||
# Insert some passages
|
||||
test_passages = [
|
||||
"Python is a popular programming language for data science and machine learning.",
|
||||
"JavaScript is widely used for web development and frontend applications.",
|
||||
"Turbopuffer is a vector database optimized for performance and scalability.",
|
||||
]
|
||||
|
||||
for passage_text in test_passages:
|
||||
client.archives.passages.create(archive_id=archive.id, text=passage_text)
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test search by agent_id
|
||||
results = client.passages.search(query="python programming", agent_id=agent.id, limit=10)
|
||||
|
||||
assert len(results) > 0, "Should find at least one passage"
|
||||
assert any("Python" in result.passage.text for result in results), "Should find Python-related passage"
|
||||
|
||||
# Verify result structure
|
||||
for result in results:
|
||||
assert hasattr(result, "passage"), "Result should have passage field"
|
||||
assert hasattr(result, "score"), "Result should have score field"
|
||||
assert hasattr(result, "metadata"), "Result should have metadata field"
|
||||
assert isinstance(result.score, float), "Score should be a float"
|
||||
|
||||
# Test search by archive_id
|
||||
archive_results = client.passages.search(query="vector database", archive_id=archive.id, limit=10)
|
||||
|
||||
assert len(archive_results) > 0, "Should find passages in archive"
|
||||
assert any("Turbopuffer" in result.passage.text or "vector" in result.passage.text for result in archive_results), (
|
||||
"Should find vector-related passage"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_with_tags(client: Letta, enable_turbopuffer):
|
||||
"""Test passage search with tag filtering"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_passage_tags_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create an archive
|
||||
archive = client.archives.create(name=f"test_archive_tags_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archive to agent
|
||||
client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id)
|
||||
|
||||
# Insert passages with tags (if supported)
|
||||
# Note: Tag support may depend on the SDK version
|
||||
test_passages = [
|
||||
"Python tutorial for beginners",
|
||||
"Advanced Python techniques",
|
||||
"JavaScript basics",
|
||||
]
|
||||
|
||||
for passage_text in test_passages:
|
||||
client.archives.passages.create(archive_id=archive.id, text=passage_text)
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test basic search without tags first
|
||||
results = client.passages.search(query="programming tutorial", agent_id=agent.id, limit=10)
|
||||
|
||||
assert len(results) > 0, "Should find passages"
|
||||
|
||||
# Test with tag filtering if supported
|
||||
# Note: The SDK may not expose tag parameters directly, so this test verifies basic functionality
|
||||
# The backend will handle tag filtering when available
|
||||
|
||||
finally:
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_with_date_filters(client: Letta, enable_turbopuffer):
|
||||
"""Test passage search with date range filtering"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_passage_dates_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create an archive
|
||||
archive = client.archives.create(name=f"test_archive_dates_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archive to agent
|
||||
client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id)
|
||||
|
||||
# Insert passages at different times
|
||||
client.archives.passages.create(archive_id=archive.id, text="Recent passage about AI trends")
|
||||
|
||||
# Wait a bit before creating another
|
||||
time.sleep(1)
|
||||
|
||||
client.archives.passages.create(archive_id=archive.id, text="Another passage about machine learning")
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test search with date range
|
||||
now = datetime.now(timezone.utc)
|
||||
start_date = now - timedelta(hours=1)
|
||||
|
||||
results = client.passages.search(query="AI machine learning", agent_id=agent.id, limit=10, start_date=start_date)
|
||||
|
||||
assert len(results) > 0, "Should find recent passages"
|
||||
|
||||
# Verify all results are within date range
|
||||
for result in results:
|
||||
passage_date = result.passage.created_at
|
||||
if passage_date:
|
||||
# Convert to datetime if it's a string
|
||||
if isinstance(passage_date, str):
|
||||
passage_date = datetime.fromisoformat(passage_date.replace("Z", "+00:00"))
|
||||
assert passage_date >= start_date, "Passage should be after start_date"
|
||||
|
||||
finally:
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_message_search_basic(client: Letta, enable_message_embedding):
|
||||
"""Test basic message search functionality through the SDK"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_message_search_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="helpful assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Send messages to the agent
|
||||
test_messages = [
|
||||
"What is the capital of Mozambique?",
|
||||
]
|
||||
|
||||
for msg_text in test_messages:
|
||||
client.agents.messages.create(agent_id=agent.id, messages=[MessageCreateParam(role="user", content=msg_text)])
|
||||
|
||||
# Wait for messages to be indexed and database transactions to complete
|
||||
# Extra time needed for async embedding and database commits
|
||||
time.sleep(6)
|
||||
|
||||
# Test FTS search for messages
|
||||
results = client.messages.search(query="capital of Mozambique", search_mode="fts", limit=10)
|
||||
|
||||
assert len(results) > 0, "Should find at least one message"
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_pagination(client: Letta, enable_turbopuffer):
|
||||
"""Test passage search pagination"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_passage_pagination_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create an archive
|
||||
archive = client.archives.create(name=f"test_archive_pagination_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archive to agent
|
||||
client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id)
|
||||
|
||||
# Insert many passages
|
||||
for i in range(10):
|
||||
client.archives.passages.create(archive_id=archive.id, text=f"Test passage number {i} about programming")
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test with different limit values
|
||||
results_limit_3 = client.passages.search(query="programming", agent_id=agent.id, limit=3)
|
||||
|
||||
assert len(results_limit_3) == 3, "Should respect limit parameter"
|
||||
|
||||
results_limit_5 = client.passages.search(query="programming", agent_id=agent.id, limit=5)
|
||||
|
||||
assert len(results_limit_5) == 5, "Should return 5 results"
|
||||
|
||||
results_all = client.passages.search(query="programming", agent_id=agent.id, limit=20)
|
||||
|
||||
assert len(results_all) >= 10, "Should return all matching passages"
|
||||
|
||||
finally:
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_org_wide(client: Letta, enable_turbopuffer):
|
||||
"""Test organization-wide passage search (without agent_id or archive_id)"""
|
||||
# Create multiple agents with archives
|
||||
agent1 = client.agents.create(
|
||||
name=f"test_org_search_agent1_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant 1")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
agent2 = client.agents.create(
|
||||
name=f"test_org_search_agent2_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant 2")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create archives for both agents
|
||||
archive1 = client.archives.create(name=f"test_archive_org1_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
archive2 = client.archives.create(name=f"test_archive_org2_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archives
|
||||
client.agents.archives.attach(agent_id=agent1.id, archive_id=archive1.id)
|
||||
client.agents.archives.attach(agent_id=agent2.id, archive_id=archive2.id)
|
||||
|
||||
# Insert passages in both archives
|
||||
client.archives.passages.create(archive_id=archive1.id, text="Unique passage in agent1 about quantum computing")
|
||||
|
||||
client.archives.passages.create(archive_id=archive2.id, text="Unique passage in agent2 about blockchain technology")
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test org-wide search (no agent_id or archive_id)
|
||||
results = client.passages.search(query="unique passage", limit=20)
|
||||
|
||||
# Should find passages from both agents
|
||||
assert len(results) >= 2, "Should find passages from multiple agents"
|
||||
|
||||
found_texts = [result.passage.text for result in results]
|
||||
assert any("quantum computing" in text for text in found_texts), "Should find agent1 passage"
|
||||
assert any("blockchain" in text for text in found_texts), "Should find agent2 passage"
|
||||
|
||||
finally:
|
||||
# Clean up archives
|
||||
try:
|
||||
client.archives.delete(archive_id=archive1.id)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
client.archives.delete(archive_id=archive2.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agents
|
||||
cleanup_agent_with_messages(client, agent1.id)
|
||||
cleanup_agent_with_messages(client, agent2.id)
|
||||
@@ -2263,6 +2263,100 @@ def test_create_agent(client: LettaSDKClient) -> None:
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
|
||||
|
||||
def test_list_all_messages(client: LettaSDKClient):
|
||||
"""Test listing all messages across multiple agents."""
|
||||
# Create two agents
|
||||
agent1 = client.agents.create(
|
||||
name="test_agent_1_messages",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="you are agent 1")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
agent2 = client.agents.create(
|
||||
name="test_agent_2_messages",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="you are agent 2")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Send messages to both agents
|
||||
agent1_msg_content = "Hello from agent 1"
|
||||
agent2_msg_content = "Hello from agent 2"
|
||||
|
||||
client.agents.messages.create(
|
||||
agent_id=agent1.id,
|
||||
messages=[MessageCreateParam(role="user", content=agent1_msg_content)],
|
||||
)
|
||||
|
||||
client.agents.messages.create(
|
||||
agent_id=agent2.id,
|
||||
messages=[MessageCreateParam(role="user", content=agent2_msg_content)],
|
||||
)
|
||||
|
||||
# Wait a bit for messages to be persisted
|
||||
time.sleep(0.5)
|
||||
|
||||
# List all messages across both agents
|
||||
all_messages = client.messages.list(limit=100)
|
||||
|
||||
# Verify we got messages back
|
||||
assert hasattr(all_messages, "items") or isinstance(all_messages, list), "Should return messages list or paginated response"
|
||||
|
||||
# Handle both list and paginated response formats
|
||||
if hasattr(all_messages, "items"):
|
||||
messages_list = all_messages.items
|
||||
else:
|
||||
messages_list = list(all_messages)
|
||||
|
||||
# Should have messages from both agents (plus initial system messages)
|
||||
assert len(messages_list) > 0, "Should have at least some messages"
|
||||
|
||||
# Extract message content for verification
|
||||
message_contents = []
|
||||
for msg in messages_list:
|
||||
# Handle different message types
|
||||
if hasattr(msg, "content"):
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
message_contents.append(content)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if hasattr(item, "text"):
|
||||
message_contents.append(item.text)
|
||||
|
||||
# Verify messages from both agents are present
|
||||
found_agent1_msg = any(agent1_msg_content in content for content in message_contents)
|
||||
found_agent2_msg = any(agent2_msg_content in content for content in message_contents)
|
||||
|
||||
assert found_agent1_msg or found_agent2_msg, "Should find at least one of the messages we sent"
|
||||
|
||||
# Test pagination parameters
|
||||
limited_messages = client.messages.list(limit=5)
|
||||
if hasattr(limited_messages, "items"):
|
||||
limited_list = limited_messages.items
|
||||
else:
|
||||
limited_list = list(limited_messages)
|
||||
|
||||
assert len(limited_list) <= 5, "Should respect limit parameter"
|
||||
|
||||
# Test order parameter (desc should be default - newest first)
|
||||
desc_messages = client.messages.list(limit=10, order="desc")
|
||||
if hasattr(desc_messages, "items"):
|
||||
desc_list = desc_messages.items
|
||||
else:
|
||||
desc_list = list(desc_messages)
|
||||
|
||||
# Verify messages are returned
|
||||
assert isinstance(desc_list, list), "Should return a list of messages"
|
||||
|
||||
finally:
|
||||
# Clean up agents
|
||||
client.agents.delete(agent_id=agent1.id)
|
||||
client.agents.delete(agent_id=agent2.id)
|
||||
|
||||
|
||||
def test_create_agent_with_tools(client: LettaSDKClient) -> None:
|
||||
"""Test creating an agent with custom inventory management tools"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user