test: add token count check in streaming tests (#2936)
This commit is contained in:
@@ -331,24 +331,36 @@ def assert_image_input_response(
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def accumulate_chunks(chunks: List[Any]) -> List[Any]:
|
||||
def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]:
|
||||
"""
|
||||
Accumulates chunks into a list of messages.
|
||||
"""
|
||||
messages = []
|
||||
current_message = None
|
||||
prev_message_type = None
|
||||
chunk_count = 0
|
||||
for chunk in chunks:
|
||||
current_message_type = chunk.message_type
|
||||
if prev_message_type != current_message_type:
|
||||
messages.append(current_message)
|
||||
if (
|
||||
prev_message_type
|
||||
and verify_token_streaming
|
||||
and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
|
||||
):
|
||||
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}"
|
||||
current_message = None
|
||||
chunk_count = 0
|
||||
if current_message is None:
|
||||
current_message = chunk
|
||||
else:
|
||||
pass # TODO: actually accumulate the chunks. For now we only care about the count
|
||||
prev_message_type = current_message_type
|
||||
chunk_count += 1
|
||||
messages.append(current_message)
|
||||
if verify_token_streaming and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]:
|
||||
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}"
|
||||
|
||||
return [m for m in messages if m is not None]
|
||||
|
||||
|
||||
@@ -774,7 +786,9 @@ def test_token_streaming_greeting_with_assistant_message(
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
stream_tokens=True,
|
||||
)
|
||||
messages = accumulate_chunks(list(response))
|
||||
messages = accumulate_chunks(
|
||||
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
|
||||
)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
@@ -803,7 +817,9 @@ def test_token_streaming_greeting_without_assistant_message(
|
||||
use_assistant_message=False,
|
||||
stream_tokens=True,
|
||||
)
|
||||
messages = accumulate_chunks(list(response))
|
||||
messages = accumulate_chunks(
|
||||
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
|
||||
)
|
||||
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
@@ -831,7 +847,9 @@ def test_token_streaming_tool_call(
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
stream_tokens=True,
|
||||
)
|
||||
messages = accumulate_chunks(list(response))
|
||||
messages = accumulate_chunks(
|
||||
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
|
||||
)
|
||||
assert_tool_call_response(messages, streaming=True, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
Reference in New Issue
Block a user