feat: Add archival search endpoint [LET-4184] (#4390)
* Add archival search endpoint * Run fern autogen * Add de-dupe logic
This commit is contained in:
@@ -328,3 +328,14 @@ class CreateArchivalMemory(BaseModel):
|
||||
text: str = Field(..., description="Text to write to archival memory.")
|
||||
tags: Optional[List[str]] = Field(None, description="Optional list of tags to attach to the memory.")
|
||||
created_at: Optional[datetime] = Field(None, description="Optional timestamp for the memory (defaults to current UTC time).")
|
||||
|
||||
|
||||
class ArchivalMemorySearchResult(BaseModel):
|
||||
timestamp: str = Field(..., description="Timestamp of when the memory was created, formatted in agent's timezone")
|
||||
content: str = Field(..., description="Text content of the archival memory passage")
|
||||
tags: List[str] = Field(default_factory=list, description="List of tags associated with this memory")
|
||||
|
||||
|
||||
class ArchivalMemorySearchResponse(BaseModel):
|
||||
results: List[ArchivalMemorySearchResult] = Field(..., description="List of search results matching the query")
|
||||
count: int = Field(..., description="Total number of results returned")
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile, status
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -32,7 +32,13 @@ from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
|
||||
from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
|
||||
from letta.schemas.memory import (
|
||||
ArchivalMemorySearchResponse,
|
||||
ArchivalMemorySearchResult,
|
||||
ContextWindowOverview,
|
||||
CreateArchivalMemory,
|
||||
Memory,
|
||||
)
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.run import Run
|
||||
@@ -978,6 +984,55 @@ async def create_passage(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/archival-memory/search", response_model=ArchivalMemorySearchResponse, operation_id="search_archival_memory")
|
||||
async def search_archival_memory(
|
||||
agent_id: str,
|
||||
query: str = Query(..., description="String to search for using semantic similarity"),
|
||||
tags: Optional[List[str]] = Query(None, description="Optional list of tags to filter search results"),
|
||||
tag_match_mode: Literal["any", "all"] = Query(
|
||||
"any", description="How to match tags - 'any' to match passages with any of the tags, 'all' to match only passages with all tags"
|
||||
),
|
||||
top_k: Optional[int] = Query(None, description="Maximum number of results to return. Uses system default if not specified"),
|
||||
start_datetime: Optional[str] = Query(None, description="Filter results to passages created after this datetime. ISO 8601 format"),
|
||||
end_datetime: Optional[str] = Query(None, description="Filter results to passages created before this datetime. ISO 8601 format"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
|
||||
|
||||
This endpoint allows manual triggering of archival memory searches, enabling users to query
|
||||
an agent's archival memory store directly via the API. The search uses the same functionality
|
||||
as the agent's archival_memory_search tool but is accessible for external API usage.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
# Use the shared agent manager method
|
||||
formatted_results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
query=query,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_match_mode,
|
||||
top_k=top_k,
|
||||
start_datetime=start_datetime,
|
||||
end_datetime=end_datetime,
|
||||
)
|
||||
|
||||
# Convert to proper response schema
|
||||
search_results = [ArchivalMemorySearchResult(**result) for result in formatted_results]
|
||||
|
||||
return ArchivalMemorySearchResponse(results=search_results, count=count)
|
||||
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error during archival memory search: {str(e)}")
|
||||
|
||||
|
||||
# TODO(ethan): query or path parameter for memory_id?
|
||||
# @router.delete("/{agent_id}/archival")
|
||||
@router.delete("/{agent_id}/archival-memory/{memory_id}", response_model=None, operation_id="delete_passage")
|
||||
|
||||
@@ -1536,7 +1536,16 @@ class SyncServer(Server):
|
||||
local_configs = self.get_local_llm_configs()
|
||||
llm_models.extend(local_configs)
|
||||
|
||||
return llm_models
|
||||
# dedupe by handle for uniqueness
|
||||
# Seems like this is required from the tests?
|
||||
seen_handles = set()
|
||||
unique_models = []
|
||||
for model in llm_models:
|
||||
if model.handle not in seen_handles:
|
||||
seen_handles.add(model.handle)
|
||||
unique_models.append(model)
|
||||
|
||||
return unique_models
|
||||
|
||||
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, insert, literal, or_, select, tuple_
|
||||
@@ -21,6 +22,7 @@ from letta.constants import (
|
||||
EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES,
|
||||
FILES_TOOLS,
|
||||
INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES,
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
|
||||
)
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
@@ -2752,6 +2754,127 @@ class AgentManager:
|
||||
|
||||
return pydantic_passages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def search_agent_archival_memory_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
query: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
tag_match_mode: Literal["any", "all"] = "any",
|
||||
top_k: Optional[int] = None,
|
||||
start_datetime: Optional[str] = None,
|
||||
end_datetime: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
|
||||
|
||||
This is a shared method used by both the agent tool and API endpoint to ensure consistent behavior.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent whose archival memory to search
|
||||
actor: User performing the search
|
||||
query: String to search for using semantic similarity
|
||||
tags: Optional list of tags to filter search results
|
||||
tag_match_mode: How to match tags - "any" or "all"
|
||||
top_k: Maximum number of results to return
|
||||
start_datetime: Filter results after this datetime (ISO 8601 format)
|
||||
end_datetime: Filter results before this datetime (ISO 8601 format)
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_results, count)
|
||||
"""
|
||||
# Handle empty or whitespace-only queries
|
||||
if not query or not query.strip():
|
||||
return [], 0
|
||||
|
||||
# Get the agent to access timezone and embedding config
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Parse datetime parameters if provided
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if start_datetime:
|
||||
try:
|
||||
# Try parsing as full datetime first (with time)
|
||||
start_date = datetime.fromisoformat(start_datetime)
|
||||
except ValueError:
|
||||
try:
|
||||
# Fall back to date-only format
|
||||
start_date = datetime.strptime(start_datetime, "%Y-%m-%d")
|
||||
# Set to beginning of day
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
|
||||
)
|
||||
|
||||
# Apply agent's timezone if datetime is naive
|
||||
if start_date.tzinfo is None and agent_state.timezone:
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
start_date = start_date.replace(tzinfo=tz)
|
||||
|
||||
if end_datetime:
|
||||
try:
|
||||
# Try parsing as full datetime first (with time)
|
||||
end_date = datetime.fromisoformat(end_datetime)
|
||||
except ValueError:
|
||||
try:
|
||||
# Fall back to date-only format
|
||||
end_date = datetime.strptime(end_datetime, "%Y-%m-%d")
|
||||
# Set to end of day for end dates
|
||||
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)")
|
||||
|
||||
# Apply agent's timezone if datetime is naive
|
||||
if end_date.tzinfo is None and agent_state.timezone:
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
end_date = end_date.replace(tzinfo=tz)
|
||||
|
||||
# Convert string to TagMatchMode enum
|
||||
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
|
||||
|
||||
# Get results using existing passage query method
|
||||
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
all_results = await self.query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
embed_query=True,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_mode,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Format results to include tags with friendly timestamps
|
||||
formatted_results = []
|
||||
for result in all_results:
|
||||
# Format timestamp in agent's timezone if available
|
||||
timestamp = result.created_at
|
||||
if timestamp and agent_state.timezone:
|
||||
try:
|
||||
# Convert to agent's timezone
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
local_time = timestamp.astimezone(tz)
|
||||
# Format as ISO string with timezone
|
||||
formatted_timestamp = local_time.isoformat()
|
||||
except Exception:
|
||||
# Fallback to ISO format if timezone conversion fails
|
||||
formatted_timestamp = str(timestamp)
|
||||
else:
|
||||
# Use ISO format if no timezone is set
|
||||
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
|
||||
|
||||
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
|
||||
|
||||
return formatted_results, len(formatted_results)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def passage_size(
|
||||
|
||||
@@ -284,92 +284,19 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
str: Query result string containing matching passages with timestamps, content, and tags.
|
||||
"""
|
||||
try:
|
||||
# Parse datetime parameters if provided
|
||||
from datetime import datetime
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if start_datetime:
|
||||
try:
|
||||
# Try parsing as full datetime first (with time)
|
||||
start_date = datetime.fromisoformat(start_datetime)
|
||||
except ValueError:
|
||||
try:
|
||||
# Fall back to date-only format
|
||||
start_date = datetime.strptime(start_datetime, "%Y-%m-%d")
|
||||
# Set to beginning of day
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
|
||||
)
|
||||
|
||||
# Apply agent's timezone if datetime is naive
|
||||
if start_date.tzinfo is None and agent_state.timezone:
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
start_date = start_date.replace(tzinfo=tz)
|
||||
|
||||
if end_datetime:
|
||||
try:
|
||||
# Try parsing as full datetime first (with time)
|
||||
end_date = datetime.fromisoformat(end_datetime)
|
||||
except ValueError:
|
||||
try:
|
||||
# Fall back to date-only format
|
||||
end_date = datetime.strptime(end_datetime, "%Y-%m-%d")
|
||||
# Set to end of day for end dates
|
||||
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
|
||||
)
|
||||
|
||||
# Apply agent's timezone if datetime is naive
|
||||
if end_date.tzinfo is None and agent_state.timezone:
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
end_date = end_date.replace(tzinfo=tz)
|
||||
|
||||
# Convert string to TagMatchMode enum
|
||||
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
|
||||
|
||||
# Get results using passage manager
|
||||
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
all_results = await self.agent_manager.query_agent_passages_async(
|
||||
actor=actor,
|
||||
# Use the shared service method to get results
|
||||
formatted_results, count = await self.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=agent_state.id,
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
embed_query=True,
|
||||
actor=actor,
|
||||
query=query,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_mode,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
tag_match_mode=tag_match_mode,
|
||||
top_k=top_k,
|
||||
start_datetime=start_datetime,
|
||||
end_datetime=end_datetime,
|
||||
)
|
||||
|
||||
# Format results to include tags with friendly timestamps
|
||||
formatted_results = []
|
||||
for result in all_results:
|
||||
# Format timestamp in agent's timezone if available
|
||||
timestamp = result.created_at
|
||||
if timestamp and agent_state.timezone:
|
||||
try:
|
||||
# Convert to agent's timezone
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
local_time = timestamp.astimezone(tz)
|
||||
# Format as ISO string with timezone
|
||||
formatted_timestamp = local_time.isoformat()
|
||||
except Exception:
|
||||
# Fallback to ISO format if timezone conversion fails
|
||||
formatted_timestamp = str(timestamp)
|
||||
else:
|
||||
# Use ISO format if no timezone is set
|
||||
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
|
||||
|
||||
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
|
||||
|
||||
return formatted_results, len(formatted_results)
|
||||
return formatted_results, count
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -3607,7 +3607,7 @@ def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sara
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passage_tags_functionality(server: SyncServer, default_user, sarah_agent):
|
||||
async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServer, default_user, sarah_agent):
|
||||
"""Test comprehensive tag functionality for passages."""
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
|
||||
@@ -3974,6 +3974,168 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age
|
||||
assert set(passages_case[0].tags) == set(case_tags)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_agent_archival_memory_async(disable_turbopuffer, server: SyncServer, default_user, sarah_agent):
|
||||
"""Test the search_agent_archival_memory_async method that powers both the agent tool and API endpoint."""
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
|
||||
# Create test passages with various content and tags
|
||||
test_data = [
|
||||
{
|
||||
"text": "Python is a powerful programming language used for data science and web development.",
|
||||
"tags": ["python", "programming", "data-science", "web"],
|
||||
"created_at": datetime(2024, 1, 15, 10, 30, tzinfo=timezone.utc),
|
||||
},
|
||||
{
|
||||
"text": "Machine learning algorithms can be implemented in Python using libraries like scikit-learn.",
|
||||
"tags": ["python", "machine-learning", "algorithms"],
|
||||
"created_at": datetime(2024, 1, 16, 14, 45, tzinfo=timezone.utc),
|
||||
},
|
||||
{
|
||||
"text": "JavaScript is essential for frontend web development and modern web applications.",
|
||||
"tags": ["javascript", "frontend", "web"],
|
||||
"created_at": datetime(2024, 1, 17, 9, 15, tzinfo=timezone.utc),
|
||||
},
|
||||
{
|
||||
"text": "Database design principles are important for building scalable applications.",
|
||||
"tags": ["database", "design", "scalability"],
|
||||
"created_at": datetime(2024, 1, 18, 16, 20, tzinfo=timezone.utc),
|
||||
},
|
||||
{
|
||||
"text": "The weather today is sunny and warm, perfect for outdoor activities.",
|
||||
"tags": ["weather", "outdoor"],
|
||||
"created_at": datetime(2024, 1, 19, 11, 0, tzinfo=timezone.utc),
|
||||
},
|
||||
]
|
||||
|
||||
# Create passages in the database
|
||||
created_passages = []
|
||||
for data in test_data:
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
text=data["text"],
|
||||
archive_id=archive.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1, 0.2, 0.3], # Mock embedding
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
tags=data["tags"],
|
||||
created_at=data["created_at"],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
created_passages.append(passage)
|
||||
|
||||
# Test 1: Basic search by query text
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="Python programming"
|
||||
)
|
||||
|
||||
assert count > 0
|
||||
assert len(results) == count
|
||||
|
||||
# Check structure of results
|
||||
for result in results:
|
||||
assert "timestamp" in result
|
||||
assert "content" in result
|
||||
assert "tags" in result
|
||||
assert isinstance(result["tags"], list)
|
||||
|
||||
# Test 2: Search with tag filtering - single tag
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="programming", tags=["python"]
|
||||
)
|
||||
|
||||
assert count > 0
|
||||
# All results should have "python" tag
|
||||
for result in results:
|
||||
assert "python" in result["tags"]
|
||||
|
||||
# Test 3: Search with tag filtering - multiple tags with "any" mode
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="development", tags=["web", "database"], tag_match_mode="any"
|
||||
)
|
||||
|
||||
assert count > 0
|
||||
# All results should have at least one of the specified tags
|
||||
for result in results:
|
||||
assert any(tag in result["tags"] for tag in ["web", "database"])
|
||||
|
||||
# Test 4: Search with tag filtering - multiple tags with "all" mode
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="Python", tags=["python", "web"], tag_match_mode="all"
|
||||
)
|
||||
|
||||
# Should only return results that have BOTH tags
|
||||
for result in results:
|
||||
assert "python" in result["tags"]
|
||||
assert "web" in result["tags"]
|
||||
|
||||
# Test 5: Search with top_k limit
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="programming", top_k=2
|
||||
)
|
||||
|
||||
assert count <= 2
|
||||
assert len(results) <= 2
|
||||
|
||||
# Test 6: Search with datetime filtering
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="programming", start_datetime="2024-01-16", end_datetime="2024-01-17"
|
||||
)
|
||||
|
||||
# Should only include passages created between those dates
|
||||
for result in results:
|
||||
# Parse timestamp to verify it's in range
|
||||
timestamp_str = result["timestamp"]
|
||||
# Basic validation that timestamp exists and has expected format
|
||||
assert "2024-01-16" in timestamp_str or "2024-01-17" in timestamp_str
|
||||
|
||||
# Test 7: Search with ISO datetime format
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query="algorithms",
|
||||
start_datetime="2024-01-16T14:00:00",
|
||||
end_datetime="2024-01-16T15:00:00",
|
||||
)
|
||||
|
||||
# Should include the machine learning passage created at 14:45
|
||||
assert count >= 0 # Might be 0 if no results, but shouldn't error
|
||||
|
||||
# Test 8: Search with non-existent agent should raise error
|
||||
non_existent_agent_id = "agent-00000000-0000-4000-8000-000000000000"
|
||||
|
||||
with pytest.raises(Exception): # Should raise NoResultFound or similar
|
||||
await server.agent_manager.search_agent_archival_memory_async(agent_id=non_existent_agent_id, actor=default_user, query="test")
|
||||
|
||||
# Test 9: Search with invalid datetime format should raise ValueError
|
||||
with pytest.raises(ValueError, match="Invalid start_datetime format"):
|
||||
await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="test", start_datetime="invalid-date"
|
||||
)
|
||||
|
||||
# Test 10: Empty query should return empty results
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query="")
|
||||
|
||||
assert count == 0 # Empty query should return 0 results
|
||||
assert len(results) == 0
|
||||
|
||||
# Test 11: Whitespace-only query should also return empty results
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query=" \n\t "
|
||||
)
|
||||
|
||||
assert count == 0 # Whitespace-only query should return 0 results
|
||||
assert len(results) == 0
|
||||
|
||||
# Cleanup - delete the created passages
|
||||
for passage in created_passages:
|
||||
await server.passage_manager.delete_agent_passage_by_id_async(passage_id=passage.id, actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# User Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user