feat(asyncify): migrate list passages (#2275)

This commit is contained in:
cthomas
2025-05-20 12:56:04 -07:00
committed by GitHub
parent 55497cd64f
commit 710c03dbd1
5 changed files with 130 additions and 32 deletions

View File

@@ -438,7 +438,7 @@ class VoiceAgent(BaseAgent):
if start_date and end_date and start_date > end_date:
start_date, end_date = end_date, start_date
archival_results = self.agent_manager.list_passages(
archival_results = await self.agent_manager.list_passages_async(
actor=self.actor,
agent_id=self.agent_id,
query_text=archival_query,

View File

@@ -484,7 +484,7 @@ def detach_block(
@router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages")
def list_passages(
async def list_passages(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[str] = Query(None, description="Unique ID of the memory to start the query range at."),
@@ -499,11 +499,11 @@ def list_passages(
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.get_agent_archival(
user_id=actor.id,
return await server.get_agent_archival_async(
agent_id=agent_id,
actor=actor,
after=after,
before=before,
query_text=search,

View File

@@ -1002,6 +1002,30 @@ class SyncServer(Server):
)
return records
async def get_agent_archival_async(
self,
agent_id: str,
actor: User,
after: Optional[str] = None,
before: Optional[str] = None,
limit: Optional[int] = 100,
order_by: Optional[str] = "created_at",
reverse: Optional[bool] = False,
query_text: Optional[str] = None,
ascending: Optional[bool] = True,
) -> List[Passage]:
# iterate over records
records = await self.agent_manager.list_passages_async(
actor=actor,
agent_id=agent_id,
after=after,
query_text=query_text,
before=before,
ascending=ascending,
limit=limit,
)
return records
def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
# Get the agent object (loaded in memory)
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)

View File

@@ -1961,6 +1961,65 @@ class AgentManager:
return [p.to_pydantic() for p in passages]
@enforce_types
async def list_passages_async(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
before: Optional[str] = None,
after: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
async with db_registry.async_session() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
before=before,
after=after,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Add limit
if limit:
main_query = main_query.limit(limit)
# Execute query
result = await session.execute(main_query)
passages = []
for row in result:
data = dict(row._mapping)
if data["agent_id"] is not None:
# This is an AgentPassage - remove source fields
data.pop("source_id", None)
data.pop("file_id", None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop("agent_id", None)
passage = SourcePassage(**data)
passages.append(passage)
return [p.to_pydantic() for p in passages]
@enforce_types
def passage_size(
self,

View File

@@ -1817,41 +1817,44 @@ async def test_refresh_memory_async(server: SyncServer, default_user, event_loop
# ======================================================================================================================
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
@pytest.mark.asyncio
async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup, event_loop):
"""Test basic listing functionality of agent passages"""
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
all_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id)
assert len(all_passages) == 5 # 3 source + 2 agent passages
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
@pytest.mark.asyncio
async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup, event_loop):
"""Test ordering of agent passages"""
# Test ascending order
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
asc_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, ascending=True)
assert len(asc_passages) == 5
for i in range(1, len(asc_passages)):
assert asc_passages[i - 1].created_at <= asc_passages[i].created_at
# Test descending order
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
desc_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, ascending=False)
assert len(desc_passages) == 5
for i in range(1, len(desc_passages)):
assert desc_passages[i - 1].created_at >= desc_passages[i].created_at
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
@pytest.mark.asyncio
async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup, event_loop):
"""Test pagination of agent passages"""
# Test limit
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
limited_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=3)
assert len(limited_passages) == 3
# Test cursor-based pagination
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
first_page = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
assert len(first_page) == 2
second_page = server.agent_manager.list_passages(
second_page = await server.agent_manager.list_passages_async(
actor=default_user, agent_id=sarah_agent.id, after=first_page[-1].id, limit=2, ascending=True
)
assert len(second_page) == 2
@@ -1865,14 +1868,14 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
[mid]
* | * * | *
"""
middle_page = server.agent_manager.list_passages(
middle_page = await server.agent_manager.list_passages_async(
actor=default_user, agent_id=sarah_agent.id, before=second_page[-1].id, after=first_page[0].id, ascending=True
)
assert len(middle_page) == 2
assert middle_page[0].id == first_page[-1].id
assert middle_page[1].id == second_page[0].id
middle_page_desc = server.agent_manager.list_passages(
middle_page_desc = await server.agent_manager.list_passages_async(
actor=default_user, agent_id=sarah_agent.id, before=second_page[-1].id, after=first_page[0].id, ascending=False
)
assert len(middle_page_desc) == 2
@@ -1880,31 +1883,40 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
assert middle_page_desc[1].id == first_page[-1].id
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
@pytest.mark.asyncio
async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup, event_loop):
"""Test text search functionality of agent passages"""
# Test text search for source passages
source_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Source passage")
source_text_passages = await server.agent_manager.list_passages_async(
actor=default_user, agent_id=sarah_agent.id, query_text="Source passage"
)
assert len(source_text_passages) == 3
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage")
agent_text_passages = await server.agent_manager.list_passages_async(
actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage"
)
assert len(agent_text_passages) == 2
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
@pytest.mark.asyncio
async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup, event_loop):
"""Test text search functionality of agent passages"""
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
agent_text_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
assert len(agent_text_passages) == 2
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
@pytest.mark.asyncio
async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup, event_loop):
"""Test filtering functionality of agent passages"""
# Test source filtering
source_filtered = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id)
source_filtered = await server.agent_manager.list_passages_async(
actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id
)
assert len(source_filtered) == 3
# Test date filtering
@@ -1912,13 +1924,14 @@ def test_agent_list_passages_filtering(server, default_user, sarah_agent, defaul
future_date = now + timedelta(days=1)
past_date = now - timedelta(days=1)
date_filtered = server.agent_manager.list_passages(
date_filtered = await server.agent_manager.list_passages_async(
actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date
)
assert len(date_filtered) == 5
def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source):
@pytest.mark.asyncio
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, event_loop):
"""Test vector search functionality of agent passages"""
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
@@ -1959,7 +1972,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
query_key = "What's my favorite color?"
# Test vector search with all passages
results = server.agent_manager.list_passages(
results = await server.agent_manager.list_passages_async(
actor=default_user,
agent_id=sarah_agent.id,
query_text=query_key,
@@ -1974,7 +1987,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
assert "blue" in results[1].text or "blue" in results[2].text
# Test vector search with agent_only=True
agent_only_results = server.agent_manager.list_passages(
agent_only_results = await server.agent_manager.list_passages_async(
actor=default_user,
agent_id=sarah_agent.id,
query_text=query_key,
@@ -1989,11 +2002,12 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
assert agent_only_results[1].text == "blue shoes"
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
@pytest.mark.asyncio
async def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup, event_loop):
"""Test listing passages from a source without specifying an agent."""
# List passages by source_id without agent_id
source_passages = server.agent_manager.list_passages(
source_passages = await server.agent_manager.list_passages_async(
actor=default_user,
source_id=default_source.id,
)
@@ -2127,8 +2141,9 @@ def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_pas
assert retrieved.text == source_passage_fixture.text
def test_passage_cascade_deletion(
server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent
@pytest.mark.asyncio
async def test_passage_cascade_deletion(
server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent, event_loop
):
"""Test that passages are deleted when their parent (agent or source) is deleted."""
# Verify passages exist
@@ -2139,7 +2154,7 @@ def test_passage_cascade_deletion(
# Delete agent and verify its passages are deleted
server.agent_manager.delete_agent(sarah_agent.id, default_user)
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
agentic_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
assert len(agentic_passages) == 0
# Delete source and verify its passages are deleted