test: migrate cancellation test to 1.0 sdk [LET-6327] (#6375)
* test: migrate cancellation test to 1.0 sdk * revert async change * debug redis * cleanup
This commit is contained in:
@@ -6,7 +6,8 @@ from typing import Any, List
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, MessageCreate
|
||||
from letta_client import AsyncLetta
|
||||
from letta_client.types import MessageCreateParam
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -51,8 +52,8 @@ def roll_dice(num_sides: int) -> int:
|
||||
return random.randint(1, num_sides)
|
||||
|
||||
|
||||
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
|
||||
MessageCreateParam(
|
||||
role="user",
|
||||
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
|
||||
)
|
||||
@@ -83,7 +84,7 @@ async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float
|
||||
await client.agents.messages.cancel(agent_id=agent_id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
@@ -100,6 +101,9 @@ def server_url() -> str:
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
if os.getenv("LETTA_REDIS_HOST"):
|
||||
print(f"Redis is configured at {os.getenv('LETTA_REDIS_HOST')}:{os.getenv('LETTA_REDIS_PORT', '6379')}")
|
||||
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
@@ -111,8 +115,8 @@ def server_url() -> str:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
try:
|
||||
response = httpx.get(url + "/v1/health", timeout=1.0)
|
||||
if response.status_code == 200:
|
||||
response = httpx.get(url + "/v1/health", timeout=1.0, follow_redirects=True)
|
||||
if response.status_code < 500:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
@@ -154,6 +158,7 @@ async def agent_state(client: AsyncLetta) -> AgentState:
|
||||
await client.agents.delete(agent_state_instance.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("LETTA_REDIS_HOST"), reason="Redis is required for background streaming (set LETTA_REDIS_HOST to enable)")
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
@@ -166,25 +171,30 @@ async def test_background_streaming_cancellation(
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
agent_state = await client.agents.update(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
delay = 5 if llm_config.model == "gpt-5" else 0.5
|
||||
delay = 5 if llm_config.model == "gpt-5" else 1.5
|
||||
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
response = await client.agents.messages.stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
stream_tokens=True,
|
||||
background=True,
|
||||
)
|
||||
messages = await accumulate_chunks(response)
|
||||
run_id = messages[0].run_id
|
||||
run_id = messages[0].run_id if hasattr(messages[0], "run_id") else None
|
||||
|
||||
await _cancellation_task
|
||||
|
||||
run = await client.runs.retrieve(run_id=run_id)
|
||||
assert run.status == JobStatus.cancelled
|
||||
if run_id:
|
||||
run = await client.runs.retrieve(run_id=run_id)
|
||||
assert run.status == JobStatus.cancelled
|
||||
else:
|
||||
runs = await client.runs.list(agent_id=agent_state.id, stop_reason="cancelled", limit=1)
|
||||
assert len(list(runs)) == 1
|
||||
run_id = runs.items[0].id
|
||||
|
||||
response = client.runs.stream(run_id=run_id, starting_after=0)
|
||||
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
|
||||
messages_from_stream = await accumulate_chunks(response)
|
||||
assert len(messages_from_stream) > 0
|
||||
|
||||
Reference in New Issue
Block a user