feat(asyncify): agent batch sync db calls (#2348)
This commit is contained in:
@@ -100,8 +100,10 @@ class BaseAgent(ABC):
|
||||
|
||||
# [DB Call] size of messages and archival memories
|
||||
# todo: blocking for now
|
||||
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
if num_messages is None:
|
||||
num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if num_archival_memories is None:
|
||||
num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Tuple
|
||||
from unittest.mock import AsyncMock, patch
|
||||
@@ -481,7 +482,10 @@ async def test_partial_error_from_anthropic_batch(
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
sizes = await asyncio.gather(
|
||||
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
|
||||
)
|
||||
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
|
||||
|
||||
new_batch_responses = await poll_running_llm_batches(server)
|
||||
|
||||
@@ -545,7 +549,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
# Tool‑call side‑effects – each agent gets at least 2 extra messages
|
||||
for agent in agents:
|
||||
before = msg_counts_before[agent.id] # captured just before resume
|
||||
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
|
||||
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
|
||||
|
||||
if agent.id == agents_failed[0].id:
|
||||
assert after == before, f"Agent {agent.id} should not have extra messages persisted due to Anthropic failure"
|
||||
@@ -643,7 +647,10 @@ async def test_resume_step_some_stop(
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
sizes = await asyncio.gather(
|
||||
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
|
||||
)
|
||||
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
|
||||
|
||||
new_batch_responses = await poll_running_llm_batches(server)
|
||||
|
||||
@@ -703,7 +710,7 @@ async def test_resume_step_some_stop(
|
||||
# Tool‑call side‑effects – each agent gets at least 2 extra messages
|
||||
for agent in agents:
|
||||
before = msg_counts_before[agent.id] # captured just before resume
|
||||
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
|
||||
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
|
||||
assert after - before >= 2, (
|
||||
f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted."
|
||||
)
|
||||
@@ -803,7 +810,10 @@ async def test_resume_step_after_request_all_continue(
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
sizes = await asyncio.gather(
|
||||
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
|
||||
)
|
||||
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
|
||||
|
||||
new_batch_responses = await poll_running_llm_batches(server)
|
||||
|
||||
@@ -860,7 +870,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
# Tool‑call side‑effects – each agent gets at least 2 extra messages
|
||||
for agent in agents:
|
||||
before = msg_counts_before[agent.id] # captured just before resume
|
||||
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
|
||||
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
|
||||
assert after - before >= 2, (
|
||||
f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user