test: add token count check in streaming tests (#2936)

This commit is contained in:
cthomas
2025-06-20 13:33:48 -07:00
committed by GitHub
parent e6f61a0a60
commit 4deaea4d95

View File

@@ -331,24 +331,36 @@ def assert_image_input_response(
assert messages[index].step_count > 0
def accumulate_chunks(chunks: List[Any]) -> List[Any]:
def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]:
"""
Accumulates chunks into a list of messages.
"""
messages = []
current_message = None
prev_message_type = None
chunk_count = 0
for chunk in chunks:
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
messages.append(current_message)
if (
prev_message_type
and verify_token_streaming
and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
):
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}"
current_message = None
chunk_count = 0
if current_message is None:
current_message = chunk
else:
pass # TODO: actually accumulate the chunks. For now we only care about the count
prev_message_type = current_message_type
chunk_count += 1
messages.append(current_message)
if verify_token_streaming and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]:
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}"
return [m for m in messages if m is not None]
@@ -774,7 +786,9 @@ def test_token_streaming_greeting_with_assistant_message(
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=True,
)
messages = accumulate_chunks(list(response))
messages = accumulate_chunks(
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
)
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@@ -803,7 +817,9 @@ def test_token_streaming_greeting_without_assistant_message(
use_assistant_message=False,
stream_tokens=True,
)
messages = accumulate_chunks(list(response))
messages = accumulate_chunks(
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
)
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@@ -831,7 +847,9 @@ def test_token_streaming_tool_call(
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=True,
)
messages = accumulate_chunks(list(response))
messages = accumulate_chunks(
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
)
assert_tool_call_response(messages, streaming=True, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)