From 6f72ac0c3acfa7ed5de99558497e16fcb376f096 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 22 May 2025 15:33:27 -0700 Subject: [PATCH] feat(asyncify): agent batch sync db calls (#2348) --- letta/agents/base_agent.py | 6 ++++-- tests/test_letta_agent_batch.py | 22 ++++++++++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index a349366d..69342758 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -100,8 +100,10 @@ class BaseAgent(ABC): # [DB Call] size of messages and archival memories # todo: blocking for now - num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + if num_messages is None: + num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) + if num_archival_memories is None: + num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) new_system_message_str = compile_system_message( system_prompt=agent_state.system, diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index da2a6666..bc8f7e8e 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timezone from typing import Tuple from unittest.mock import AsyncMock, patch @@ -481,7 +482,10 @@ async def test_partial_error_from_anthropic_batch( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -545,7 +549,7 @@ async def test_partial_error_from_anthropic_batch( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) if agent.id == agents_failed[0].id: assert after == before, f"Agent {agent.id} should not have extra messages persisted due to Anthropic failure" @@ -643,7 +647,10 @@ async def test_resume_step_some_stop( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -703,7 +710,7 @@ async def test_resume_step_some_stop( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) assert after - before >= 2, ( f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted." ) @@ -803,7 +810,10 @@ async def test_resume_step_after_request_all_continue( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -860,7 +870,7 @@ async def test_resume_step_after_request_all_continue( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) assert after - before >= 2, ( f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted." )