feat: Add archival search endpoint [LET-4184] (#4390)

* Add archival search endpoint

* Run fern autogen

* Add de-dupe logic
This commit is contained in:
Matthew Zhou
2025-09-03 10:55:20 -07:00
committed by GitHub
parent 6e633bd8f9
commit 051a5cde6a
6 changed files with 374 additions and 87 deletions

View File

@@ -328,3 +328,14 @@ class CreateArchivalMemory(BaseModel):
text: str = Field(..., description="Text to write to archival memory.")
tags: Optional[List[str]] = Field(None, description="Optional list of tags to attach to the memory.")
created_at: Optional[datetime] = Field(None, description="Optional timestamp for the memory (defaults to current UTC time).")
class ArchivalMemorySearchResult(BaseModel):
timestamp: str = Field(..., description="Timestamp of when the memory was created, formatted in agent's timezone")
content: str = Field(..., description="Text content of the archival memory passage")
tags: List[str] = Field(default_factory=list, description="List of tags associated with this memory")
class ArchivalMemorySearchResponse(BaseModel):
results: List[ArchivalMemorySearchResult] = Field(..., description="List of search results matching the query")
count: int = Field(..., description="Total number of results returned")

View File

@@ -2,7 +2,7 @@ import asyncio
import json
import traceback
from datetime import datetime, timezone
from typing import Annotated, Any, Dict, List, Optional, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile, status
from fastapi.responses import JSONResponse
@@ -32,7 +32,13 @@ from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.memory import (
ArchivalMemorySearchResponse,
ArchivalMemorySearchResult,
ContextWindowOverview,
CreateArchivalMemory,
Memory,
)
from letta.schemas.message import MessageCreate
from letta.schemas.passage import Passage
from letta.schemas.run import Run
@@ -978,6 +984,55 @@ async def create_passage(
)
@router.get("/{agent_id}/archival-memory/search", response_model=ArchivalMemorySearchResponse, operation_id="search_archival_memory")
async def search_archival_memory(
agent_id: str,
query: str = Query(..., description="String to search for using semantic similarity"),
tags: Optional[List[str]] = Query(None, description="Optional list of tags to filter search results"),
tag_match_mode: Literal["any", "all"] = Query(
"any", description="How to match tags - 'any' to match passages with any of the tags, 'all' to match only passages with all tags"
),
top_k: Optional[int] = Query(None, description="Maximum number of results to return. Uses system default if not specified"),
start_datetime: Optional[str] = Query(None, description="Filter results to passages created after this datetime. ISO 8601 format"),
end_datetime: Optional[str] = Query(None, description="Filter results to passages created before this datetime. ISO 8601 format"),
server: "SyncServer" = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
This endpoint allows manual triggering of archival memory searches, enabling users to query
an agent's archival memory store directly via the API. The search uses the same functionality
as the agent's archival_memory_search tool but is accessible for external API usage.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
# Use the shared agent manager method
formatted_results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=agent_id,
actor=actor,
query=query,
tags=tags,
tag_match_mode=tag_match_mode,
top_k=top_k,
start_datetime=start_datetime,
end_datetime=end_datetime,
)
# Convert to proper response schema
search_results = [ArchivalMemorySearchResult(**result) for result in formatted_results]
return ArchivalMemorySearchResponse(results=search_results, count=count)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error during archival memory search: {str(e)}")
# TODO(ethan): query or path parameter for memory_id?
# @router.delete("/{agent_id}/archival")
@router.delete("/{agent_id}/archival-memory/{memory_id}", response_model=None, operation_id="delete_passage")

View File

@@ -1536,7 +1536,16 @@ class SyncServer(Server):
local_configs = self.get_local_llm_configs()
llm_models.extend(local_configs)
return llm_models
# dedupe by handle for uniqueness
# Seems like this is required from the tests?
seen_handles = set()
unique_models = []
for model in llm_models:
if model.handle not in seen_handles:
seen_handles.add(model.handle)
unique_models.append(model)
return unique_models
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
"""List available embedding models"""

View File

@@ -1,6 +1,7 @@
import asyncio
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Literal, Optional, Set, Tuple
from zoneinfo import ZoneInfo
import sqlalchemy as sa
from sqlalchemy import delete, func, insert, literal, or_, select, tuple_
@@ -21,6 +22,7 @@ from letta.constants import (
EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES,
FILES_TOOLS,
INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES,
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
)
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
@@ -2752,6 +2754,127 @@ class AgentManager:
return pydantic_passages
@enforce_types
@trace_method
async def search_agent_archival_memory_async(
self,
agent_id: str,
actor: PydanticUser,
query: str,
tags: Optional[List[str]] = None,
tag_match_mode: Literal["any", "all"] = "any",
top_k: Optional[int] = None,
start_datetime: Optional[str] = None,
end_datetime: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
This is a shared method used by both the agent tool and API endpoint to ensure consistent behavior.
Args:
agent_id: ID of the agent whose archival memory to search
actor: User performing the search
query: String to search for using semantic similarity
tags: Optional list of tags to filter search results
tag_match_mode: How to match tags - "any" or "all"
top_k: Maximum number of results to return
start_datetime: Filter results after this datetime (ISO 8601 format)
end_datetime: Filter results before this datetime (ISO 8601 format)
Returns:
Tuple of (formatted_results, count)
"""
# Handle empty or whitespace-only queries
if not query or not query.strip():
return [], 0
# Get the agent to access timezone and embedding config
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
# Parse datetime parameters if provided
start_date = None
end_date = None
if start_datetime:
try:
# Try parsing as full datetime first (with time)
start_date = datetime.fromisoformat(start_datetime)
except ValueError:
try:
# Fall back to date-only format
start_date = datetime.strptime(start_datetime, "%Y-%m-%d")
# Set to beginning of day
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
except ValueError:
raise ValueError(
f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
)
# Apply agent's timezone if datetime is naive
if start_date.tzinfo is None and agent_state.timezone:
tz = ZoneInfo(agent_state.timezone)
start_date = start_date.replace(tzinfo=tz)
if end_datetime:
try:
# Try parsing as full datetime first (with time)
end_date = datetime.fromisoformat(end_datetime)
except ValueError:
try:
# Fall back to date-only format
end_date = datetime.strptime(end_datetime, "%Y-%m-%d")
# Set to end of day for end dates
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise ValueError(f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)")
# Apply agent's timezone if datetime is naive
if end_date.tzinfo is None and agent_state.timezone:
tz = ZoneInfo(agent_state.timezone)
end_date = end_date.replace(tzinfo=tz)
# Convert string to TagMatchMode enum
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
# Get results using existing passage query method
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
all_results = await self.query_agent_passages_async(
actor=actor,
agent_id=agent_id,
query_text=query,
limit=limit,
embedding_config=agent_state.embedding_config,
embed_query=True,
tags=tags,
tag_match_mode=tag_mode,
start_date=start_date,
end_date=end_date,
)
# Format results to include tags with friendly timestamps
formatted_results = []
for result in all_results:
# Format timestamp in agent's timezone if available
timestamp = result.created_at
if timestamp and agent_state.timezone:
try:
# Convert to agent's timezone
tz = ZoneInfo(agent_state.timezone)
local_time = timestamp.astimezone(tz)
# Format as ISO string with timezone
formatted_timestamp = local_time.isoformat()
except Exception:
# Fallback to ISO format if timezone conversion fails
formatted_timestamp = str(timestamp)
else:
# Use ISO format if no timezone is set
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
return formatted_results, len(formatted_results)
@enforce_types
@trace_method
async def passage_size(

View File

@@ -284,92 +284,19 @@ class LettaCoreToolExecutor(ToolExecutor):
str: Query result string containing matching passages with timestamps, content, and tags.
"""
try:
# Parse datetime parameters if provided
from datetime import datetime
start_date = None
end_date = None
if start_datetime:
try:
# Try parsing as full datetime first (with time)
start_date = datetime.fromisoformat(start_datetime)
except ValueError:
try:
# Fall back to date-only format
start_date = datetime.strptime(start_datetime, "%Y-%m-%d")
# Set to beginning of day
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
except ValueError:
raise ValueError(
f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
)
# Apply agent's timezone if datetime is naive
if start_date.tzinfo is None and agent_state.timezone:
tz = ZoneInfo(agent_state.timezone)
start_date = start_date.replace(tzinfo=tz)
if end_datetime:
try:
# Try parsing as full datetime first (with time)
end_date = datetime.fromisoformat(end_datetime)
except ValueError:
try:
# Fall back to date-only format
end_date = datetime.strptime(end_datetime, "%Y-%m-%d")
# Set to end of day for end dates
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise ValueError(
f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
)
# Apply agent's timezone if datetime is naive
if end_date.tzinfo is None and agent_state.timezone:
tz = ZoneInfo(agent_state.timezone)
end_date = end_date.replace(tzinfo=tz)
# Convert string to TagMatchMode enum
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
# Get results using passage manager
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
all_results = await self.agent_manager.query_agent_passages_async(
actor=actor,
# Use the shared service method to get results
formatted_results, count = await self.agent_manager.search_agent_archival_memory_async(
agent_id=agent_state.id,
query_text=query,
limit=limit,
embedding_config=agent_state.embedding_config,
embed_query=True,
actor=actor,
query=query,
tags=tags,
tag_match_mode=tag_mode,
start_date=start_date,
end_date=end_date,
tag_match_mode=tag_match_mode,
top_k=top_k,
start_datetime=start_datetime,
end_datetime=end_datetime,
)
# Format results to include tags with friendly timestamps
formatted_results = []
for result in all_results:
# Format timestamp in agent's timezone if available
timestamp = result.created_at
if timestamp and agent_state.timezone:
try:
# Convert to agent's timezone
tz = ZoneInfo(agent_state.timezone)
local_time = timestamp.astimezone(tz)
# Format as ISO string with timezone
formatted_timestamp = local_time.isoformat()
except Exception:
# Fallback to ISO format if timezone conversion fails
formatted_timestamp = str(timestamp)
else:
# Use ISO format if no timezone is set
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
return formatted_results, len(formatted_results)
return formatted_results, count
except Exception as e:
raise e

View File

@@ -3607,7 +3607,7 @@ def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sara
@pytest.mark.asyncio
async def test_passage_tags_functionality(server: SyncServer, default_user, sarah_agent):
async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServer, default_user, sarah_agent):
"""Test comprehensive tag functionality for passages."""
from letta.schemas.enums import TagMatchMode
@@ -3974,6 +3974,168 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age
assert set(passages_case[0].tags) == set(case_tags)
@pytest.mark.asyncio
async def test_search_agent_archival_memory_async(disable_turbopuffer, server: SyncServer, default_user, sarah_agent):
"""Test the search_agent_archival_memory_async method that powers both the agent tool and API endpoint."""
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
# Create test passages with various content and tags
test_data = [
{
"text": "Python is a powerful programming language used for data science and web development.",
"tags": ["python", "programming", "data-science", "web"],
"created_at": datetime(2024, 1, 15, 10, 30, tzinfo=timezone.utc),
},
{
"text": "Machine learning algorithms can be implemented in Python using libraries like scikit-learn.",
"tags": ["python", "machine-learning", "algorithms"],
"created_at": datetime(2024, 1, 16, 14, 45, tzinfo=timezone.utc),
},
{
"text": "JavaScript is essential for frontend web development and modern web applications.",
"tags": ["javascript", "frontend", "web"],
"created_at": datetime(2024, 1, 17, 9, 15, tzinfo=timezone.utc),
},
{
"text": "Database design principles are important for building scalable applications.",
"tags": ["database", "design", "scalability"],
"created_at": datetime(2024, 1, 18, 16, 20, tzinfo=timezone.utc),
},
{
"text": "The weather today is sunny and warm, perfect for outdoor activities.",
"tags": ["weather", "outdoor"],
"created_at": datetime(2024, 1, 19, 11, 0, tzinfo=timezone.utc),
},
]
# Create passages in the database
created_passages = []
for data in test_data:
passage = await server.passage_manager.create_agent_passage_async(
PydanticPassage(
text=data["text"],
archive_id=archive.id,
organization_id=default_user.organization_id,
embedding=[0.1, 0.2, 0.3], # Mock embedding
embedding_config=DEFAULT_EMBEDDING_CONFIG,
tags=data["tags"],
created_at=data["created_at"],
),
actor=default_user,
)
created_passages.append(passage)
# Test 1: Basic search by query text
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="Python programming"
)
assert count > 0
assert len(results) == count
# Check structure of results
for result in results:
assert "timestamp" in result
assert "content" in result
assert "tags" in result
assert isinstance(result["tags"], list)
# Test 2: Search with tag filtering - single tag
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="programming", tags=["python"]
)
assert count > 0
# All results should have "python" tag
for result in results:
assert "python" in result["tags"]
# Test 3: Search with tag filtering - multiple tags with "any" mode
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="development", tags=["web", "database"], tag_match_mode="any"
)
assert count > 0
# All results should have at least one of the specified tags
for result in results:
assert any(tag in result["tags"] for tag in ["web", "database"])
# Test 4: Search with tag filtering - multiple tags with "all" mode
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="Python", tags=["python", "web"], tag_match_mode="all"
)
# Should only return results that have BOTH tags
for result in results:
assert "python" in result["tags"]
assert "web" in result["tags"]
# Test 5: Search with top_k limit
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="programming", top_k=2
)
assert count <= 2
assert len(results) <= 2
# Test 6: Search with datetime filtering
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="programming", start_datetime="2024-01-16", end_datetime="2024-01-17"
)
# Should only include passages created between those dates
for result in results:
# Parse timestamp to verify it's in range
timestamp_str = result["timestamp"]
# Basic validation that timestamp exists and has expected format
assert "2024-01-16" in timestamp_str or "2024-01-17" in timestamp_str
# Test 7: Search with ISO datetime format
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id,
actor=default_user,
query="algorithms",
start_datetime="2024-01-16T14:00:00",
end_datetime="2024-01-16T15:00:00",
)
# Should include the machine learning passage created at 14:45
assert count >= 0 # Might be 0 if no results, but shouldn't error
# Test 8: Search with non-existent agent should raise error
non_existent_agent_id = "agent-00000000-0000-4000-8000-000000000000"
with pytest.raises(Exception): # Should raise NoResultFound or similar
await server.agent_manager.search_agent_archival_memory_async(agent_id=non_existent_agent_id, actor=default_user, query="test")
# Test 9: Search with invalid datetime format should raise ValueError
with pytest.raises(ValueError, match="Invalid start_datetime format"):
await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="test", start_datetime="invalid-date"
)
# Test 10: Empty query should return empty results
results, count = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query="")
assert count == 0 # Empty query should return 0 results
assert len(results) == 0
# Test 11: Whitespace-only query should also return empty results
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query=" \n\t "
)
assert count == 0 # Whitespace-only query should return 0 results
assert len(results) == 0
# Cleanup - delete the created passages
for passage in created_passages:
await server.passage_manager.delete_agent_passage_by_id_async(passage_id=passage.id, actor=default_user)
# ======================================================================================================================
# User Manager Tests
# ======================================================================================================================