From 2ef2bde2e1811158ca95654a1718a7435ea2e2a5 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 21 Apr 2025 15:47:20 -0700 Subject: [PATCH] test: Add stateful tool test for letta batch agent (#1824) --- letta/services/llm_batch_manager.py | 52 ++++----- tests/test_letta_agent_batch.py | 157 +++++++++++++++++++++++++--- tests/test_managers.py | 50 ++++++++- 3 files changed, 216 insertions(+), 43 deletions(-) diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index 0e944003..ec3a947b 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -291,9 +291,7 @@ class LLMBatchManager: return [item.to_pydantic() for item in results] def bulk_update_llm_batch_items( - self, - llm_batch_id_agent_id_pairs: List[Tuple[str, str]], - field_updates: List[Dict[str, Any]], + self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True ) -> None: """ Efficiently update multiple LLMBatchItem rows by (llm_batch_id, agent_id) pairs. @@ -301,30 +299,43 @@ class LLMBatchManager: Args: llm_batch_id_agent_id_pairs: List of (llm_batch_id, agent_id) tuples identifying items to update field_updates: List of dictionaries containing the fields to update for each item + strict: Whether to error if any of the requested keys don't exist (default True). + If False, missing pairs are skipped. """ if not llm_batch_id_agent_id_pairs or not field_updates: return if len(llm_batch_id_agent_id_pairs) != len(field_updates): - raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length") + raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length") with self.session_maker() as session: - # Lookup primary keys + # Lookup primary keys for all requested (batch_id, agent_id) pairs items = ( session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id) .filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs)) .all() ) - pair_to_pk = {(b, a): id for id, b, a in items} + pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items} + if strict: + requested = set(llm_batch_id_agent_id_pairs) + found = set(pair_to_pk.keys()) + missing = requested - found + if missing: + raise ValueError( + f"Cannot bulk-update batch items: no records for the following " f"(llm_batch_id, agent_id) pairs: {missing}" + ) + + # Build mappings, skipping any missing when strict=False mappings = [] - for (llm_batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates): - pk_id = pair_to_pk.get((llm_batch_id, agent_id)) - if not pk_id: + for (batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates): + pk = pair_to_pk.get((batch_id, agent_id)) + if pk is None: + # skip missing in non-strict mode continue update_fields = fields.copy() - update_fields["id"] = pk_id + update_fields["id"] = pk mappings.append(update_fields) if mappings: @@ -332,10 +343,7 @@ class LLMBatchManager: session.commit() @enforce_types - def bulk_update_batch_llm_items_results_by_agent( - self, - updates: List[ItemUpdateInfo], - ) -> None: + def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None: """Update request status and batch results for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [ @@ -346,29 +354,23 @@ class LLMBatchManager: for update in updates ] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates) + self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types - def bulk_update_llm_batch_items_step_status_by_agent( - self, - updates: List[StepStatusUpdateInfo], - ) -> None: + def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None: """Update step status for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [{"step_status": update.step_status} for update in updates] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates) + self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types - def bulk_update_llm_batch_items_request_status_by_agent( - self, - updates: List[RequestStatusUpdateInfo], - ) -> None: + def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None: """Update request status for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [{"request_status": update.request_status} for update in updates] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates) + self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None: diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index a67cc48c..9835a6f7 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,11 +1,3 @@ -""" -Tests for LettaAgentBatch.step_until_request functionality. - -This module tests the batch processing capabilities of LettaAgentBatch, -specifically the step_until_request method which prepares agent requests -for batch processing. -""" - import os import threading import time @@ -92,6 +84,28 @@ def weather_tool(client): yield tool +@pytest.fixture(scope="function") +def rethink_tool(client): + def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore + """ + Re-evaluate the memory in block_name, integrating new and updated facts. + Replace outdated information with the most likely truths, avoiding redundancy with original memories. + Ensure consistency with other memory blocks. + + Args: + new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block. + target_block_label (str): The name of the block to write to. + Returns: + str: None is always returned as this function does not produce a response. + """ + agent_state.memory.update_block_value(label=target_block_label, value=new_memory) + return None + + tool = client.tools.upsert_from_function(func=rethink_memory) + # Yield the created tool + yield tool + + @pytest.fixture def agents(client, weather_tool): """ @@ -173,7 +187,7 @@ def create_batch_response(batch_id: str, processing_status: str = "in_progress") ) -def create_complete_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse: +def create_get_weather_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse: """Create a dummy successful batch response with a tool call after user asks about weather.""" return BetaMessageBatchIndividualResponse( custom_id=custom_id, @@ -204,6 +218,39 @@ def create_complete_tool_response(custom_id: str, model: str, request_heartbeat: ) +def create_rethink_tool_response( + custom_id: str, model: str, request_heartbeat: bool, new_memory: str, target_block_label: str +) -> BetaMessageBatchIndividualResponse: + """Create a dummy successful batch response with a tool call after user asks about weather.""" + return BetaMessageBatchIndividualResponse( + custom_id=custom_id, + result=BetaMessageBatchSucceededResult( + type="succeeded", + message=BetaMessage( + id="msg_abc123", + role="assistant", + type="message", + model=model, + content=[ + {"type": "text", "text": "Let me rethink my memory."}, + { + "type": "tool_use", + "id": "tu_01234567890123456789012345", + "name": "rethink_memory", + "input": { + "new_memory": new_memory, + "target_block_label": target_block_label, + "request_heartbeat": request_heartbeat, + }, + }, + ], + usage={"input_tokens": 7, "output_tokens": 17}, + stop_reason="end_turn", + ), + ), + ) + + def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse: """Create a dummy failed batch response with a rate limit error.""" return BetaMessageBatchIndividualResponse( @@ -322,7 +369,89 @@ class MockAsyncIterable: @pytest.mark.asyncio -async def test_error_from_anthropic( +async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool): + target_block_label = "human" + new_memory = "banana" + agent = client.agents.create( + name=f"test_agent_rethink", + include_base_tools=True, + model=MODELS["sonnet"], + tags=["test_agents"], + embedding="letta/letta-free", + tool_ids=[rethink_tool.id], + memory_blocks=[ + { + "label": target_block_label, + "value": "Name: Matt", + }, + ], + ) + agents = [agent] + batch_requests = [ + LettaBatchRequest(agent_id=agent.id, messages=[MessageCreate(role="user", content=[TextContent(text=f"Rethink memory.")])]) + for agent in agents + ] + + anthropic_batch_id = "msgbatch_test_12345" + dummy_batch_response = create_batch_response( + batch_id=anthropic_batch_id, + ) + + # 1. Invoke `step_until_request` + with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): + # Create batch runner + batch_runner = LettaAgentBatch( + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + batch_manager=server.batch_manager, + sandbox_config_manager=server.sandbox_config_manager, + job_manager=server.job_manager, + actor=default_user, + ) + + # Run the method under test + solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="rethink_memory")]) + step_state_map = {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents} + pre_resume_response = await batch_runner.step_until_request( + batch_requests=batch_requests, + agent_step_state_mapping=step_state_map, + letta_batch_job_id=batch_job.id, + ) + + # 2. Invoke the polling job and mock responses from Anthropic + mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended")) + + with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve): + mock_items = [ + create_rethink_tool_response( + custom_id=agent.id, + model=agent.llm_config.model, + request_heartbeat=False, + new_memory=new_memory, + target_block_label=target_block_label, + ) + for agent in agents + ] + + # Create the mock for results + mock_results = Mock() + mock_results.return_value = MockAsyncIterable(mock_items.copy()) + + with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): + with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): + await poll_running_llm_batches(server) + + # Check that the tool has been executed correctly + agent = client.agents.retrieve(agent_id=agent.id) + for block in agent.memory.blocks: + if block.label == target_block_label: + assert block.value == new_memory + + +@pytest.mark.asyncio +async def test_partial_error_from_anthropic_batch( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): anthropic_batch_id = "msgbatch_test_12345" @@ -364,7 +493,7 @@ async def test_error_from_anthropic( mock_items = [create_failed_response(custom_id=agent.id) for agent in agents_failed] mock_items.extend( [ - create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) + create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents_continue ] ) @@ -501,12 +630,12 @@ async def test_resume_step_some_stop( agents_continue = agents[:1] agents_finish = agents[1:] mock_items = [ - create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) + create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents_continue ] mock_items.extend( [ - create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=False) + create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=False) for agent in agents_finish ] ) @@ -634,7 +763,7 @@ async def test_resume_step_after_request_all_continue( with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve): mock_items = [ - create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents + create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents ] # Create the mock for results diff --git a/tests/test_managers.py b/tests/test_managers.py index 3a9de069..76e6d0dc 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1,5 +1,6 @@ import os import random +import re import string import time from datetime import datetime, timedelta, timezone @@ -5171,6 +5172,47 @@ def test_bulk_update_batch_items_request_status_by_agent( assert updated.request_status == JobStatus.expired +def test_bulk_update_nonexistent_items_should_error( + server, + default_user, + dummy_beta_message_batch, + dummy_successful_response, + letta_batch_job, +): + # Create a batch job + batch = server.batch_manager.create_llm_batch_job( + llm_provider=ProviderType.anthropic, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + letta_batch_job_id=letta_batch_job.id, + ) + + nonexistent_pairs = [(batch.id, "nonexistent-agent-id")] + nonexistent_updates = [{"request_status": JobStatus.expired}] + expected_err_msg = ( + f"Cannot bulk-update batch items: no records for the following " + f"(llm_batch_id, agent_id) pairs: {{('{batch.id}', 'nonexistent-agent-id')}}" + ) + + with pytest.raises(ValueError, match=re.escape(expected_err_msg)): + server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates) + + with pytest.raises(ValueError, match=re.escape(expected_err_msg)): + server.batch_manager.bulk_update_batch_llm_items_results_by_agent( + [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)] + ) + + with pytest.raises(ValueError, match=re.escape(expected_err_msg)): + server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( + [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)] + ) + + with pytest.raises(ValueError, match=re.escape(expected_err_msg)): + server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( + [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)] + ) + + def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job): # Create a batch job batch = server.batch_manager.create_llm_batch_job( @@ -5187,22 +5229,22 @@ def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_ nonexistent_updates = [{"request_status": JobStatus.expired}] # This should not raise an error, just silently skip non-existent items - server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates) + server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False) # Test with higher-level methods # Results by agent server.batch_manager.bulk_update_batch_llm_items_results_by_agent( - [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)] + [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False ) # Step status by agent server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( - [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)] + [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False ) # Request status by agent server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( - [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)] + [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False )