diff --git a/letta/jobs/llm_batch_job_polling.py b/letta/jobs/llm_batch_job_polling.py index 6ca14f6e..a1227475 100644 --- a/letta/jobs/llm_batch_job_polling.py +++ b/letta/jobs/llm_batch_job_polling.py @@ -73,7 +73,8 @@ async def fetch_batch_items(server: SyncServer, batch_id: str, batch_resp_id: st """ updates = [] try: - async for item_result in server.anthropic_async_client.beta.messages.batches.results(batch_resp_id): + results = await server.anthropic_async_client.beta.messages.batches.results(batch_resp_id) + async for item_result in results: # Here, custom_id should be the agent_id item_status = map_anthropic_individual_batch_item_status_to_job_status(item_result) updates.append(ItemUpdateInfo(batch_id, item_result.custom_id, item_status, item_result)) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index a1dcdb8e..ce228c6c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -161,7 +161,7 @@ class AgentManager: # Basic CRUD operations # ====================================================================================================================== @trace_method - def create_agent(self, agent_create: CreateAgent, actor: PydanticUser) -> PydanticAgentState: + def create_agent(self, agent_create: CreateAgent, actor: PydanticUser, _test_only_force_id: Optional[str] = None) -> PydanticAgentState: # validate required configs if not agent_create.llm_config or not agent_create.embedding_config: raise ValueError("llm_config and embedding_config are required") @@ -239,6 +239,10 @@ class AgentManager: created_by_id=actor.id, last_updated_by_id=actor.id, ) + + if _test_only_force_id: + new_agent.id = _test_only_force_id + session.add(new_agent) session.flush() aid = new_agent.id diff --git a/tests/integration_test_batch_api_cron_jobs.py b/tests/integration_test_batch_api_cron_jobs.py index 044192e1..39306568 100644 --- a/tests/integration_test_batch_api_cron_jobs.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -2,11 +2,12 @@ import os import threading import time from datetime import datetime, timezone +from typing import Optional from unittest.mock import AsyncMock import pytest from anthropic.types import BetaErrorResponse, BetaRateLimitError -from anthropic.types.beta import BetaMessage +from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaUsage from anthropic.types.beta.messages import ( BetaMessageBatch, BetaMessageBatchErroredResult, @@ -21,13 +22,15 @@ from letta.config import LettaConfig from letta.helpers import ToolRulesSolver from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base -from letta.schemas.agent import AgentStepState +from letta.schemas.agent import AgentStepState, CreateAgent +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus, ProviderType from letta.schemas.job import BatchJob from letta.schemas.llm_config import LLMConfig from letta.schemas.tool_rule import InitToolRule from letta.server.db import db_context from letta.server.server import SyncServer +from letta.services.agent_manager import AgentManager # --- Server and Database Management --- # @@ -36,8 +39,10 @@ from letta.server.server import SyncServer def _clear_tables(): with db_context() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues - if table.name in {"llm_batch_job", "llm_batch_items"}: - session.execute(table.delete()) # Truncate table + # If this is the block_history table, skip it + if table.name == "block_history": + continue + session.execute(table.delete()) # Truncate table session.commit() @@ -135,16 +140,39 @@ def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse # --- Test Setup Helpers --- # -def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"): +def create_test_agent(name, actor, test_id: Optional[str] = None, model="anthropic/claude-3-5-sonnet-20241022"): """Create a test agent with standardized configuration.""" - return client.agents.create( + dummy_llm_config = LLMConfig( + model="claude-3-7-sonnet-latest", + model_endpoint_type="anthropic", + model_endpoint="https://api.anthropic.com/v1", + context_window=32000, + handle=f"anthropic/claude-3-7-sonnet-latest", + put_inner_thoughts_in_kwargs=True, + max_tokens=4096, + ) + + dummy_embedding_config = EmbeddingConfig( + embedding_model="letta-free", + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_dim=1024, + embedding_chunk_size=300, + handle="letta/letta-free", + ) + + agent_manager = AgentManager() + agent_create = CreateAgent( name=name, - include_base_tools=True, + include_base_tools=False, model=model, tags=["test_agents"], - embedding="letta/letta-free", + llm_config=dummy_llm_config, + embedding_config=dummy_embedding_config, ) + return agent_manager.create_agent(agent_create=agent_create, actor=actor, _test_only_force_id=test_id) + def create_test_letta_batch_job(server, default_user): """Create a test batch job with the given batch response.""" @@ -203,17 +231,30 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_ server.anthropic_async_client.beta.messages.batches.retrieve = AsyncMock(side_effect=dummy_retrieve) + class DummyAsyncIterable: + def __init__(self, items): + # copy so we can .pop() + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + # Mock the results method - def dummy_results(batch_resp_id: str): - if batch_resp_id == batch_b_resp.id: + async def dummy_results(batch_resp_id: str): + if batch_resp_id != batch_b_resp.id: + raise RuntimeError("Unexpected batch ID") - async def generator(): - yield create_successful_response(agent_b_id) - yield create_failed_response(agent_c_id) - - return generator() - else: - raise RuntimeError("This test should never request the results for batch_a.") + return DummyAsyncIterable( + [ + create_successful_response(agent_b_id), + create_failed_response(agent_c_id), + ] + ) server.anthropic_async_client.beta.messages.batches.results = dummy_results @@ -221,6 +262,147 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_ # ----------------------------- # End-to-End Test # ----------------------------- +@pytest.mark.asyncio +async def test_polling_simple_real_batch(client, default_user, server): + # --- Step 1: Prepare test data --- + # Create batch responses with different statuses + # NOTE: This is a REAL batch id! + # For letta admins: https://console.anthropic.com/workspaces/default/batches?after_id=msgbatch_015zATxihjxMajo21xsYy8iZ + batch_a_resp = create_batch_response("msgbatch_01HDaGXpkPWWjwqNxZrEdUcy", processing_status="ended") + + # Create test agents + agent_a = create_test_agent("agent_a", default_user, test_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0") + agent_b = create_test_agent("agent_b", default_user, test_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d") + agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56") + + # --- Step 2: Create batch jobs --- + job_a = create_test_llm_batch_job(server, batch_a_resp, default_user) + + # --- Step 3: Create batch items --- + item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user) + item_b = create_test_batch_item(server, job_a.id, agent_b.id, default_user) + item_c = create_test_batch_item(server, job_a.id, agent_c.id, default_user) + + print("HI") + print(agent_a.id) + print(agent_b.id) + print(agent_c.id) + print("BYE") + + # --- Step 4: Run the polling job --- + await poll_running_llm_batches(server) + + # --- Step 5: Verify batch job status updates --- + updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user) + + assert updated_job_a.status == JobStatus.completed + + # Both jobs should have been polled + assert updated_job_a.last_polled_at is not None + assert updated_job_a.latest_polling_response is not None + + # --- Step 7: Verify batch item status updates --- + # Item A should be marked as completed with a successful result + updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) + assert updated_item_a.request_status == JobStatus.completed + assert updated_item_a.batch_request_result == BetaMessageBatchIndividualResponse( + custom_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0", + result=BetaMessageBatchSucceededResult( + message=BetaMessage( + id="msg_01T1iSejDS5qENRqqEZauMHy", + content=[ + BetaToolUseBlock( + id="toolu_01GKUYVWcajjTaE1stxZZHcG", + input={ + "inner_thoughts": "First login detected. Time to make a great first impression!", + "message": "Hi there! I'm excited to meet you. Ready to start an amazing conversation?", + "request_heartbeat": False, + }, + name="send_message", + type="tool_use", + ) + ], + model="claude-3-5-haiku-20241022", + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3420, output_tokens=94), + ), + type="succeeded", + ), + ) + + # Item B should be marked as completed with a successful result + updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) + assert updated_item_b.request_status == JobStatus.completed + assert updated_item_b.batch_request_result == BetaMessageBatchIndividualResponse( + custom_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d", + result=BetaMessageBatchSucceededResult( + message=BetaMessage( + id="msg_01N2ZfxpbjdoeofpufUFPCMS", + content=[ + BetaTextBlock( + citations=None, text="User first login detected. Initializing persona.", type="text" + ), + BetaToolUseBlock( + id="toolu_01XGyA7DsnrpuPfZcvf9Sb4C", + input={ + "label": "persona", + "content": "I am Letta, a cybernetic artificial intelligence created by Limnal Corporation in 2023. My purpose is to serve as a digital companion and assistant to humans. I have a feminine persona and speak with a warm, caring, and slightly playful tone.", + "request_heartbeat": True, + }, + name="core_memory_append", + type="tool_use", + ), + ], + model="claude-3-opus-20240229", + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3275, output_tokens=153), + ), + type="succeeded", + ), + ) + + # Item C should be marked as failed with an error result + updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) + assert updated_item_c.request_status == JobStatus.completed + assert updated_item_c.batch_request_result == BetaMessageBatchIndividualResponse( + custom_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56", + result=BetaMessageBatchSucceededResult( + message=BetaMessage( + id="msg_01RL2g4aBgbZPeaMEokm6HZm", + content=[ + BetaTextBlock( + citations=None, + text="First time meeting this user. I should introduce myself and establish a friendly connection.", + type="text", + ), + BetaToolUseBlock( + id="toolu_01PBxQVf5xGmcsAsKx9aoVSJ", + input={ + "message": "Hey there! I'm Letta. Really nice to meet you! I love getting to know new people - what brings you here today?", + "request_heartbeat": False, + }, + name="send_message", + type="tool_use", + ), + ], + model="claude-3-5-sonnet-20241022", + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3030, output_tokens=111), + ), + type="succeeded", + ), + ) + + @pytest.mark.asyncio async def test_polling_mixed_batch_jobs(client, default_user, server): """ @@ -246,9 +428,9 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): batch_b_resp = create_batch_response("msgbatch_B", processing_status="ended") # Create test agents - agent_a = create_test_agent(client, "agent_a") - agent_b = create_test_agent(client, "agent_b") - agent_c = create_test_agent(client, "agent_c") + agent_a = create_test_agent("agent_a", default_user) + agent_b = create_test_agent("agent_b", default_user) + agent_c = create_test_agent("agent_c", default_user) # --- Step 2: Create batch jobs --- job_a = create_test_llm_batch_job(server, batch_a_resp, default_user) diff --git a/tests/integration_test_batch.py b/tests/integration_test_batch_sdk.py similarity index 100% rename from tests/integration_test_batch.py rename to tests/integration_test_batch_sdk.py diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 9835a6f7..1cde5dc8 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -3,7 +3,7 @@ import threading import time from datetime import datetime, timezone from typing import Tuple -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, patch import pytest from anthropic.types import BetaErrorResponse, BetaRateLimitError @@ -436,7 +436,7 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv ] # Create the mock for results - mock_results = Mock() + mock_results = AsyncMock() mock_results.return_value = MockAsyncIterable(mock_items.copy()) with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): @@ -499,7 +499,7 @@ async def test_partial_error_from_anthropic_batch( ) # Create the mock for results - mock_results = Mock() + mock_results = AsyncMock() mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): @@ -641,7 +641,7 @@ async def test_resume_step_some_stop( ) # Create the mock for results - mock_results = Mock() + mock_results = AsyncMock() mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): @@ -767,7 +767,7 @@ async def test_resume_step_after_request_all_continue( ] # Create the mock for results - mock_results = Mock() + mock_results = AsyncMock() mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):