test: make send message tests less flaky (#2578)
This commit is contained in:
@@ -122,9 +122,20 @@ def roll_dice(num_sides: int) -> int:
|
||||
|
||||
|
||||
USER_MESSAGE_OTID = str(uuid.uuid4())
|
||||
USER_MESSAGE_GREETING: List[MessageCreate] = [MessageCreate(role="user", content="Hi there.", otid=USER_MESSAGE_OTID)]
|
||||
USER_MESSAGE_TOOL_CALL: List[MessageCreate] = [
|
||||
MessageCreate(role="user", content="Call the roll_dice tool with 16 sides and tell me the outcome.", otid=USER_MESSAGE_OTID)
|
||||
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
|
||||
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=f"This is an automated test message. Call the send_message tool with the message '{USER_MESSAGE_RESPONSE}'.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="This is an automated test message. Call the roll_dice tool with 16 sides and tell me the outcome.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
@@ -145,6 +156,7 @@ TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
||||
def assert_greeting_with_assistant_message_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -166,6 +178,8 @@ def assert_greeting_with_assistant_message_response(
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
if not token_streaming:
|
||||
assert USER_MESSAGE_RESPONSE in messages[index].content
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
@@ -180,6 +194,7 @@ def assert_greeting_with_assistant_message_response(
|
||||
def assert_greeting_without_assistant_message_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -201,6 +216,9 @@ def assert_greeting_without_assistant_message_response(
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
assert messages[index].tool_call.name == "send_message"
|
||||
if not token_streaming:
|
||||
assert USER_MESSAGE_RESPONSE in messages[index].tool_call.arguments
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
@@ -339,7 +357,7 @@ def test_greeting_with_assistant_message(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
assert_greeting_with_assistant_message_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
@@ -365,7 +383,7 @@ def test_greeting_without_assistant_message(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
assert_greeting_without_assistant_message_response(response.messages)
|
||||
@@ -394,7 +412,7 @@ def test_tool_call(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_TOOL_CALL,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
assert_tool_call_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
@@ -421,7 +439,7 @@ async def test_greeting_with_assistant_message_async_client(
|
||||
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = await async_client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
assert_greeting_with_assistant_message_response(response.messages)
|
||||
|
||||
@@ -446,7 +464,7 @@ async def test_greeting_without_assistant_message_async_client(
|
||||
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = await async_client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
assert_greeting_without_assistant_message_response(response.messages)
|
||||
@@ -474,7 +492,7 @@ async def test_tool_call_async_client(
|
||||
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = await async_client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_TOOL_CALL,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
assert_tool_call_response(response.messages)
|
||||
|
||||
@@ -497,7 +515,7 @@ def test_streaming_greeting_with_assistant_message(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
@@ -522,7 +540,7 @@ def test_streaming_greeting_without_assistant_message(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
chunks = list(response)
|
||||
@@ -550,7 +568,7 @@ def test_streaming_tool_call(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_TOOL_CALL,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
@@ -577,7 +595,7 @@ async def test_streaming_greeting_with_assistant_message_async_client(
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = async_client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
chunks = [chunk async for chunk in response]
|
||||
messages = accumulate_chunks(chunks)
|
||||
@@ -604,7 +622,7 @@ async def test_streaming_greeting_without_assistant_message_async_client(
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = async_client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
chunks = [chunk async for chunk in response]
|
||||
@@ -634,7 +652,7 @@ async def test_streaming_tool_call_async_client(
|
||||
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = async_client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_TOOL_CALL,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
chunks = [chunk async for chunk in response]
|
||||
messages = accumulate_chunks(chunks)
|
||||
@@ -659,7 +677,7 @@ def test_step_streaming_greeting_with_assistant_message(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
stream_tokens=False,
|
||||
)
|
||||
messages = []
|
||||
@@ -686,12 +704,12 @@ def test_token_streaming_greeting_with_assistant_message(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -712,13 +730,13 @@ def test_token_streaming_greeting_without_assistant_message(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_without_assistant_message_response(messages, streaming=True)
|
||||
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -741,7 +759,7 @@ def test_token_streaming_tool_call(
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_TOOL_CALL,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
@@ -769,7 +787,7 @@ async def test_token_streaming_greeting_with_assistant_message_async_client(
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = async_client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = [chunk async for chunk in response]
|
||||
@@ -797,7 +815,7 @@ async def test_token_streaming_greeting_without_assistant_message_async_client(
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = async_client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
stream_tokens=True,
|
||||
)
|
||||
@@ -828,7 +846,7 @@ async def test_token_streaming_tool_call_async_client(
|
||||
agent_state = await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = async_client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_TOOL_CALL,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = [chunk async for chunk in response]
|
||||
@@ -855,7 +873,7 @@ def test_async_greeting_with_assistant_message(
|
||||
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
|
||||
@@ -888,7 +906,7 @@ async def test_async_greeting_with_assistant_message_async_client(
|
||||
|
||||
run = await async_client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
# Use the synchronous client to check job completion
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
|
||||
Reference in New Issue
Block a user