From 4deaea4d958d607a113fffeb2d9ac0bf1c8f1099 Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 20 Jun 2025 13:33:48 -0700 Subject: [PATCH] test: add token count check in streaming tests (#2936) --- tests/integration_test_send_message.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 76dcd583..14fa877f 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -331,24 +331,36 @@ def assert_image_input_response( assert messages[index].step_count > 0 -def accumulate_chunks(chunks: List[Any]) -> List[Any]: +def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]: """ Accumulates chunks into a list of messages. """ messages = [] current_message = None prev_message_type = None + chunk_count = 0 for chunk in chunks: current_message_type = chunk.message_type if prev_message_type != current_message_type: messages.append(current_message) + if ( + prev_message_type + and verify_token_streaming + and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"] + ): + assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}" current_message = None + chunk_count = 0 if current_message is None: current_message = chunk else: pass # TODO: actually accumulate the chunks. For now we only care about the count prev_message_type = current_message_type + chunk_count += 1 messages.append(current_message) + if verify_token_streaming and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]: + assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}" + return [m for m in messages if m is not None] @@ -774,7 +786,9 @@ def test_token_streaming_greeting_with_assistant_message( messages=USER_MESSAGE_FORCE_REPLY, stream_tokens=True, ) - messages = accumulate_chunks(list(response)) + messages = accumulate_chunks( + list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]) + ) assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) @@ -803,7 +817,9 @@ def test_token_streaming_greeting_without_assistant_message( use_assistant_message=False, stream_tokens=True, ) - messages = accumulate_chunks(list(response)) + messages = accumulate_chunks( + list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]) + ) assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False) assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) @@ -831,7 +847,9 @@ def test_token_streaming_tool_call( messages=USER_MESSAGE_ROLL_DICE, stream_tokens=True, ) - messages = accumulate_chunks(list(response)) + messages = accumulate_chunks( + list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]) + ) assert_tool_call_response(messages, streaming=True, llm_config=llm_config) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)