diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index a349366d..69342758 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -100,8 +100,10 @@ class BaseAgent(ABC): # [DB Call] size of messages and archival memories # todo: blocking for now - num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + if num_messages is None: + num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) + if num_archival_memories is None: + num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) new_system_message_str = compile_system_message( system_prompt=agent_state.system, diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index da2a6666..bc8f7e8e 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timezone from typing import Tuple from unittest.mock import AsyncMock, patch @@ -481,7 +482,10 @@ async def test_partial_error_from_anthropic_batch( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -545,7 +549,7 @@ async def test_partial_error_from_anthropic_batch( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) if agent.id == agents_failed[0].id: assert after == before, f"Agent {agent.id} should not have extra messages persisted due to Anthropic failure" @@ -643,7 +647,10 @@ async def test_resume_step_some_stop( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -703,7 +710,7 @@ async def test_resume_step_some_stop( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) assert after - before >= 2, ( f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted." ) @@ -803,7 +810,10 @@ async def test_resume_step_after_request_all_continue( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -860,7 +870,7 @@ async def test_resume_step_after_request_all_continue( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) assert after - before >= 2, ( f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted." )