diff --git a/tests/integration_test_cancellation.py b/tests/integration_test_cancellation.py index 509ad8db..45042897 100644 --- a/tests/integration_test_cancellation.py +++ b/tests/integration_test_cancellation.py @@ -6,7 +6,8 @@ from typing import Any, List import pytest from dotenv import load_dotenv -from letta_client import AsyncLetta, MessageCreate +from letta_client import AsyncLetta +from letta_client.types import MessageCreateParam from letta.log import get_logger from letta.schemas.agent import AgentState @@ -51,8 +52,8 @@ def roll_dice(num_sides: int) -> int: return random.randint(1, num_sides) -USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [ - MessageCreate( +USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [ + MessageCreateParam( role="user", content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.", ) @@ -83,7 +84,7 @@ async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float await client.agents.messages.cancel(agent_id=agent_id) -@pytest.fixture(scope="function") +@pytest.fixture(scope="module") def server_url() -> str: """ Provides the URL for the Letta server. @@ -100,6 +101,9 @@ def server_url() -> str: url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") if not os.getenv("LETTA_SERVER_URL"): + if os.getenv("LETTA_REDIS_HOST"): + print(f"Redis is configured at {os.getenv('LETTA_REDIS_HOST')}:{os.getenv('LETTA_REDIS_PORT', '6379')}") + thread = threading.Thread(target=_run_server, daemon=True) thread.start() @@ -111,8 +115,8 @@ def server_url() -> str: start_time = time.time() while time.time() - start_time < timeout_seconds: try: - response = httpx.get(url + "/v1/health", timeout=1.0) - if response.status_code == 200: + response = httpx.get(url + "/v1/health", timeout=1.0, follow_redirects=True) + if response.status_code < 500: break except Exception: pass @@ -154,6 +158,7 @@ async def agent_state(client: AsyncLetta) -> AgentState: await client.agents.delete(agent_state_instance.id) +@pytest.mark.skipif(not os.getenv("LETTA_REDIS_HOST"), reason="Redis is required for background streaming (set LETTA_REDIS_HOST to enable)") @pytest.mark.parametrize( "llm_config", TESTED_LLM_CONFIGS, @@ -166,25 +171,30 @@ async def test_background_streaming_cancellation( agent_state: AgentState, llm_config: LLMConfig, ) -> None: - agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + agent_state = await client.agents.update(agent_id=agent_state.id, llm_config=llm_config) - delay = 5 if llm_config.model == "gpt-5" else 0.5 + delay = 5 if llm_config.model == "gpt-5" else 1.5 _cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay)) - response = client.agents.messages.create_stream( + response = await client.agents.messages.stream( agent_id=agent_state.id, messages=USER_MESSAGE_ROLL_DICE, stream_tokens=True, background=True, ) messages = await accumulate_chunks(response) - run_id = messages[0].run_id + run_id = messages[0].run_id if hasattr(messages[0], "run_id") else None await _cancellation_task - run = await client.runs.retrieve(run_id=run_id) - assert run.status == JobStatus.cancelled + if run_id: + run = await client.runs.retrieve(run_id=run_id) + assert run.status == JobStatus.cancelled + else: + runs = await client.runs.list(agent_id=agent_state.id, stop_reason="cancelled", limit=1) + assert len(list(runs)) == 1 + run_id = runs.items[0].id - response = client.runs.stream(run_id=run_id, starting_after=0) + response = await client.runs.messages.stream(run_id=run_id, starting_after=0) messages_from_stream = await accumulate_chunks(response) assert len(messages_from_stream) > 0