From 07ecdc9c6d5e14f8a0efa57755d8d951e53ab52a Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 1 Jun 2025 16:14:30 -0700 Subject: [PATCH] test: make send message tests less flaky (#2578) --- tests/integration_test_send_message.py | 70 ++++++++++++++++---------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 351daea9..58f0f300 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -122,9 +122,20 @@ def roll_dice(num_sides: int) -> int: USER_MESSAGE_OTID = str(uuid.uuid4()) -USER_MESSAGE_GREETING: List[MessageCreate] = [MessageCreate(role="user", content="Hi there.", otid=USER_MESSAGE_OTID)] -USER_MESSAGE_TOOL_CALL: List[MessageCreate] = [ - MessageCreate(role="user", content="Call the roll_dice tool with 16 sides and tell me the outcome.", otid=USER_MESSAGE_OTID) +USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work" +USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [ + MessageCreate( + role="user", + content=f"This is an automated test message. Call the send_message tool with the message '{USER_MESSAGE_RESPONSE}'.", + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [ + MessageCreate( + role="user", + content="This is an automated test message. Call the roll_dice tool with 16 sides and tell me the outcome.", + otid=USER_MESSAGE_OTID, + ) ] all_configs = [ "openai-gpt-4o-mini.json", @@ -145,6 +156,7 @@ TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames] def assert_greeting_with_assistant_message_response( messages: List[Any], streaming: bool = False, + token_streaming: bool = False, from_db: bool = False, ) -> None: """ @@ -166,6 +178,8 @@ def assert_greeting_with_assistant_message_response( index += 1 assert isinstance(messages[index], AssistantMessage) + if not token_streaming: + assert USER_MESSAGE_RESPONSE in messages[index].content assert messages[index].otid and messages[index].otid[-1] == "1" index += 1 @@ -180,6 +194,7 @@ def assert_greeting_with_assistant_message_response( def assert_greeting_without_assistant_message_response( messages: List[Any], streaming: bool = False, + token_streaming: bool = False, from_db: bool = False, ) -> None: """ @@ -201,6 +216,9 @@ def assert_greeting_without_assistant_message_response( index += 1 assert isinstance(messages[index], ToolCallMessage) + assert messages[index].tool_call.name == "send_message" + if not token_streaming: + assert USER_MESSAGE_RESPONSE in messages[index].tool_call.arguments assert messages[index].otid and messages[index].otid[-1] == "1" index += 1 @@ -339,7 +357,7 @@ def test_greeting_with_assistant_message( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, ) assert_greeting_with_assistant_message_response(response.messages) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) @@ -365,7 +383,7 @@ def test_greeting_without_assistant_message( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, use_assistant_message=False, ) assert_greeting_without_assistant_message_response(response.messages) @@ -394,7 +412,7 @@ def test_tool_call( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( agent_id=agent_state.id, - messages=USER_MESSAGE_TOOL_CALL, + messages=USER_MESSAGE_ROLL_DICE, ) assert_tool_call_response(response.messages) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) @@ -421,7 +439,7 @@ async def test_greeting_with_assistant_message_async_client( agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = await async_client.agents.messages.create( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, ) assert_greeting_with_assistant_message_response(response.messages) @@ -446,7 +464,7 @@ async def test_greeting_without_assistant_message_async_client( agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = await async_client.agents.messages.create( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, use_assistant_message=False, ) assert_greeting_without_assistant_message_response(response.messages) @@ -474,7 +492,7 @@ async def test_tool_call_async_client( agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = await async_client.agents.messages.create( agent_id=agent_state.id, - messages=USER_MESSAGE_TOOL_CALL, + messages=USER_MESSAGE_ROLL_DICE, ) assert_tool_call_response(response.messages) @@ -497,7 +515,7 @@ def test_streaming_greeting_with_assistant_message( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, ) chunks = list(response) messages = accumulate_chunks(chunks) @@ -522,7 +540,7 @@ def test_streaming_greeting_without_assistant_message( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, use_assistant_message=False, ) chunks = list(response) @@ -550,7 +568,7 @@ def test_streaming_tool_call( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_TOOL_CALL, + messages=USER_MESSAGE_ROLL_DICE, ) chunks = list(response) messages = accumulate_chunks(chunks) @@ -577,7 +595,7 @@ async def test_streaming_greeting_with_assistant_message_async_client( await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = async_client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, ) chunks = [chunk async for chunk in response] messages = accumulate_chunks(chunks) @@ -604,7 +622,7 @@ async def test_streaming_greeting_without_assistant_message_async_client( await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = async_client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, use_assistant_message=False, ) chunks = [chunk async for chunk in response] @@ -634,7 +652,7 @@ async def test_streaming_tool_call_async_client( agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = async_client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_TOOL_CALL, + messages=USER_MESSAGE_ROLL_DICE, ) chunks = [chunk async for chunk in response] messages = accumulate_chunks(chunks) @@ -659,7 +677,7 @@ def test_step_streaming_greeting_with_assistant_message( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, stream_tokens=False, ) messages = [] @@ -686,12 +704,12 @@ def test_token_streaming_greeting_with_assistant_message( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, stream_tokens=True, ) chunks = list(response) messages = accumulate_chunks(chunks) - assert_greeting_with_assistant_message_response(messages, streaming=True) + assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True) @pytest.mark.parametrize( @@ -712,13 +730,13 @@ def test_token_streaming_greeting_without_assistant_message( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, use_assistant_message=False, stream_tokens=True, ) chunks = list(response) messages = accumulate_chunks(chunks) - assert_greeting_without_assistant_message_response(messages, streaming=True) + assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True) @pytest.mark.parametrize( @@ -741,7 +759,7 @@ def test_token_streaming_tool_call( agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_TOOL_CALL, + messages=USER_MESSAGE_ROLL_DICE, stream_tokens=True, ) chunks = list(response) @@ -769,7 +787,7 @@ async def test_token_streaming_greeting_with_assistant_message_async_client( await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = async_client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, stream_tokens=True, ) chunks = [chunk async for chunk in response] @@ -797,7 +815,7 @@ async def test_token_streaming_greeting_without_assistant_message_async_client( await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = async_client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, use_assistant_message=False, stream_tokens=True, ) @@ -828,7 +846,7 @@ async def test_token_streaming_tool_call_async_client( agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = async_client.agents.messages.create_stream( agent_id=agent_state.id, - messages=USER_MESSAGE_TOOL_CALL, + messages=USER_MESSAGE_ROLL_DICE, stream_tokens=True, ) chunks = [chunk async for chunk in response] @@ -855,7 +873,7 @@ def test_async_greeting_with_assistant_message( run = client.agents.messages.create_async( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, ) run = wait_for_run_completion(client, run.id) @@ -888,7 +906,7 @@ async def test_async_greeting_with_assistant_message_async_client( run = await async_client.agents.messages.create_async( agent_id=agent_state.id, - messages=USER_MESSAGE_GREETING, + messages=USER_MESSAGE_FORCE_REPLY, ) # Use the synchronous client to check job completion run = wait_for_run_completion(client, run.id)