test: Add stateful tool test for letta batch agent (#1824)
This commit is contained in:
@@ -291,9 +291,7 @@ class LLMBatchManager:
|
||||
return [item.to_pydantic() for item in results]
|
||||
|
||||
def bulk_update_llm_batch_items(
|
||||
self,
|
||||
llm_batch_id_agent_id_pairs: List[Tuple[str, str]],
|
||||
field_updates: List[Dict[str, Any]],
|
||||
self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Efficiently update multiple LLMBatchItem rows by (llm_batch_id, agent_id) pairs.
|
||||
@@ -301,30 +299,43 @@ class LLMBatchManager:
|
||||
Args:
|
||||
llm_batch_id_agent_id_pairs: List of (llm_batch_id, agent_id) tuples identifying items to update
|
||||
field_updates: List of dictionaries containing the fields to update for each item
|
||||
strict: Whether to error if any of the requested keys don't exist (default True).
|
||||
If False, missing pairs are skipped.
|
||||
"""
|
||||
if not llm_batch_id_agent_id_pairs or not field_updates:
|
||||
return
|
||||
|
||||
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
|
||||
raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length")
|
||||
raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length")
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Lookup primary keys
|
||||
# Lookup primary keys for all requested (batch_id, agent_id) pairs
|
||||
items = (
|
||||
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
|
||||
.filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs))
|
||||
.all()
|
||||
)
|
||||
pair_to_pk = {(b, a): id for id, b, a in items}
|
||||
pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items}
|
||||
|
||||
if strict:
|
||||
requested = set(llm_batch_id_agent_id_pairs)
|
||||
found = set(pair_to_pk.keys())
|
||||
missing = requested - found
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"Cannot bulk-update batch items: no records for the following " f"(llm_batch_id, agent_id) pairs: {missing}"
|
||||
)
|
||||
|
||||
# Build mappings, skipping any missing when strict=False
|
||||
mappings = []
|
||||
for (llm_batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
|
||||
pk_id = pair_to_pk.get((llm_batch_id, agent_id))
|
||||
if not pk_id:
|
||||
for (batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
|
||||
pk = pair_to_pk.get((batch_id, agent_id))
|
||||
if pk is None:
|
||||
# skip missing in non-strict mode
|
||||
continue
|
||||
|
||||
update_fields = fields.copy()
|
||||
update_fields["id"] = pk_id
|
||||
update_fields["id"] = pk
|
||||
mappings.append(update_fields)
|
||||
|
||||
if mappings:
|
||||
@@ -332,10 +343,7 @@ class LLMBatchManager:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_batch_llm_items_results_by_agent(
|
||||
self,
|
||||
updates: List[ItemUpdateInfo],
|
||||
) -> None:
|
||||
def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None:
|
||||
"""Update request status and batch results for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [
|
||||
@@ -346,29 +354,23 @@ class LLMBatchManager:
|
||||
for update in updates
|
||||
]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_llm_batch_items_step_status_by_agent(
|
||||
self,
|
||||
updates: List[StepStatusUpdateInfo],
|
||||
) -> None:
|
||||
def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None:
|
||||
"""Update step status for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [{"step_status": update.step_status} for update in updates]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_llm_batch_items_request_status_by_agent(
|
||||
self,
|
||||
updates: List[RequestStatusUpdateInfo],
|
||||
) -> None:
|
||||
def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None:
|
||||
"""Update request status for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [{"request_status": update.request_status} for update in updates]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||||
|
||||
@@ -1,11 +1,3 @@
|
||||
"""
|
||||
Tests for LettaAgentBatch.step_until_request functionality.
|
||||
|
||||
This module tests the batch processing capabilities of LettaAgentBatch,
|
||||
specifically the step_until_request method which prepares agent requests
|
||||
for batch processing.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -92,6 +84,28 @@ def weather_tool(client):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def rethink_tool(client):
|
||||
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
|
||||
"""
|
||||
Re-evaluate the memory in block_name, integrating new and updated facts.
|
||||
Replace outdated information with the most likely truths, avoiding redundancy with original memories.
|
||||
Ensure consistency with other memory blocks.
|
||||
|
||||
Args:
|
||||
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
|
||||
target_block_label (str): The name of the block to write to.
|
||||
Returns:
|
||||
str: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
|
||||
return None
|
||||
|
||||
tool = client.tools.upsert_from_function(func=rethink_memory)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agents(client, weather_tool):
|
||||
"""
|
||||
@@ -173,7 +187,7 @@ def create_batch_response(batch_id: str, processing_status: str = "in_progress")
|
||||
)
|
||||
|
||||
|
||||
def create_complete_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse:
|
||||
def create_get_weather_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy successful batch response with a tool call after user asks about weather."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
custom_id=custom_id,
|
||||
@@ -204,6 +218,39 @@ def create_complete_tool_response(custom_id: str, model: str, request_heartbeat:
|
||||
)
|
||||
|
||||
|
||||
def create_rethink_tool_response(
|
||||
custom_id: str, model: str, request_heartbeat: bool, new_memory: str, target_block_label: str
|
||||
) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy successful batch response with a tool call after user asks about weather."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
custom_id=custom_id,
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
type="succeeded",
|
||||
message=BetaMessage(
|
||||
id="msg_abc123",
|
||||
role="assistant",
|
||||
type="message",
|
||||
model=model,
|
||||
content=[
|
||||
{"type": "text", "text": "Let me rethink my memory."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu_01234567890123456789012345",
|
||||
"name": "rethink_memory",
|
||||
"input": {
|
||||
"new_memory": new_memory,
|
||||
"target_block_label": target_block_label,
|
||||
"request_heartbeat": request_heartbeat,
|
||||
},
|
||||
},
|
||||
],
|
||||
usage={"input_tokens": 7, "output_tokens": 17},
|
||||
stop_reason="end_turn",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy failed batch response with a rate limit error."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
@@ -322,7 +369,89 @@ class MockAsyncIterable:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_from_anthropic(
|
||||
async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
|
||||
target_block_label = "human"
|
||||
new_memory = "banana"
|
||||
agent = client.agents.create(
|
||||
name=f"test_agent_rethink",
|
||||
include_base_tools=True,
|
||||
model=MODELS["sonnet"],
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[rethink_tool.id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": target_block_label,
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
],
|
||||
)
|
||||
agents = [agent]
|
||||
batch_requests = [
|
||||
LettaBatchRequest(agent_id=agent.id, messages=[MessageCreate(role="user", content=[TextContent(text=f"Rethink memory.")])])
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
batch_id=anthropic_batch_id,
|
||||
)
|
||||
|
||||
# 1. Invoke `step_until_request`
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
# Create batch runner
|
||||
batch_runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Run the method under test
|
||||
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="rethink_memory")])
|
||||
step_state_map = {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents}
|
||||
pre_resume_response = await batch_runner.step_until_request(
|
||||
batch_requests=batch_requests,
|
||||
agent_step_state_mapping=step_state_map,
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
# 2. Invoke the polling job and mock responses from Anthropic
|
||||
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
|
||||
mock_items = [
|
||||
create_rethink_tool_response(
|
||||
custom_id=agent.id,
|
||||
model=agent.llm_config.model,
|
||||
request_heartbeat=False,
|
||||
new_memory=new_memory,
|
||||
target_block_label=target_block_label,
|
||||
)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
# Create the mock for results
|
||||
mock_results = Mock()
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy())
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# Check that the tool has been executed correctly
|
||||
agent = client.agents.retrieve(agent_id=agent.id)
|
||||
for block in agent.memory.blocks:
|
||||
if block.label == target_block_label:
|
||||
assert block.value == new_memory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_error_from_anthropic_batch(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
@@ -364,7 +493,7 @@ async def test_error_from_anthropic(
|
||||
mock_items = [create_failed_response(custom_id=agent.id) for agent in agents_failed]
|
||||
mock_items.extend(
|
||||
[
|
||||
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
|
||||
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
|
||||
for agent in agents_continue
|
||||
]
|
||||
)
|
||||
@@ -501,12 +630,12 @@ async def test_resume_step_some_stop(
|
||||
agents_continue = agents[:1]
|
||||
agents_finish = agents[1:]
|
||||
mock_items = [
|
||||
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
|
||||
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
|
||||
for agent in agents_continue
|
||||
]
|
||||
mock_items.extend(
|
||||
[
|
||||
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=False)
|
||||
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=False)
|
||||
for agent in agents_finish
|
||||
]
|
||||
)
|
||||
@@ -634,7 +763,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
|
||||
mock_items = [
|
||||
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents
|
||||
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents
|
||||
]
|
||||
|
||||
# Create the mock for results
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -5171,6 +5172,47 @@ def test_bulk_update_batch_items_request_status_by_agent(
|
||||
assert updated.request_status == JobStatus.expired
|
||||
|
||||
|
||||
def test_bulk_update_nonexistent_items_should_error(
|
||||
server,
|
||||
default_user,
|
||||
dummy_beta_message_batch,
|
||||
dummy_successful_response,
|
||||
letta_batch_job,
|
||||
):
|
||||
# Create a batch job
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
nonexistent_pairs = [(batch.id, "nonexistent-agent-id")]
|
||||
nonexistent_updates = [{"request_status": JobStatus.expired}]
|
||||
expected_err_msg = (
|
||||
f"Cannot bulk-update batch items: no records for the following "
|
||||
f"(llm_batch_id, agent_id) pairs: {{('{batch.id}', 'nonexistent-agent-id')}}"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
|
||||
)
|
||||
|
||||
|
||||
def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job):
|
||||
# Create a batch job
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
@@ -5187,22 +5229,22 @@ def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_
|
||||
nonexistent_updates = [{"request_status": JobStatus.expired}]
|
||||
|
||||
# This should not raise an error, just silently skip non-existent items
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False)
|
||||
|
||||
# Test with higher-level methods
|
||||
# Results by agent
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False
|
||||
)
|
||||
|
||||
# Step status by agent
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False
|
||||
)
|
||||
|
||||
# Request status by agent
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user