test: Add stateful tool test for letta batch agent (#1824)

This commit is contained in:
Matthew Zhou
2025-04-21 15:47:20 -07:00
committed by GitHub
parent 128989820a
commit 2ef2bde2e1
3 changed files with 216 additions and 43 deletions

View File

@@ -291,9 +291,7 @@ class LLMBatchManager:
return [item.to_pydantic() for item in results]
def bulk_update_llm_batch_items(
self,
llm_batch_id_agent_id_pairs: List[Tuple[str, str]],
field_updates: List[Dict[str, Any]],
self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True
) -> None:
"""
Efficiently update multiple LLMBatchItem rows by (llm_batch_id, agent_id) pairs.
@@ -301,30 +299,43 @@ class LLMBatchManager:
Args:
llm_batch_id_agent_id_pairs: List of (llm_batch_id, agent_id) tuples identifying items to update
field_updates: List of dictionaries containing the fields to update for each item
strict: Whether to error if any of the requested keys don't exist (default True).
If False, missing pairs are skipped.
"""
if not llm_batch_id_agent_id_pairs or not field_updates:
return
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length")
raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length")
with self.session_maker() as session:
# Lookup primary keys
# Lookup primary keys for all requested (batch_id, agent_id) pairs
items = (
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
.filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs))
.all()
)
pair_to_pk = {(b, a): id for id, b, a in items}
pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items}
if strict:
requested = set(llm_batch_id_agent_id_pairs)
found = set(pair_to_pk.keys())
missing = requested - found
if missing:
raise ValueError(
f"Cannot bulk-update batch items: no records for the following " f"(llm_batch_id, agent_id) pairs: {missing}"
)
# Build mappings, skipping any missing when strict=False
mappings = []
for (llm_batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
pk_id = pair_to_pk.get((llm_batch_id, agent_id))
if not pk_id:
for (batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
pk = pair_to_pk.get((batch_id, agent_id))
if pk is None:
# skip missing in non-strict mode
continue
update_fields = fields.copy()
update_fields["id"] = pk_id
update_fields["id"] = pk
mappings.append(update_fields)
if mappings:
@@ -332,10 +343,7 @@ class LLMBatchManager:
session.commit()
@enforce_types
def bulk_update_batch_llm_items_results_by_agent(
self,
updates: List[ItemUpdateInfo],
) -> None:
def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None:
"""Update request status and batch results for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [
@@ -346,29 +354,23 @@ class LLMBatchManager:
for update in updates
]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def bulk_update_llm_batch_items_step_status_by_agent(
self,
updates: List[StepStatusUpdateInfo],
) -> None:
def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None:
"""Update step status for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [{"step_status": update.step_status} for update in updates]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def bulk_update_llm_batch_items_request_status_by_agent(
self,
updates: List[RequestStatusUpdateInfo],
) -> None:
def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None:
"""Update request status for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [{"request_status": update.request_status} for update in updates]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:

View File

@@ -1,11 +1,3 @@
"""
Tests for LettaAgentBatch.step_until_request functionality.
This module tests the batch processing capabilities of LettaAgentBatch,
specifically the step_until_request method which prepares agent requests
for batch processing.
"""
import os
import threading
import time
@@ -92,6 +84,28 @@ def weather_tool(client):
yield tool
@pytest.fixture(scope="function")
def rethink_tool(client):
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
"""
Re-evaluate the memory in block_name, integrating new and updated facts.
Replace outdated information with the most likely truths, avoiding redundancy with original memories.
Ensure consistency with other memory blocks.
Args:
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
target_block_label (str): The name of the block to write to.
Returns:
str: None is always returned as this function does not produce a response.
"""
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
return None
tool = client.tools.upsert_from_function(func=rethink_memory)
# Yield the created tool
yield tool
@pytest.fixture
def agents(client, weather_tool):
"""
@@ -173,7 +187,7 @@ def create_batch_response(batch_id: str, processing_status: str = "in_progress")
)
def create_complete_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse:
def create_get_weather_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse:
"""Create a dummy successful batch response with a tool call after user asks about weather."""
return BetaMessageBatchIndividualResponse(
custom_id=custom_id,
@@ -204,6 +218,39 @@ def create_complete_tool_response(custom_id: str, model: str, request_heartbeat:
)
def create_rethink_tool_response(
custom_id: str, model: str, request_heartbeat: bool, new_memory: str, target_block_label: str
) -> BetaMessageBatchIndividualResponse:
"""Create a dummy successful batch response with a tool call after user asks about weather."""
return BetaMessageBatchIndividualResponse(
custom_id=custom_id,
result=BetaMessageBatchSucceededResult(
type="succeeded",
message=BetaMessage(
id="msg_abc123",
role="assistant",
type="message",
model=model,
content=[
{"type": "text", "text": "Let me rethink my memory."},
{
"type": "tool_use",
"id": "tu_01234567890123456789012345",
"name": "rethink_memory",
"input": {
"new_memory": new_memory,
"target_block_label": target_block_label,
"request_heartbeat": request_heartbeat,
},
},
],
usage={"input_tokens": 7, "output_tokens": 17},
stop_reason="end_turn",
),
),
)
def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse:
"""Create a dummy failed batch response with a rate limit error."""
return BetaMessageBatchIndividualResponse(
@@ -322,7 +369,89 @@ class MockAsyncIterable:
@pytest.mark.asyncio
async def test_error_from_anthropic(
async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
target_block_label = "human"
new_memory = "banana"
agent = client.agents.create(
name=f"test_agent_rethink",
include_base_tools=True,
model=MODELS["sonnet"],
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[rethink_tool.id],
memory_blocks=[
{
"label": target_block_label,
"value": "Name: Matt",
},
],
)
agents = [agent]
batch_requests = [
LettaBatchRequest(agent_id=agent.id, messages=[MessageCreate(role="user", content=[TextContent(text=f"Rethink memory.")])])
for agent in agents
]
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
batch_id=anthropic_batch_id,
)
# 1. Invoke `step_until_request`
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
# Create batch runner
batch_runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
# Run the method under test
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="rethink_memory")])
step_state_map = {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents}
pre_resume_response = await batch_runner.step_until_request(
batch_requests=batch_requests,
agent_step_state_mapping=step_state_map,
letta_batch_job_id=batch_job.id,
)
# 2. Invoke the polling job and mock responses from Anthropic
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
mock_items = [
create_rethink_tool_response(
custom_id=agent.id,
model=agent.llm_config.model,
request_heartbeat=False,
new_memory=new_memory,
target_block_label=target_block_label,
)
for agent in agents
]
# Create the mock for results
mock_results = Mock()
mock_results.return_value = MockAsyncIterable(mock_items.copy())
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
await poll_running_llm_batches(server)
# Check that the tool has been executed correctly
agent = client.agents.retrieve(agent_id=agent.id)
for block in agent.memory.blocks:
if block.label == target_block_label:
assert block.value == new_memory
@pytest.mark.asyncio
async def test_partial_error_from_anthropic_batch(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
anthropic_batch_id = "msgbatch_test_12345"
@@ -364,7 +493,7 @@ async def test_error_from_anthropic(
mock_items = [create_failed_response(custom_id=agent.id) for agent in agents_failed]
mock_items.extend(
[
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
for agent in agents_continue
]
)
@@ -501,12 +630,12 @@ async def test_resume_step_some_stop(
agents_continue = agents[:1]
agents_finish = agents[1:]
mock_items = [
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
for agent in agents_continue
]
mock_items.extend(
[
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=False)
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=False)
for agent in agents_finish
]
)
@@ -634,7 +763,7 @@ async def test_resume_step_after_request_all_continue(
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
mock_items = [
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents
]
# Create the mock for results

View File

@@ -1,5 +1,6 @@
import os
import random
import re
import string
import time
from datetime import datetime, timedelta, timezone
@@ -5171,6 +5172,47 @@ def test_bulk_update_batch_items_request_status_by_agent(
assert updated.request_status == JobStatus.expired
def test_bulk_update_nonexistent_items_should_error(
server,
default_user,
dummy_beta_message_batch,
dummy_successful_response,
letta_batch_job,
):
# Create a batch job
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
nonexistent_pairs = [(batch.id, "nonexistent-agent-id")]
nonexistent_updates = [{"request_status": JobStatus.expired}]
expected_err_msg = (
f"Cannot bulk-update batch items: no records for the following "
f"(llm_batch_id, agent_id) pairs: {{('{batch.id}', 'nonexistent-agent-id')}}"
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
)
def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job):
# Create a batch job
batch = server.batch_manager.create_llm_batch_job(
@@ -5187,22 +5229,22 @@ def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_
nonexistent_updates = [{"request_status": JobStatus.expired}]
# This should not raise an error, just silently skip non-existent items
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False)
# Test with higher-level methods
# Results by agent
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False
)
# Step status by agent
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False
)
# Request status by agent
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False
)