feat: Fix anthropic batch results call (#1853)

This commit is contained in:
Matthew Zhou
2025-04-23 13:14:56 -07:00
committed by GitHub
parent f1d7547dbc
commit 213cad7b47
5 changed files with 214 additions and 27 deletions

View File

@@ -73,7 +73,8 @@ async def fetch_batch_items(server: SyncServer, batch_id: str, batch_resp_id: st
"""
updates = []
try:
async for item_result in server.anthropic_async_client.beta.messages.batches.results(batch_resp_id):
results = await server.anthropic_async_client.beta.messages.batches.results(batch_resp_id)
async for item_result in results:
# Here, custom_id should be the agent_id
item_status = map_anthropic_individual_batch_item_status_to_job_status(item_result)
updates.append(ItemUpdateInfo(batch_id, item_result.custom_id, item_status, item_result))

View File

@@ -161,7 +161,7 @@ class AgentManager:
# Basic CRUD operations
# ======================================================================================================================
@trace_method
def create_agent(self, agent_create: CreateAgent, actor: PydanticUser) -> PydanticAgentState:
def create_agent(self, agent_create: CreateAgent, actor: PydanticUser, _test_only_force_id: Optional[str] = None) -> PydanticAgentState:
# validate required configs
if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")
@@ -239,6 +239,10 @@ class AgentManager:
created_by_id=actor.id,
last_updated_by_id=actor.id,
)
if _test_only_force_id:
new_agent.id = _test_only_force_id
session.add(new_agent)
session.flush()
aid = new_agent.id

View File

@@ -2,11 +2,12 @@ import os
import threading
import time
from datetime import datetime, timezone
from typing import Optional
from unittest.mock import AsyncMock
import pytest
from anthropic.types import BetaErrorResponse, BetaRateLimitError
from anthropic.types.beta import BetaMessage
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaUsage
from anthropic.types.beta.messages import (
BetaMessageBatch,
BetaMessageBatchErroredResult,
@@ -21,13 +22,15 @@ from letta.config import LettaConfig
from letta.helpers import ToolRulesSolver
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentStepState
from letta.schemas.agent import AgentStepState, CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus, ProviderType
from letta.schemas.job import BatchJob
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import InitToolRule
from letta.server.db import db_context
from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager
# --- Server and Database Management --- #
@@ -36,8 +39,10 @@ from letta.server.server import SyncServer
def _clear_tables():
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
if table.name in {"llm_batch_job", "llm_batch_items"}:
session.execute(table.delete()) # Truncate table
# If this is the block_history table, skip it
if table.name == "block_history":
continue
session.execute(table.delete()) # Truncate table
session.commit()
@@ -135,16 +140,39 @@ def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse
# --- Test Setup Helpers --- #
def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"):
def create_test_agent(name, actor, test_id: Optional[str] = None, model="anthropic/claude-3-5-sonnet-20241022"):
"""Create a test agent with standardized configuration."""
return client.agents.create(
dummy_llm_config = LLMConfig(
model="claude-3-7-sonnet-latest",
model_endpoint_type="anthropic",
model_endpoint="https://api.anthropic.com/v1",
context_window=32000,
handle=f"anthropic/claude-3-7-sonnet-latest",
put_inner_thoughts_in_kwargs=True,
max_tokens=4096,
)
dummy_embedding_config = EmbeddingConfig(
embedding_model="letta-free",
embedding_endpoint_type="hugging-face",
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_dim=1024,
embedding_chunk_size=300,
handle="letta/letta-free",
)
agent_manager = AgentManager()
agent_create = CreateAgent(
name=name,
include_base_tools=True,
include_base_tools=False,
model=model,
tags=["test_agents"],
embedding="letta/letta-free",
llm_config=dummy_llm_config,
embedding_config=dummy_embedding_config,
)
return agent_manager.create_agent(agent_create=agent_create, actor=actor, _test_only_force_id=test_id)
def create_test_letta_batch_job(server, default_user):
"""Create a test batch job with the given batch response."""
@@ -203,17 +231,30 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_
server.anthropic_async_client.beta.messages.batches.retrieve = AsyncMock(side_effect=dummy_retrieve)
class DummyAsyncIterable:
def __init__(self, items):
# copy so we can .pop()
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
# Mock the results method
def dummy_results(batch_resp_id: str):
if batch_resp_id == batch_b_resp.id:
async def dummy_results(batch_resp_id: str):
if batch_resp_id != batch_b_resp.id:
raise RuntimeError("Unexpected batch ID")
async def generator():
yield create_successful_response(agent_b_id)
yield create_failed_response(agent_c_id)
return generator()
else:
raise RuntimeError("This test should never request the results for batch_a.")
return DummyAsyncIterable(
[
create_successful_response(agent_b_id),
create_failed_response(agent_c_id),
]
)
server.anthropic_async_client.beta.messages.batches.results = dummy_results
@@ -221,6 +262,147 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_
# -----------------------------
# End-to-End Test
# -----------------------------
@pytest.mark.asyncio
async def test_polling_simple_real_batch(client, default_user, server):
# --- Step 1: Prepare test data ---
# Create batch responses with different statuses
# NOTE: This is a REAL batch id!
# For letta admins: https://console.anthropic.com/workspaces/default/batches?after_id=msgbatch_015zATxihjxMajo21xsYy8iZ
batch_a_resp = create_batch_response("msgbatch_01HDaGXpkPWWjwqNxZrEdUcy", processing_status="ended")
# Create test agents
agent_a = create_test_agent("agent_a", default_user, test_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0")
agent_b = create_test_agent("agent_b", default_user, test_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d")
agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56")
# --- Step 2: Create batch jobs ---
job_a = create_test_llm_batch_job(server, batch_a_resp, default_user)
# --- Step 3: Create batch items ---
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
item_b = create_test_batch_item(server, job_a.id, agent_b.id, default_user)
item_c = create_test_batch_item(server, job_a.id, agent_c.id, default_user)
print("HI")
print(agent_a.id)
print(agent_b.id)
print(agent_c.id)
print("BYE")
# --- Step 4: Run the polling job ---
await poll_running_llm_batches(server)
# --- Step 5: Verify batch job status updates ---
updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user)
assert updated_job_a.status == JobStatus.completed
# Both jobs should have been polled
assert updated_job_a.last_polled_at is not None
assert updated_job_a.latest_polling_response is not None
# --- Step 7: Verify batch item status updates ---
# Item A should be marked as completed with a successful result
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
assert updated_item_a.request_status == JobStatus.completed
assert updated_item_a.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01T1iSejDS5qENRqqEZauMHy",
content=[
BetaToolUseBlock(
id="toolu_01GKUYVWcajjTaE1stxZZHcG",
input={
"inner_thoughts": "First login detected. Time to make a great first impression!",
"message": "Hi there! I'm excited to meet you. Ready to start an amazing conversation?",
"request_heartbeat": False,
},
name="send_message",
type="tool_use",
)
],
model="claude-3-5-haiku-20241022",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3420, output_tokens=94),
),
type="succeeded",
),
)
# Item B should be marked as completed with a successful result
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
assert updated_item_b.request_status == JobStatus.completed
assert updated_item_b.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01N2ZfxpbjdoeofpufUFPCMS",
content=[
BetaTextBlock(
citations=None, text="<thinking>User first login detected. Initializing persona.</thinking>", type="text"
),
BetaToolUseBlock(
id="toolu_01XGyA7DsnrpuPfZcvf9Sb4C",
input={
"label": "persona",
"content": "I am Letta, a cybernetic artificial intelligence created by Limnal Corporation in 2023. My purpose is to serve as a digital companion and assistant to humans. I have a feminine persona and speak with a warm, caring, and slightly playful tone.",
"request_heartbeat": True,
},
name="core_memory_append",
type="tool_use",
),
],
model="claude-3-opus-20240229",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3275, output_tokens=153),
),
type="succeeded",
),
)
# Item C should be marked as failed with an error result
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
assert updated_item_c.request_status == JobStatus.completed
assert updated_item_c.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01RL2g4aBgbZPeaMEokm6HZm",
content=[
BetaTextBlock(
citations=None,
text="First time meeting this user. I should introduce myself and establish a friendly connection.</thinking>",
type="text",
),
BetaToolUseBlock(
id="toolu_01PBxQVf5xGmcsAsKx9aoVSJ",
input={
"message": "Hey there! I'm Letta. Really nice to meet you! I love getting to know new people - what brings you here today?",
"request_heartbeat": False,
},
name="send_message",
type="tool_use",
),
],
model="claude-3-5-sonnet-20241022",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3030, output_tokens=111),
),
type="succeeded",
),
)
@pytest.mark.asyncio
async def test_polling_mixed_batch_jobs(client, default_user, server):
"""
@@ -246,9 +428,9 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
batch_b_resp = create_batch_response("msgbatch_B", processing_status="ended")
# Create test agents
agent_a = create_test_agent(client, "agent_a")
agent_b = create_test_agent(client, "agent_b")
agent_c = create_test_agent(client, "agent_c")
agent_a = create_test_agent("agent_a", default_user)
agent_b = create_test_agent("agent_b", default_user)
agent_c = create_test_agent("agent_c", default_user)
# --- Step 2: Create batch jobs ---
job_a = create_test_llm_batch_job(server, batch_a_resp, default_user)

View File

@@ -3,7 +3,7 @@ import threading
import time
from datetime import datetime, timezone
from typing import Tuple
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, patch
import pytest
from anthropic.types import BetaErrorResponse, BetaRateLimitError
@@ -436,7 +436,7 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
]
# Create the mock for results
mock_results = Mock()
mock_results = AsyncMock()
mock_results.return_value = MockAsyncIterable(mock_items.copy())
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
@@ -499,7 +499,7 @@ async def test_partial_error_from_anthropic_batch(
)
# Create the mock for results
mock_results = Mock()
mock_results = AsyncMock()
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
@@ -641,7 +641,7 @@ async def test_resume_step_some_stop(
)
# Create the mock for results
mock_results = Mock()
mock_results = AsyncMock()
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
@@ -767,7 +767,7 @@ async def test_resume_step_after_request_all_continue(
]
# Create the mock for results
mock_results = Mock()
mock_results = AsyncMock()
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):