test: make send message tests less flaky (#2578)

This commit is contained in:
cthomas
2025-06-01 16:14:30 -07:00
committed by GitHub
parent e2e223dafa
commit 07ecdc9c6d

View File

@@ -122,9 +122,20 @@ def roll_dice(num_sides: int) -> int:
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_GREETING: List[MessageCreate] = [MessageCreate(role="user", content="Hi there.", otid=USER_MESSAGE_OTID)]
USER_MESSAGE_TOOL_CALL: List[MessageCreate] = [
MessageCreate(role="user", content="Call the roll_dice tool with 16 sides and tell me the outcome.", otid=USER_MESSAGE_OTID)
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
MessageCreate(
role="user",
content=f"This is an automated test message. Call the send_message tool with the message '{USER_MESSAGE_RESPONSE}'.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
MessageCreate(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and tell me the outcome.",
otid=USER_MESSAGE_OTID,
)
]
all_configs = [
"openai-gpt-4o-mini.json",
@@ -145,6 +156,7 @@ TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
def assert_greeting_with_assistant_message_response(
messages: List[Any],
streaming: bool = False,
token_streaming: bool = False,
from_db: bool = False,
) -> None:
"""
@@ -166,6 +178,8 @@ def assert_greeting_with_assistant_message_response(
index += 1
assert isinstance(messages[index], AssistantMessage)
if not token_streaming:
assert USER_MESSAGE_RESPONSE in messages[index].content
assert messages[index].otid and messages[index].otid[-1] == "1"
index += 1
@@ -180,6 +194,7 @@ def assert_greeting_with_assistant_message_response(
def assert_greeting_without_assistant_message_response(
messages: List[Any],
streaming: bool = False,
token_streaming: bool = False,
from_db: bool = False,
) -> None:
"""
@@ -201,6 +216,9 @@ def assert_greeting_without_assistant_message_response(
index += 1
assert isinstance(messages[index], ToolCallMessage)
assert messages[index].tool_call.name == "send_message"
if not token_streaming:
assert USER_MESSAGE_RESPONSE in messages[index].tool_call.arguments
assert messages[index].otid and messages[index].otid[-1] == "1"
index += 1
@@ -339,7 +357,7 @@ def test_greeting_with_assistant_message(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
)
assert_greeting_with_assistant_message_response(response.messages)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
@@ -365,7 +383,7 @@ def test_greeting_without_assistant_message(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
assert_greeting_without_assistant_message_response(response.messages)
@@ -394,7 +412,7 @@ def test_tool_call(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_TOOL_CALL,
messages=USER_MESSAGE_ROLL_DICE,
)
assert_tool_call_response(response.messages)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
@@ -421,7 +439,7 @@ async def test_greeting_with_assistant_message_async_client(
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = await async_client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
)
assert_greeting_with_assistant_message_response(response.messages)
@@ -446,7 +464,7 @@ async def test_greeting_without_assistant_message_async_client(
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = await async_client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
assert_greeting_without_assistant_message_response(response.messages)
@@ -474,7 +492,7 @@ async def test_tool_call_async_client(
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = await async_client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_TOOL_CALL,
messages=USER_MESSAGE_ROLL_DICE,
)
assert_tool_call_response(response.messages)
@@ -497,7 +515,7 @@ def test_streaming_greeting_with_assistant_message(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
@@ -522,7 +540,7 @@ def test_streaming_greeting_without_assistant_message(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
chunks = list(response)
@@ -550,7 +568,7 @@ def test_streaming_tool_call(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_TOOL_CALL,
messages=USER_MESSAGE_ROLL_DICE,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
@@ -577,7 +595,7 @@ async def test_streaming_greeting_with_assistant_message_async_client(
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = async_client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
)
chunks = [chunk async for chunk in response]
messages = accumulate_chunks(chunks)
@@ -604,7 +622,7 @@ async def test_streaming_greeting_without_assistant_message_async_client(
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = async_client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
chunks = [chunk async for chunk in response]
@@ -634,7 +652,7 @@ async def test_streaming_tool_call_async_client(
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = async_client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_TOOL_CALL,
messages=USER_MESSAGE_ROLL_DICE,
)
chunks = [chunk async for chunk in response]
messages = accumulate_chunks(chunks)
@@ -659,7 +677,7 @@ def test_step_streaming_greeting_with_assistant_message(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=False,
)
messages = []
@@ -686,12 +704,12 @@ def test_token_streaming_greeting_with_assistant_message(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=True,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
assert_greeting_with_assistant_message_response(messages, streaming=True)
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True)
@pytest.mark.parametrize(
@@ -712,13 +730,13 @@ def test_token_streaming_greeting_without_assistant_message(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
stream_tokens=True,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
assert_greeting_without_assistant_message_response(messages, streaming=True)
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True)
@pytest.mark.parametrize(
@@ -741,7 +759,7 @@ def test_token_streaming_tool_call(
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_TOOL_CALL,
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=True,
)
chunks = list(response)
@@ -769,7 +787,7 @@ async def test_token_streaming_greeting_with_assistant_message_async_client(
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = async_client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=True,
)
chunks = [chunk async for chunk in response]
@@ -797,7 +815,7 @@ async def test_token_streaming_greeting_without_assistant_message_async_client(
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = async_client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
stream_tokens=True,
)
@@ -828,7 +846,7 @@ async def test_token_streaming_tool_call_async_client(
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = async_client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_TOOL_CALL,
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=True,
)
chunks = [chunk async for chunk in response]
@@ -855,7 +873,7 @@ def test_async_greeting_with_assistant_message(
run = client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
)
run = wait_for_run_completion(client, run.id)
@@ -888,7 +906,7 @@ async def test_async_greeting_with_assistant_message_async_client(
run = await async_client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
messages=USER_MESSAGE_FORCE_REPLY,
)
# Use the synchronous client to check job completion
run = wait_for_run_completion(client, run.id)