diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index ec4c7ebb..bd7908a8 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -328,3 +328,14 @@ class CreateArchivalMemory(BaseModel): text: str = Field(..., description="Text to write to archival memory.") tags: Optional[List[str]] = Field(None, description="Optional list of tags to attach to the memory.") created_at: Optional[datetime] = Field(None, description="Optional timestamp for the memory (defaults to current UTC time).") + + +class ArchivalMemorySearchResult(BaseModel): + timestamp: str = Field(..., description="Timestamp of when the memory was created, formatted in agent's timezone") + content: str = Field(..., description="Text content of the archival memory passage") + tags: List[str] = Field(default_factory=list, description="List of tags associated with this memory") + + +class ArchivalMemorySearchResponse(BaseModel): + results: List[ArchivalMemorySearchResult] = Field(..., description="List of search results matching the query") + count: int = Field(..., description="Total number of results returned") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 3962e787..06a66067 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -2,7 +2,7 @@ import asyncio import json import traceback from datetime import datetime, timezone -from typing import Annotated, Any, Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Union from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile, status from fastapi.responses import JSONResponse @@ -32,7 +32,13 @@ from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse -from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory +from letta.schemas.memory import ( + ArchivalMemorySearchResponse, + ArchivalMemorySearchResult, + ContextWindowOverview, + CreateArchivalMemory, + Memory, +) from letta.schemas.message import MessageCreate from letta.schemas.passage import Passage from letta.schemas.run import Run @@ -978,6 +984,55 @@ async def create_passage( ) +@router.get("/{agent_id}/archival-memory/search", response_model=ArchivalMemorySearchResponse, operation_id="search_archival_memory") +async def search_archival_memory( + agent_id: str, + query: str = Query(..., description="String to search for using semantic similarity"), + tags: Optional[List[str]] = Query(None, description="Optional list of tags to filter search results"), + tag_match_mode: Literal["any", "all"] = Query( + "any", description="How to match tags - 'any' to match passages with any of the tags, 'all' to match only passages with all tags" + ), + top_k: Optional[int] = Query(None, description="Maximum number of results to return. Uses system default if not specified"), + start_datetime: Optional[str] = Query(None, description="Filter results to passages created after this datetime. ISO 8601 format"), + end_datetime: Optional[str] = Query(None, description="Filter results to passages created before this datetime. ISO 8601 format"), + server: "SyncServer" = Depends(get_letta_server), + actor_id: str | None = Header(None, alias="user_id"), +): + """ + Search archival memory using semantic (embedding-based) search with optional temporal filtering. + + This endpoint allows manual triggering of archival memory searches, enabling users to query + an agent's archival memory store directly via the API. The search uses the same functionality + as the agent's archival_memory_search tool but is accessible for external API usage. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + + try: + # Use the shared agent manager method + formatted_results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=agent_id, + actor=actor, + query=query, + tags=tags, + tag_match_mode=tag_match_mode, + top_k=top_k, + start_datetime=start_datetime, + end_datetime=end_datetime, + ) + + # Convert to proper response schema + search_results = [ArchivalMemorySearchResult(**result) for result in formatted_results] + + return ArchivalMemorySearchResponse(results=search_results, count=count) + + except NoResultFound as e: + raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Internal server error during archival memory search: {str(e)}") + + # TODO(ethan): query or path parameter for memory_id? # @router.delete("/{agent_id}/archival") @router.delete("/{agent_id}/archival-memory/{memory_id}", response_model=None, operation_id="delete_passage") diff --git a/letta/server/server.py b/letta/server/server.py index aa3ed3b1..46d7e9d5 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1536,7 +1536,16 @@ class SyncServer(Server): local_configs = self.get_local_llm_configs() llm_models.extend(local_configs) - return llm_models + # dedupe by handle for uniqueness + # Seems like this is required from the tests? + seen_handles = set() + unique_models = [] + for model in llm_models: + if model.handle not in seen_handles: + seen_handles.add(model.handle) + unique_models.append(model) + + return unique_models def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]: """List available embedding models""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 01214505..053b781e 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1,6 +1,7 @@ import asyncio from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Literal, Optional, Set, Tuple +from zoneinfo import ZoneInfo import sqlalchemy as sa from sqlalchemy import delete, func, insert, literal, or_, select, tuple_ @@ -21,6 +22,7 @@ from letta.constants import ( EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES, FILES_TOOLS, INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES, + RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE, ) from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time @@ -2752,6 +2754,127 @@ class AgentManager: return pydantic_passages + @enforce_types + @trace_method + async def search_agent_archival_memory_async( + self, + agent_id: str, + actor: PydanticUser, + query: str, + tags: Optional[List[str]] = None, + tag_match_mode: Literal["any", "all"] = "any", + top_k: Optional[int] = None, + start_datetime: Optional[str] = None, + end_datetime: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """ + Search archival memory using semantic (embedding-based) search with optional temporal filtering. + + This is a shared method used by both the agent tool and API endpoint to ensure consistent behavior. + + Args: + agent_id: ID of the agent whose archival memory to search + actor: User performing the search + query: String to search for using semantic similarity + tags: Optional list of tags to filter search results + tag_match_mode: How to match tags - "any" or "all" + top_k: Maximum number of results to return + start_datetime: Filter results after this datetime (ISO 8601 format) + end_datetime: Filter results before this datetime (ISO 8601 format) + + Returns: + Tuple of (formatted_results, count) + """ + # Handle empty or whitespace-only queries + if not query or not query.strip(): + return [], 0 + + # Get the agent to access timezone and embedding config + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + + # Parse datetime parameters if provided + start_date = None + end_date = None + + if start_datetime: + try: + # Try parsing as full datetime first (with time) + start_date = datetime.fromisoformat(start_datetime) + except ValueError: + try: + # Fall back to date-only format + start_date = datetime.strptime(start_datetime, "%Y-%m-%d") + # Set to beginning of day + start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) + except ValueError: + raise ValueError( + f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)" + ) + + # Apply agent's timezone if datetime is naive + if start_date.tzinfo is None and agent_state.timezone: + tz = ZoneInfo(agent_state.timezone) + start_date = start_date.replace(tzinfo=tz) + + if end_datetime: + try: + # Try parsing as full datetime first (with time) + end_date = datetime.fromisoformat(end_datetime) + except ValueError: + try: + # Fall back to date-only format + end_date = datetime.strptime(end_datetime, "%Y-%m-%d") + # Set to end of day for end dates + end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999) + except ValueError: + raise ValueError(f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)") + + # Apply agent's timezone if datetime is naive + if end_date.tzinfo is None and agent_state.timezone: + tz = ZoneInfo(agent_state.timezone) + end_date = end_date.replace(tzinfo=tz) + + # Convert string to TagMatchMode enum + tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL + + # Get results using existing passage query method + limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + all_results = await self.query_agent_passages_async( + actor=actor, + agent_id=agent_id, + query_text=query, + limit=limit, + embedding_config=agent_state.embedding_config, + embed_query=True, + tags=tags, + tag_match_mode=tag_mode, + start_date=start_date, + end_date=end_date, + ) + + # Format results to include tags with friendly timestamps + formatted_results = [] + for result in all_results: + # Format timestamp in agent's timezone if available + timestamp = result.created_at + if timestamp and agent_state.timezone: + try: + # Convert to agent's timezone + tz = ZoneInfo(agent_state.timezone) + local_time = timestamp.astimezone(tz) + # Format as ISO string with timezone + formatted_timestamp = local_time.isoformat() + except Exception: + # Fallback to ISO format if timezone conversion fails + formatted_timestamp = str(timestamp) + else: + # Use ISO format if no timezone is set + formatted_timestamp = str(timestamp) if timestamp else "Unknown" + + formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []}) + + return formatted_results, len(formatted_results) + @enforce_types @trace_method async def passage_size( diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index ac28cf31..5561e00d 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -284,92 +284,19 @@ class LettaCoreToolExecutor(ToolExecutor): str: Query result string containing matching passages with timestamps, content, and tags. """ try: - # Parse datetime parameters if provided - from datetime import datetime - - start_date = None - end_date = None - - if start_datetime: - try: - # Try parsing as full datetime first (with time) - start_date = datetime.fromisoformat(start_datetime) - except ValueError: - try: - # Fall back to date-only format - start_date = datetime.strptime(start_datetime, "%Y-%m-%d") - # Set to beginning of day - start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) - except ValueError: - raise ValueError( - f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)" - ) - - # Apply agent's timezone if datetime is naive - if start_date.tzinfo is None and agent_state.timezone: - tz = ZoneInfo(agent_state.timezone) - start_date = start_date.replace(tzinfo=tz) - - if end_datetime: - try: - # Try parsing as full datetime first (with time) - end_date = datetime.fromisoformat(end_datetime) - except ValueError: - try: - # Fall back to date-only format - end_date = datetime.strptime(end_datetime, "%Y-%m-%d") - # Set to end of day for end dates - end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999) - except ValueError: - raise ValueError( - f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)" - ) - - # Apply agent's timezone if datetime is naive - if end_date.tzinfo is None and agent_state.timezone: - tz = ZoneInfo(agent_state.timezone) - end_date = end_date.replace(tzinfo=tz) - - # Convert string to TagMatchMode enum - tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL - - # Get results using passage manager - limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - all_results = await self.agent_manager.query_agent_passages_async( - actor=actor, + # Use the shared service method to get results + formatted_results, count = await self.agent_manager.search_agent_archival_memory_async( agent_id=agent_state.id, - query_text=query, - limit=limit, - embedding_config=agent_state.embedding_config, - embed_query=True, + actor=actor, + query=query, tags=tags, - tag_match_mode=tag_mode, - start_date=start_date, - end_date=end_date, + tag_match_mode=tag_match_mode, + top_k=top_k, + start_datetime=start_datetime, + end_datetime=end_datetime, ) - # Format results to include tags with friendly timestamps - formatted_results = [] - for result in all_results: - # Format timestamp in agent's timezone if available - timestamp = result.created_at - if timestamp and agent_state.timezone: - try: - # Convert to agent's timezone - tz = ZoneInfo(agent_state.timezone) - local_time = timestamp.astimezone(tz) - # Format as ISO string with timezone - formatted_timestamp = local_time.isoformat() - except Exception: - # Fallback to ISO format if timezone conversion fails - formatted_timestamp = str(timestamp) - else: - # Use ISO format if no timezone is set - formatted_timestamp = str(timestamp) if timestamp else "Unknown" - - formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []}) - - return formatted_results, len(formatted_results) + return formatted_results, count except Exception as e: raise e diff --git a/tests/test_managers.py b/tests/test_managers.py index b1ebdfdf..5fd7f0bc 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3607,7 +3607,7 @@ def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sara @pytest.mark.asyncio -async def test_passage_tags_functionality(server: SyncServer, default_user, sarah_agent): +async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServer, default_user, sarah_agent): """Test comprehensive tag functionality for passages.""" from letta.schemas.enums import TagMatchMode @@ -3974,6 +3974,168 @@ async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_age assert set(passages_case[0].tags) == set(case_tags) +@pytest.mark.asyncio +async def test_search_agent_archival_memory_async(disable_turbopuffer, server: SyncServer, default_user, sarah_agent): + """Test the search_agent_archival_memory_async method that powers both the agent tool and API endpoint.""" + # Get or create default archive for the agent + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + + # Create test passages with various content and tags + test_data = [ + { + "text": "Python is a powerful programming language used for data science and web development.", + "tags": ["python", "programming", "data-science", "web"], + "created_at": datetime(2024, 1, 15, 10, 30, tzinfo=timezone.utc), + }, + { + "text": "Machine learning algorithms can be implemented in Python using libraries like scikit-learn.", + "tags": ["python", "machine-learning", "algorithms"], + "created_at": datetime(2024, 1, 16, 14, 45, tzinfo=timezone.utc), + }, + { + "text": "JavaScript is essential for frontend web development and modern web applications.", + "tags": ["javascript", "frontend", "web"], + "created_at": datetime(2024, 1, 17, 9, 15, tzinfo=timezone.utc), + }, + { + "text": "Database design principles are important for building scalable applications.", + "tags": ["database", "design", "scalability"], + "created_at": datetime(2024, 1, 18, 16, 20, tzinfo=timezone.utc), + }, + { + "text": "The weather today is sunny and warm, perfect for outdoor activities.", + "tags": ["weather", "outdoor"], + "created_at": datetime(2024, 1, 19, 11, 0, tzinfo=timezone.utc), + }, + ] + + # Create passages in the database + created_passages = [] + for data in test_data: + passage = await server.passage_manager.create_agent_passage_async( + PydanticPassage( + text=data["text"], + archive_id=archive.id, + organization_id=default_user.organization_id, + embedding=[0.1, 0.2, 0.3], # Mock embedding + embedding_config=DEFAULT_EMBEDDING_CONFIG, + tags=data["tags"], + created_at=data["created_at"], + ), + actor=default_user, + ) + created_passages.append(passage) + + # Test 1: Basic search by query text + results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, actor=default_user, query="Python programming" + ) + + assert count > 0 + assert len(results) == count + + # Check structure of results + for result in results: + assert "timestamp" in result + assert "content" in result + assert "tags" in result + assert isinstance(result["tags"], list) + + # Test 2: Search with tag filtering - single tag + results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, actor=default_user, query="programming", tags=["python"] + ) + + assert count > 0 + # All results should have "python" tag + for result in results: + assert "python" in result["tags"] + + # Test 3: Search with tag filtering - multiple tags with "any" mode + results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, actor=default_user, query="development", tags=["web", "database"], tag_match_mode="any" + ) + + assert count > 0 + # All results should have at least one of the specified tags + for result in results: + assert any(tag in result["tags"] for tag in ["web", "database"]) + + # Test 4: Search with tag filtering - multiple tags with "all" mode + results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, actor=default_user, query="Python", tags=["python", "web"], tag_match_mode="all" + ) + + # Should only return results that have BOTH tags + for result in results: + assert "python" in result["tags"] + assert "web" in result["tags"] + + # Test 5: Search with top_k limit + results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, actor=default_user, query="programming", top_k=2 + ) + + assert count <= 2 + assert len(results) <= 2 + + # Test 6: Search with datetime filtering + results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, actor=default_user, query="programming", start_datetime="2024-01-16", end_datetime="2024-01-17" + ) + + # Should only include passages created between those dates + for result in results: + # Parse timestamp to verify it's in range + timestamp_str = result["timestamp"] + # Basic validation that timestamp exists and has expected format + assert "2024-01-16" in timestamp_str or "2024-01-17" in timestamp_str + + # Test 7: Search with ISO datetime format + results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, + actor=default_user, + query="algorithms", + start_datetime="2024-01-16T14:00:00", + end_datetime="2024-01-16T15:00:00", + ) + + # Should include the machine learning passage created at 14:45 + assert count >= 0 # Might be 0 if no results, but shouldn't error + + # Test 8: Search with non-existent agent should raise error + non_existent_agent_id = "agent-00000000-0000-4000-8000-000000000000" + + with pytest.raises(Exception): # Should raise NoResultFound or similar + await server.agent_manager.search_agent_archival_memory_async(agent_id=non_existent_agent_id, actor=default_user, query="test") + + # Test 9: Search with invalid datetime format should raise ValueError + with pytest.raises(ValueError, match="Invalid start_datetime format"): + await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, actor=default_user, query="test", start_datetime="invalid-date" + ) + + # Test 10: Empty query should return empty results + results, count = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query="") + + assert count == 0 # Empty query should return 0 results + assert len(results) == 0 + + # Test 11: Whitespace-only query should also return empty results + results, count = await server.agent_manager.search_agent_archival_memory_async( + agent_id=sarah_agent.id, actor=default_user, query=" \n\t " + ) + + assert count == 0 # Whitespace-only query should return 0 results + assert len(results) == 0 + + # Cleanup - delete the created passages + for passage in created_passages: + await server.passage_manager.delete_agent_passage_by_id_async(passage_id=passage.id, actor=default_user) + + # ====================================================================================================================== # User Manager Tests # ======================================================================================================================