From 710c03dbd1ebd4c5c6bf9db759a906705cde965e Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 20 May 2025 12:56:04 -0700 Subject: [PATCH] feat(asyncify): migrate list passages (#2275) --- letta/agents/voice_agent.py | 2 +- letta/server/rest_api/routers/v1/agents.py | 8 +-- letta/server/server.py | 24 ++++++++ letta/services/agent_manager.py | 59 ++++++++++++++++++ tests/test_managers.py | 69 +++++++++++++--------- 5 files changed, 130 insertions(+), 32 deletions(-) diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 1d0ab88c..3926ce5a 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -438,7 +438,7 @@ class VoiceAgent(BaseAgent): if start_date and end_date and start_date > end_date: start_date, end_date = end_date, start_date - archival_results = self.agent_manager.list_passages( + archival_results = await self.agent_manager.list_passages_async( actor=self.actor, agent_id=self.agent_id, query_text=archival_query, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index bc8609f4..d0fd2545 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -484,7 +484,7 @@ def detach_block( @router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages") -def list_passages( +async def list_passages( agent_id: str, server: "SyncServer" = Depends(get_letta_server), after: Optional[str] = Query(None, description="Unique ID of the memory to start the query range at."), @@ -499,11 +499,11 @@ def list_passages( """ Retrieve the memories in an agent's archival memory store (paginated query). """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return server.get_agent_archival( - user_id=actor.id, + return await server.get_agent_archival_async( agent_id=agent_id, + actor=actor, after=after, before=before, query_text=search, diff --git a/letta/server/server.py b/letta/server/server.py index 7b5bc4f6..512dd4d9 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1002,6 +1002,30 @@ class SyncServer(Server): ) return records + async def get_agent_archival_async( + self, + agent_id: str, + actor: User, + after: Optional[str] = None, + before: Optional[str] = None, + limit: Optional[int] = 100, + order_by: Optional[str] = "created_at", + reverse: Optional[bool] = False, + query_text: Optional[str] = None, + ascending: Optional[bool] = True, + ) -> List[Passage]: + # iterate over records + records = await self.agent_manager.list_passages_async( + actor=actor, + agent_id=agent_id, + after=after, + query_text=query_text, + before=before, + ascending=ascending, + limit=limit, + ) + return records + def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]: # Get the agent object (loaded in memory) agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 603388e1..326dd30f 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1961,6 +1961,65 @@ class AgentManager: return [p.to_pydantic() for p in passages] + @enforce_types + async def list_passages_async( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + limit: Optional[int] = 50, + query_text: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + before: Optional[str] = None, + after: Optional[str] = None, + source_id: Optional[str] = None, + embed_query: bool = False, + ascending: bool = True, + embedding_config: Optional[EmbeddingConfig] = None, + agent_only: bool = False, + ) -> List[PydanticPassage]: + """Lists all passages attached to an agent.""" + async with db_registry.async_session() as session: + main_query = self._build_passage_query( + actor=actor, + agent_id=agent_id, + file_id=file_id, + query_text=query_text, + start_date=start_date, + end_date=end_date, + before=before, + after=after, + source_id=source_id, + embed_query=embed_query, + ascending=ascending, + embedding_config=embedding_config, + agent_only=agent_only, + ) + + # Add limit + if limit: + main_query = main_query.limit(limit) + + # Execute query + result = await session.execute(main_query) + + passages = [] + for row in result: + data = dict(row._mapping) + if data["agent_id"] is not None: + # This is an AgentPassage - remove source fields + data.pop("source_id", None) + data.pop("file_id", None) + passage = AgentPassage(**data) + else: + # This is a SourcePassage - remove agent field + data.pop("agent_id", None) + passage = SourcePassage(**data) + passages.append(passage) + + return [p.to_pydantic() for p in passages] + @enforce_types def passage_size( self, diff --git a/tests/test_managers.py b/tests/test_managers.py index fc0ff7f8..ca7fa6ea 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1817,41 +1817,44 @@ async def test_refresh_memory_async(server: SyncServer, default_user, event_loop # ====================================================================================================================== -def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test basic listing functionality of agent passages""" - all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id) + all_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id) assert len(all_passages) == 5 # 3 source + 2 agent passages -def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test ordering of agent passages""" # Test ascending order - asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True) + asc_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, ascending=True) assert len(asc_passages) == 5 for i in range(1, len(asc_passages)): assert asc_passages[i - 1].created_at <= asc_passages[i].created_at # Test descending order - desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False) + desc_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, ascending=False) assert len(desc_passages) == 5 for i in range(1, len(desc_passages)): assert desc_passages[i - 1].created_at >= desc_passages[i].created_at -def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test pagination of agent passages""" # Test limit - limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3) + limited_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=3) assert len(limited_passages) == 3 # Test cursor-based pagination - first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True) + first_page = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True) assert len(first_page) == 2 - second_page = server.agent_manager.list_passages( + second_page = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, after=first_page[-1].id, limit=2, ascending=True ) assert len(second_page) == 2 @@ -1865,14 +1868,14 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent [mid] * | * * | * """ - middle_page = server.agent_manager.list_passages( + middle_page = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, before=second_page[-1].id, after=first_page[0].id, ascending=True ) assert len(middle_page) == 2 assert middle_page[0].id == first_page[-1].id assert middle_page[1].id == second_page[0].id - middle_page_desc = server.agent_manager.list_passages( + middle_page_desc = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, before=second_page[-1].id, after=first_page[0].id, ascending=False ) assert len(middle_page_desc) == 2 @@ -1880,31 +1883,40 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent assert middle_page_desc[1].id == first_page[-1].id -def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test text search functionality of agent passages""" # Test text search for source passages - source_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Source passage") + source_text_passages = await server.agent_manager.list_passages_async( + actor=default_user, agent_id=sarah_agent.id, query_text="Source passage" + ) assert len(source_text_passages) == 3 # Test text search for agent passages - agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage") + agent_text_passages = await server.agent_manager.list_passages_async( + actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage" + ) assert len(agent_text_passages) == 2 -def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test text search functionality of agent passages""" # Test text search for agent passages - agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True) + agent_text_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, agent_only=True) assert len(agent_text_passages) == 2 -def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup, event_loop): """Test filtering functionality of agent passages""" # Test source filtering - source_filtered = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id) + source_filtered = await server.agent_manager.list_passages_async( + actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id + ) assert len(source_filtered) == 3 # Test date filtering @@ -1912,13 +1924,14 @@ def test_agent_list_passages_filtering(server, default_user, sarah_agent, defaul future_date = now + timedelta(days=1) past_date = now - timedelta(days=1) - date_filtered = server.agent_manager.list_passages( + date_filtered = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date ) assert len(date_filtered) == 5 -def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source): +@pytest.mark.asyncio +async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, event_loop): """Test vector search functionality of agent passages""" embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG) @@ -1959,7 +1972,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de query_key = "What's my favorite color?" # Test vector search with all passages - results = server.agent_manager.list_passages( + results = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, query_text=query_key, @@ -1974,7 +1987,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de assert "blue" in results[1].text or "blue" in results[2].text # Test vector search with agent_only=True - agent_only_results = server.agent_manager.list_passages( + agent_only_results = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, query_text=query_key, @@ -1989,11 +2002,12 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de assert agent_only_results[1].text == "blue shoes" -def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup): +@pytest.mark.asyncio +async def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup, event_loop): """Test listing passages from a source without specifying an agent.""" # List passages by source_id without agent_id - source_passages = server.agent_manager.list_passages( + source_passages = await server.agent_manager.list_passages_async( actor=default_user, source_id=default_source.id, ) @@ -2127,8 +2141,9 @@ def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_pas assert retrieved.text == source_passage_fixture.text -def test_passage_cascade_deletion( - server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent +@pytest.mark.asyncio +async def test_passage_cascade_deletion( + server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent, event_loop ): """Test that passages are deleted when their parent (agent or source) is deleted.""" # Verify passages exist @@ -2139,7 +2154,7 @@ def test_passage_cascade_deletion( # Delete agent and verify its passages are deleted server.agent_manager.delete_agent(sarah_agent.id, default_user) - agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True) + agentic_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, agent_only=True) assert len(agentic_passages) == 0 # Delete source and verify its passages are deleted