feat: fix anthropic stream buffering issue (was a missing beta header) (#3933)

This commit is contained in:
Charles Packer
2025-08-14 22:26:56 -07:00
committed by GitHub
parent 6731f769ca
commit b0f335bb69
3 changed files with 370 additions and 6 deletions

View File

@@ -63,7 +63,13 @@ class AnthropicClient(LLMClientBase):
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
client = await self._get_anthropic_client_async(llm_config, async_client=True)
request_data["stream"] = True
return await client.beta.messages.create(**request_data)
# Add fine-grained tool streaming beta header for better streaming performance
# This helps reduce buffering when streaming tool call parameters
# See: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/fine-grained-streaming
betas = ["fine-grained-tool-streaming-2025-05-14"]
return await client.beta.messages.create(**request_data, betas=betas)
@trace_method
async def send_llm_batch_request_async(

View File

@@ -72,6 +72,23 @@ USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_LONG_RESPONSE: str = (
"Teamwork makes the dream work. When people collaborate and combine their unique skills, perspectives, and experiences, they can achieve far more than any individual working alone. "
"This synergy creates an environment where innovation flourishes, problems are solved more creatively, and goals are reached more efficiently. "
"In a team setting, diverse viewpoints lead to better decision-making as different team members bring their unique backgrounds and expertise to the table. "
"Communication becomes the backbone of success, allowing ideas to flow freely and ensuring everyone is aligned toward common objectives. "
"Trust builds gradually as team members learn to rely on each other's strengths while supporting one another through challenges. "
"The collective intelligence of a group often surpasses that of even the brightest individual, as collaboration sparks creativity and innovation. "
"Successful teams celebrate victories together and learn from failures as a unit, creating a culture of continuous improvement. "
"Together, we can overcome challenges that would be insurmountable alone, achieving extraordinary results through the power of collaboration."
)
USER_MESSAGE_FORCE_LONG_REPLY: List[MessageCreate] = [
MessageCreate(
role="user",
content=f"This is an automated test message. Call the send_message tool with exactly this message: '{USER_MESSAGE_LONG_RESPONSE}'",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_GREETING: List[MessageCreate] = [
MessageCreate(
role="user",
@@ -86,6 +103,23 @@ USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE_LONG: List[MessageCreate] = [
MessageCreate(
role="user",
content=(
"This is an automated test message. Call the roll_dice tool with 16 sides and send me a very detailed, comprehensive message about the outcome. "
"Your response must be at least 800 characters long. Start by explaining what dice rolling represents in games and probability theory. "
"Discuss the mathematical probability of getting each number on a 16-sided die (1/16 or 6.25% for each face). "
"Explain how 16-sided dice are commonly used in tabletop role-playing games like Dungeons & Dragons. "
"Describe the specific number you rolled and what it might mean in different gaming contexts. "
"Discuss how this particular roll compares to the expected value (8.5) of a 16-sided die. "
"Explain the concept of randomness and how true random number generation works. "
"End with some interesting facts about polyhedral dice and their history in gaming. "
"Remember, make your response detailed and at least 800 characters long."
),
otid=USER_MESSAGE_OTID,
)
]
URL_IMAGE = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
USER_MESSAGE_URL_IMAGE: List[MessageCreate] = [
MessageCreate(
@@ -172,7 +206,8 @@ def assert_greeting_with_assistant_message_response(
assert isinstance(messages[index], AssistantMessage)
if not token_streaming:
assert USER_MESSAGE_RESPONSE in messages[index].content
# Check for either short or long response
assert USER_MESSAGE_RESPONSE in messages[index].content or USER_MESSAGE_LONG_RESPONSE in messages[index].content
assert messages[index].otid and messages[index].otid[-1] == "1"
index += 1
@@ -491,7 +526,7 @@ def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -
and verify_token_streaming
and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
):
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}"
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}. Messages: {messages}"
current_message = None
chunk_count = 0
if current_message is None:
@@ -964,9 +999,14 @@ def test_token_streaming_greeting_with_assistant_message(
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Use longer message for Anthropic models to test if they stream in chunks
if llm_config.model_endpoint_type == "anthropic":
messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY
else:
messages_to_send = USER_MESSAGE_FORCE_REPLY
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
messages=messages_to_send,
stream_tokens=True,
)
messages = accumulate_chunks(
@@ -994,9 +1034,14 @@ def test_token_streaming_greeting_without_assistant_message(
"""
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Use longer message for Anthropic models to force chunking
if llm_config.model_endpoint_type == "anthropic":
messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY
else:
messages_to_send = USER_MESSAGE_FORCE_REPLY
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
messages=messages_to_send,
use_assistant_message=False,
stream_tokens=True,
)
@@ -1037,9 +1082,14 @@ def test_token_streaming_tool_call(
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Use longer message for Anthropic models to force chunking
if llm_config.model_endpoint_type == "anthropic":
messages_to_send = USER_MESSAGE_ROLL_DICE_LONG
else:
messages_to_send = USER_MESSAGE_ROLL_DICE
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
messages=messages_to_send,
stream_tokens=True,
)
messages = accumulate_chunks(

View File

@@ -0,0 +1,308 @@
"""Test to verify streaming behavior of Anthropic models with and without native reasoning.
This test confirms:
1. Sonnet 3.5 (20241022) with non-native reasoning exhibits batch streaming (API limitation)
- UPDATE: With fine-grained-tool-streaming beta header, this may improve
2. Sonnet 4 (20250514) with native reasoning should stream progressively
3. GPT-4.1 streams progressively as expected
Note: We've added the 'fine-grained-tool-streaming-2025-05-14' beta header to potentially
improve streaming performance with Anthropic models, especially for tool call parameters.
"""
import os
import time
from typing import List, Tuple
import pytest
from letta_client import Letta, MessageCreate
from tests.utils import wait_for_server
def run_server():
"""Start the Letta server."""
from dotenv import load_dotenv
load_dotenv()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(scope="module")
def client():
"""Create a Letta client for testing."""
import threading
# Get URL from environment or start server
api_url = os.getenv("LETTA_API_URL")
server_url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url)
print("Running client tests with server:", server_url)
# Override the base_url if the LETTA_API_URL is set
base_url = api_url if api_url else server_url
# create the Letta client
yield Letta(base_url=base_url, token=None)
@pytest.fixture
def agent_factory(client: Letta):
"""Factory fixture to create agents with different models."""
created_agents = []
def _create_agent(model_name: str):
"""Create an agent with the specified model."""
# Check for required API keys
if "claude" in model_name.lower():
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
if not anthropic_key:
pytest.skip("ANTHROPIC_API_KEY not set, skipping Anthropic test")
elif "gpt" in model_name.lower():
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
pytest.skip("OPENAI_API_KEY not set, skipping OpenAI test")
agent_state = client.agents.create(
name=f"test_agent_{model_name.replace('/', '_').replace('.', '_')}",
memory_blocks=[{"label": "human", "value": "Test user"}, {"label": "persona", "value": "You are a creative storyteller."}],
model=model_name,
embedding="letta/letta-free",
)
created_agents.append(agent_state)
return agent_state
yield _create_agent
# Cleanup all created agents
for agent_state in created_agents:
try:
client.agents.delete(agent_state.id)
except:
pass # Agent might have already been deleted
def detect_burst_chunks(chunks: List[Tuple[float, any]], burst_threshold: float = 0.05) -> List[List[int]]:
"""
Detect bursts of chunks arriving close together in time.
Args:
chunks: List of (timestamp, chunk) tuples
burst_threshold: Maximum time difference (in seconds) to consider chunks as part of the same burst
Returns:
List of bursts, where each burst is a list of chunk indices
"""
if not chunks:
return []
bursts = []
current_burst = [0]
for i in range(1, len(chunks)):
time_diff = chunks[i][0] - chunks[i - 1][0]
if time_diff <= burst_threshold:
# Part of the same burst
current_burst.append(i)
else:
# New burst
if len(current_burst) > 1: # Only count as burst if more than 1 chunk
bursts.append(current_burst)
current_burst = [i]
# Don't forget the last burst
if len(current_burst) > 1:
bursts.append(current_burst)
return bursts
@pytest.mark.parametrize(
"model,expected_buffering",
[
("anthropic/claude-3-5-sonnet-20241022", False), # With fine-grained streaming beta, should stream better
("anthropic/claude-sonnet-4-20250514", False), # Sonnet 4 should NOT show buffering (has native reasoning)
("openai/gpt-4.1", False), # GPT-4.1 should NOT show buffering (uses native reasoning)
],
)
def test_streaming_buffering_behavior(client: Letta, agent_factory, model: str, expected_buffering: bool):
"""
Test streaming behavior for different models.
With fine-grained-tool-streaming beta header:
- Sonnet 3.5 (20241022) should now stream progressively (beta feature improves tool streaming)
- Sonnet 4 (20250514) with native reasoning should stream progressively without buffering
- GPT-4.1 should stream progressively without buffering
"""
print(f"\n=== Testing Streaming Behavior for {model} ===")
print(f"Expected buffering: {expected_buffering}")
# Create agent with the specified model
agent = agent_factory(model)
# Send a message that should generate reasoning and tool calls
# This prompt should trigger inner thoughts and then a response
user_message = "Think step by step about what makes a good story, then write me a creative story about a toad named Ted. Make it exactly 3 paragraphs long."
# Create the stream
response_stream = client.agents.messages.create_stream(
agent_id=agent.id, messages=[MessageCreate(role="user", content=user_message)], stream_tokens=True # Enable token streaming
)
# Collect chunks with timestamps
chunks_with_time = []
reasoning_chunks = []
assistant_chunks = []
tool_chunks = []
start_time = time.time()
try:
for chunk in response_stream:
elapsed = time.time() - start_time
chunks_with_time.append((elapsed, chunk))
# Categorize chunks by type
chunk_type = type(chunk).__name__
chunk_info = f"[{elapsed:.3f}s] {chunk_type}"
# Check for different message types
if hasattr(chunk, "message_type"):
chunk_info += f" (message_type: {chunk.message_type})"
if chunk.message_type == "reasoning_message":
reasoning_chunks.append((elapsed, chunk))
elif chunk.message_type == "assistant_message":
assistant_chunks.append((elapsed, chunk))
elif chunk.message_type == "tool_call_message":
tool_chunks.append((elapsed, chunk))
elif type(chunk).__name__ == "ReasoningMessage":
chunk_info += " (ReasoningMessage)"
reasoning_chunks.append((elapsed, chunk))
elif type(chunk).__name__ == "AssistantMessage":
chunk_info += " (AssistantMessage)"
assistant_chunks.append((elapsed, chunk))
elif type(chunk).__name__ == "ToolCallMessage":
chunk_info += " (ToolCallMessage)"
tool_chunks.append((elapsed, chunk))
# Check for inner thoughts (in tool calls for non-native reasoning)
if hasattr(chunk, "tool_calls") and chunk.tool_calls:
for tool_call in chunk.tool_calls:
if hasattr(tool_call, "function") and hasattr(tool_call.function, "arguments"):
# Check if this is inner thoughts
if "inner_thoughts" in str(tool_call.function.arguments):
chunk_info += " [contains inner_thoughts]"
tool_chunks.append((elapsed, chunk))
print(chunk_info)
# Optional: print chunk content snippet for debugging
if hasattr(chunk, "content") and chunk.content:
content_preview = str(chunk.content)[:100]
if content_preview:
print(f" Content: {content_preview}...")
except Exception as e:
print(f"Stream error: {e}")
import traceback
traceback.print_exc()
# Analyze results
print(f"\n=== Analysis ===")
print(f"Total chunks: {len(chunks_with_time)}")
print(f"Reasoning chunks: {len(reasoning_chunks)}")
print(f"Assistant chunks: {len(assistant_chunks)}")
print(f"Tool chunks: {len(tool_chunks)}")
# Detect bursts for each type
if reasoning_chunks:
reasoning_bursts = detect_burst_chunks(reasoning_chunks)
print(f"\nReasoning bursts detected: {len(reasoning_bursts)}")
for i, burst in enumerate(reasoning_bursts):
burst_times = [reasoning_chunks[idx][0] for idx in burst]
print(f" Burst {i+1}: {len(burst)} chunks from {burst_times[0]:.3f}s to {burst_times[-1]:.3f}s")
if assistant_chunks:
assistant_bursts = detect_burst_chunks(assistant_chunks)
print(f"\nAssistant bursts detected: {len(assistant_bursts)}")
for i, burst in enumerate(assistant_bursts):
burst_times = [assistant_chunks[idx][0] for idx in burst]
print(f" Burst {i+1}: {len(burst)} chunks from {burst_times[0]:.3f}s to {burst_times[-1]:.3f}s")
if tool_chunks:
tool_bursts = detect_burst_chunks(tool_chunks)
print(f"\nTool call bursts detected: {len(tool_bursts)}")
for i, burst in enumerate(tool_bursts):
burst_times = [tool_chunks[idx][0] for idx in burst]
print(f" Burst {i+1}: {len(burst)} chunks from {burst_times[0]:.3f}s to {burst_times[-1]:.3f}s")
# Analyze results based on expected behavior
print(f"\n=== Test Results ===")
# Check if we detected large bursts
has_significant_bursts = False
if reasoning_chunks:
reasoning_bursts = detect_burst_chunks(reasoning_chunks, burst_threshold=0.1)
if reasoning_bursts:
largest_burst = max(reasoning_bursts, key=len)
burst_percentage = len(largest_burst) / len(reasoning_chunks) * 100
print(f"\nLargest reasoning burst: {len(largest_burst)}/{len(reasoning_chunks)} chunks ({burst_percentage:.1f}%)")
if burst_percentage >= 80: # Consider 80%+ as significant buffering
has_significant_bursts = True
print(f" -> BUFFERING DETECTED: {burst_percentage:.1f}% of reasoning chunks in single burst")
if assistant_chunks:
assistant_bursts = detect_burst_chunks(assistant_chunks, burst_threshold=0.1)
if assistant_bursts:
largest_burst = max(assistant_bursts, key=len)
burst_percentage = len(largest_burst) / len(assistant_chunks) * 100
print(f"Largest assistant burst: {len(largest_burst)}/{len(assistant_chunks)} chunks ({burst_percentage:.1f}%)")
if burst_percentage >= 80:
has_significant_bursts = True
print(f" -> BUFFERING DETECTED: {burst_percentage:.1f}% of assistant chunks in single burst")
if tool_chunks:
tool_bursts = detect_burst_chunks(tool_chunks, burst_threshold=0.1)
if tool_bursts:
largest_burst = max(tool_bursts, key=len)
burst_percentage = len(largest_burst) / len(tool_chunks) * 100
print(f"Largest tool burst: {len(largest_burst)}/{len(tool_chunks)} chunks ({burst_percentage:.1f}%)")
if burst_percentage >= 80:
has_significant_bursts = True
print(f" -> BUFFERING DETECTED: {burst_percentage:.1f}% of tool chunks in single burst")
# Overall streaming analysis
total_time = chunks_with_time[-1][0] if chunks_with_time else 0
avg_time_between = total_time / len(chunks_with_time) if chunks_with_time else 0
print(f"\nTotal streaming time: {total_time:.2f}s")
print(f"Average time between chunks: {avg_time_between:.3f}s")
# Verify test expectations
if expected_buffering:
assert has_significant_bursts, (
f"Expected buffering behavior for {model}, but streaming appeared progressive. "
f"This suggests the issue may be fixed or the test isn't detecting it properly."
)
print(f"\n✓ Test PASSED: {model} shows expected buffering behavior")
else:
assert not has_significant_bursts, (
f"Did NOT expect buffering for {model}, but detected significant burst behavior. "
f"This suggests {model} may also have streaming issues."
)
print(f"\n✓ Test PASSED: {model} shows expected progressive streaming")
if __name__ == "__main__":
# Allow running directly for debugging
pytest.main([__file__, "-v", "-s"])