feat: Fix test batch sdkpy [LET-4507] (#4917)
* Fix test batch sdk * Consolidate and fix test batch sdk
This commit is contained in:
committed by
Caren Thomas
parent
b0bc04fec7
commit
d78e0ccb58
@@ -13,6 +13,7 @@ from letta_client import (
|
||||
ContinueToolRule,
|
||||
CreateBlock,
|
||||
Letta as LettaSDKClient,
|
||||
LettaBatchRequest,
|
||||
LettaRequest,
|
||||
MaxCountPerStepToolRule,
|
||||
MessageCreate,
|
||||
@@ -24,6 +25,10 @@ from letta_client.core import ApiError
|
||||
from letta_client.types import AgentState, ToolReturnMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.server.server import SyncServer
|
||||
from tests.helpers.utils import upload_file_and_wait
|
||||
|
||||
# Constants
|
||||
@@ -60,6 +65,18 @@ def client() -> LettaSDKClient:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""
|
||||
Creates a SyncServer instance for testing.
|
||||
|
||||
Loads and saves config to ensure proper initialization.
|
||||
"""
|
||||
config = LettaConfig.load()
|
||||
config.save()
|
||||
return SyncServer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent(client: LettaSDKClient):
|
||||
agent_state = client.agents.create(
|
||||
@@ -2190,3 +2207,74 @@ def test_upsert_tools(client: LettaSDKClient):
|
||||
|
||||
# Clean up
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch(client: LettaSDKClient, server: SyncServer):
|
||||
# create agents
|
||||
agent1 = client.agents.create(
|
||||
name="agent1_batch",
|
||||
memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
|
||||
model="anthropic/claude-3-7-sonnet-20250219",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
agent2 = client.agents.create(
|
||||
name="agent2_batch",
|
||||
memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
|
||||
model="anthropic/claude-3-7-sonnet-20250219",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
# create a run
|
||||
run = client.batches.create(
|
||||
requests=[
|
||||
LettaBatchRequest(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
TextContent(
|
||||
text="hi",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
agent_id=agent1.id,
|
||||
),
|
||||
LettaBatchRequest(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
TextContent(
|
||||
text="hi",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
agent_id=agent2.id,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert run is not None
|
||||
|
||||
# list batches
|
||||
batches = client.batches.list()
|
||||
assert len(batches) >= 1, f"Expected 1 or more batches, got {len(batches)}"
|
||||
assert batches[0].status == JobStatus.running
|
||||
|
||||
# Poll it once
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# get the batch results
|
||||
results = client.batches.retrieve(
|
||||
batch_id=run.id,
|
||||
)
|
||||
assert results is not None
|
||||
|
||||
# cancel
|
||||
client.batches.cancel(batch_id=run.id)
|
||||
batch_job = client.batches.retrieve(
|
||||
batch_id=run.id,
|
||||
)
|
||||
assert batch_job.status == JobStatus.cancelled
|
||||
|
||||
Reference in New Issue
Block a user