feat: Fix test batch sdkpy [LET-4507] (#4917)

* Fix test batch sdk

* Consolidate and fix test batch sdk
This commit is contained in:
Matthew Zhou
2025-09-24 16:00:52 -07:00
committed by Caren Thomas
parent b0bc04fec7
commit d78e0ccb58
5 changed files with 129 additions and 180 deletions

View File

@@ -13,6 +13,7 @@ from letta_client import (
ContinueToolRule,
CreateBlock,
Letta as LettaSDKClient,
LettaBatchRequest,
LettaRequest,
MaxCountPerStepToolRule,
MessageCreate,
@@ -24,6 +25,10 @@ from letta_client.core import ApiError
from letta_client.types import AgentState, ToolReturnMessage
from pydantic import BaseModel, Field
from letta.config import LettaConfig
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.schemas.enums import JobStatus
from letta.server.server import SyncServer
from tests.helpers.utils import upload_file_and_wait
# Constants
@@ -60,6 +65,18 @@ def client() -> LettaSDKClient:
yield client
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
return SyncServer()
@pytest.fixture(scope="function")
def agent(client: LettaSDKClient):
agent_state = client.agents.create(
@@ -2190,3 +2207,74 @@ def test_upsert_tools(client: LettaSDKClient):
# Clean up
client.tools.delete(tool.id)
@pytest.mark.asyncio
async def test_create_batch(client: LettaSDKClient, server: SyncServer):
# create agents
agent1 = client.agents.create(
name="agent1_batch",
memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
model="anthropic/claude-3-7-sonnet-20250219",
embedding="letta/letta-free",
)
agent2 = client.agents.create(
name="agent2_batch",
memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
model="anthropic/claude-3-7-sonnet-20250219",
embedding="letta/letta-free",
)
# create a run
run = client.batches.create(
requests=[
LettaBatchRequest(
messages=[
MessageCreate(
role="user",
content=[
TextContent(
text="hi",
)
],
)
],
agent_id=agent1.id,
),
LettaBatchRequest(
messages=[
MessageCreate(
role="user",
content=[
TextContent(
text="hi",
)
],
)
],
agent_id=agent2.id,
),
]
)
assert run is not None
# list batches
batches = client.batches.list()
assert len(batches) >= 1, f"Expected 1 or more batches, got {len(batches)}"
assert batches[0].status == JobStatus.running
# Poll it once
await poll_running_llm_batches(server)
# get the batch results
results = client.batches.retrieve(
batch_id=run.id,
)
assert results is not None
# cancel
client.batches.cancel(batch_id=run.id)
batch_job = client.batches.retrieve(
batch_id=run.id,
)
assert batch_job.status == JobStatus.cancelled