From b7f2fb256a3ba52efa269ac68a62905d9b9e9c87 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 21 May 2025 07:25:49 -0700 Subject: [PATCH] fix: Fix test letta agent batch (#2295) --- letta/agents/letta_agent_batch.py | 141 +++++++++++++++++++++--------- tests/test_letta_agent_batch.py | 140 ++++++++++++++--------------- 2 files changed, 166 insertions(+), 115 deletions(-) diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index e154107d..03794cec 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -267,64 +267,119 @@ class LettaAgentBatch(BaseAgent): @trace_method async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext: - # NOTE: We only continue for items with successful results + """ + Collect context for resuming operations from completed batch items. + + Args: + llm_batch_id: The ID of the batch to collect context for + + Returns: + _ResumeContext object containing all necessary data for resumption + """ + # Fetch only completed batch items batch_items = await self.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_id, request_status=JobStatus.completed) - agent_ids = [] - provider_results = {} - request_status_updates: List[RequestStatusUpdateInfo] = [] + # Exit early if no items to process + if not batch_items: + return _ResumeContext( + batch_items=[], + agent_ids=[], + agent_state_map={}, + provider_results={}, + tool_call_name_map={}, + tool_call_args_map={}, + should_continue_map={}, + request_status_updates=[], + ) - for item in batch_items: - aid = item.agent_id - agent_ids.append(aid) - provider_results[aid] = item.batch_request_result.result + # Extract agent IDs and organize items by agent ID + agent_ids = [item.agent_id for item in batch_items] + batch_item_map = {item.agent_id: item for item in batch_items} + # Collect provider results + provider_results = {item.agent_id: item.batch_request_result.result for item in batch_items} + + # Fetch agent states in a single call agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids, actor=self.actor) agent_state_map = {agent.id: agent for agent in agent_states} - name_map, args_map, cont_map = {}, {}, {} - for aid in agent_ids: - # status bookkeeping - pr = provider_results[aid] - status = ( - JobStatus.completed - if isinstance(pr, BetaMessageBatchSucceededResult) - else ( - JobStatus.failed - if isinstance(pr, BetaMessageBatchErroredResult) - else JobStatus.cancelled if isinstance(pr, BetaMessageBatchCanceledResult) else JobStatus.expired - ) - ) - request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status)) - - # translate provider‑specific response → OpenAI‑style tool call (unchanged) - llm_client = LLMClient.create( - provider_type=item.llm_config.model_endpoint_type, - put_inner_thoughts_first=True, - actor=self.actor, - ) - tool_call = ( - llm_client.convert_response_to_chat_completion( - response_data=pr.message.model_dump(), input_messages=[], llm_config=item.llm_config - ) - .choices[0] - .message.tool_calls[0] - ) - - name, args, cont = self._extract_tool_call_and_decide_continue(tool_call, item.step_state) - name_map[aid], args_map[aid], cont_map[aid] = name, args, cont + # Process each agent's results + tool_call_results = self._process_agent_results( + agent_ids=agent_ids, batch_item_map=batch_item_map, provider_results=provider_results, llm_batch_id=llm_batch_id + ) return _ResumeContext( batch_items=batch_items, agent_ids=agent_ids, agent_state_map=agent_state_map, provider_results=provider_results, - tool_call_name_map=name_map, - tool_call_args_map=args_map, - should_continue_map=cont_map, - request_status_updates=request_status_updates, + tool_call_name_map=tool_call_results.name_map, + tool_call_args_map=tool_call_results.args_map, + should_continue_map=tool_call_results.cont_map, + request_status_updates=tool_call_results.status_updates, ) + def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id): + """ + Process the results for each agent, extracting tool calls and determining continuation status. + + Returns: + A namedtuple containing name_map, args_map, cont_map, and status_updates + """ + from collections import namedtuple + + ToolCallResults = namedtuple("ToolCallResults", ["name_map", "args_map", "cont_map", "status_updates"]) + + name_map, args_map, cont_map = {}, {}, {} + request_status_updates = [] + + for aid in agent_ids: + item = batch_item_map[aid] + result = provider_results[aid] + + # Determine job status based on result type + status = self._determine_job_status(result) + request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status)) + + # Process tool calls + name, args, cont = self._extract_tool_call_from_result(item, result) + name_map[aid], args_map[aid], cont_map[aid] = name, args, cont + + return ToolCallResults(name_map, args_map, cont_map, request_status_updates) + + def _determine_job_status(self, result): + """Determine job status based on result type""" + if isinstance(result, BetaMessageBatchSucceededResult): + return JobStatus.completed + elif isinstance(result, BetaMessageBatchErroredResult): + return JobStatus.failed + elif isinstance(result, BetaMessageBatchCanceledResult): + return JobStatus.cancelled + else: + return JobStatus.expired + + def _extract_tool_call_from_result(self, item, result): + """Extract tool call information from a result""" + llm_client = LLMClient.create( + provider_type=item.llm_config.model_endpoint_type, + put_inner_thoughts_first=True, + actor=self.actor, + ) + + # If result isn't a successful type, we can't extract a tool call + if not isinstance(result, BetaMessageBatchSucceededResult): + return None, None, False + + tool_call = ( + llm_client.convert_response_to_chat_completion( + response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config + ) + .choices[0] + .message.tool_calls[0] + ) + + return self._extract_tool_call_and_decide_continue(tool_call, item.step_state) + def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None: if updates: self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates) diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 11da3a19..70005133 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,7 +1,5 @@ -import os -import threading from datetime import datetime, timezone -from typing import Tuple +from typing import List, Optional, Tuple from unittest.mock import AsyncMock, patch import pytest @@ -14,27 +12,26 @@ from anthropic.types.beta.messages import ( BetaMessageBatchRequestCounts, BetaMessageBatchSucceededResult, ) -from dotenv import load_dotenv -from letta_client import Letta from letta.agents.letta_agent_batch import LettaAgentBatch from letta.config import LettaConfig +from letta.functions.functions import parse_source_code from letta.helpers import ToolRulesSolver from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base -from letta.schemas.agent import AgentState, AgentStepState +from letta.schemas.agent import AgentState, AgentStepState, CreateAgent from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType from letta.schemas.job import BatchJob from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_request import LettaBatchRequest from letta.schemas.message import MessageCreate +from letta.schemas.tool import Tool from letta.schemas.tool_rule import InitToolRule from letta.server.db import db_context from letta.server.server import SyncServer -from tests.utils import wait_for_server # --------------------------------------------------------------------------- # -# Test Constants +# Test Constants / Helpers # --------------------------------------------------------------------------- # # Model identifiers used in tests @@ -48,13 +45,31 @@ MODELS = { EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"] +def create_tool_from_func( + func, + tags: Optional[List[str]] = None, + description: Optional[str] = None, +): + source_code = parse_source_code(func) + source_type = "python" + if not tags: + tags = [] + + return Tool( + source_type=source_type, + source_code=source_code, + tags=tags, + description=description, + ) + + # --------------------------------------------------------------------------- # # Test Fixtures # --------------------------------------------------------------------------- # @pytest.fixture(scope="function") -def weather_tool(client): +def weather_tool(server): def get_weather(location: str) -> str: """ Fetches the current weather for a given location. @@ -79,13 +94,14 @@ def weather_tool(client): else: raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") - tool = client.tools.upsert_from_function(func=get_weather) + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor) # Yield the created tool yield tool @pytest.fixture(scope="function") -def rethink_tool(client): +def rethink_tool(server): def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore """ Re-evaluate the memory in block_name, integrating new and updated facts. @@ -101,28 +117,33 @@ def rethink_tool(client): agent_state.memory.update_block_value(label=target_block_label, value=new_memory) return None - tool = client.tools.upsert_from_function(func=rethink_memory) + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor) # Yield the created tool yield tool @pytest.fixture -def agents(client, weather_tool): +def agents(server, weather_tool): """ Create three test agents with different models. Returns: Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models """ + actor = server.user_manager.get_user_or_default() def create_agent(suffix, model_name): - return client.agents.create( - name=f"test_agent_{suffix}", - include_base_tools=True, - model=model_name, - tags=["test_agents"], - embedding="letta/letta-free", - tool_ids=[weather_tool.id], + return server.create_agent( + CreateAgent( + name=f"test_agent_{suffix}", + include_base_tools=True, + model=model_name, + tags=["test_agents"], + embedding="letta/letta-free", + tool_ids=[weather_tool.id], + ), + actor=actor, ) return ( @@ -290,32 +311,6 @@ def clear_batch_tables(): session.commit() -def run_server(): - """Starts the Letta server in a background thread.""" - load_dotenv() - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - -@pytest.fixture(scope="session") -def server_url(): - """ - Ensures a server is running and returns its base URL. - - Uses environment variable if available, otherwise starts a server - in a background thread. - """ - url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - wait_for_server(url) - - return url - - @pytest.fixture(scope="module") def server(): """ @@ -324,14 +319,11 @@ def server(): Loads and saves config to ensure proper initialization. """ config = LettaConfig.load() + config.save() - return SyncServer() - -@pytest.fixture(scope="session") -def client(server_url): - """Creates a REST client connected to the test server.""" - return Letta(base_url=server_url) + server = SyncServer(init_with_default_org_and_user=True) + yield server @pytest.fixture @@ -368,23 +360,27 @@ class MockAsyncIterable: # --------------------------------------------------------------------------- # -@pytest.mark.asyncio(loop_scope="session") -async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool): +@pytest.mark.asyncio(loop_scope="module") +async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool): target_block_label = "human" new_memory = "banana" - agent = client.agents.create( - name=f"test_agent_rethink", - include_base_tools=True, - model=MODELS["sonnet"], - tags=["test_agents"], - embedding="letta/letta-free", - tool_ids=[rethink_tool.id], - memory_blocks=[ - { - "label": target_block_label, - "value": "Name: Matt", - }, - ], + actor = server.user_manager.get_user_or_default() + agent = await server.create_agent_async( + request=CreateAgent( + name=f"test_agent_rethink", + include_base_tools=True, + model=MODELS["sonnet"], + tags=["test_agents"], + embedding="letta/letta-free", + tool_ids=[rethink_tool.id], + memory_blocks=[ + { + "label": target_block_label, + "value": "Name: Matt", + }, + ], + ), + actor=actor, ) agents = [agent] batch_requests = [ @@ -444,13 +440,13 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv await poll_running_llm_batches(server) # Check that the tool has been executed correctly - agent = client.agents.retrieve(agent_id=agent.id) + agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor) for block in agent.memory.blocks: if block.label == target_block_label: assert block.value == new_memory -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_partial_error_from_anthropic_batch( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -610,7 +606,7 @@ async def test_partial_error_from_anthropic_batch( assert agent_messages[0].role == MessageRole.user, "Expected initial user message" -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_resume_step_some_stop( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -773,7 +769,7 @@ def _assert_descending_order(messages): return True -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_resume_step_after_request_all_continue( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -911,7 +907,7 @@ async def test_resume_step_after_request_all_continue( assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message" -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_step_until_request_prepares_and_submits_batch_correctly( disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job ):