feat(asyncify): agent batch sync db calls (#2348)

This commit is contained in:
cthomas
2025-05-22 15:33:27 -07:00
committed by GitHub
parent ed9f3e6abf
commit 6f72ac0c3a
2 changed files with 20 additions and 8 deletions

View File

@@ -100,8 +100,10 @@ class BaseAgent(ABC):
# [DB Call] size of messages and archival memories
# todo: blocking for now
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
if num_messages is None:
num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
if num_archival_memories is None:
num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,

View File

@@ -1,3 +1,4 @@
import asyncio
from datetime import datetime, timezone
from typing import Tuple
from unittest.mock import AsyncMock, patch
@@ -481,7 +482,10 @@ async def test_partial_error_from_anthropic_batch(
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
sizes = await asyncio.gather(
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
)
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
new_batch_responses = await poll_running_llm_batches(server)
@@ -545,7 +549,7 @@ async def test_partial_error_from_anthropic_batch(
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
if agent.id == agents_failed[0].id:
assert after == before, f"Agent {agent.id} should not have extra messages persisted due to Anthropic failure"
@@ -643,7 +647,10 @@ async def test_resume_step_some_stop(
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
sizes = await asyncio.gather(
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
)
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
new_batch_responses = await poll_running_llm_batches(server)
@@ -703,7 +710,7 @@ async def test_resume_step_some_stop(
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
assert after - before >= 2, (
f"Agent {agent.id} should have an assistant toolcall " f"and toolresponse message persisted."
)
@@ -803,7 +810,10 @@ async def test_resume_step_after_request_all_continue(
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
sizes = await asyncio.gather(
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
)
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
new_batch_responses = await poll_running_llm_batches(server)
@@ -860,7 +870,7 @@ async def test_resume_step_after_request_all_continue(
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
assert after - before >= 2, (
f"Agent {agent.id} should have an assistant toolcall " f"and toolresponse message persisted."
)