feat: support step streaming for new agent loop (#2182)

This commit is contained in:
Sarah Wooders
2025-05-17 17:22:20 -07:00
committed by GitHub
parent a75effa69c
commit 42d3ce6d09
3 changed files with 109 additions and 9 deletions

View File

@@ -75,6 +75,68 @@ class LettaAgent(BaseAgent):
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
)
@trace_method
async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True):
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
usage = LettaUsageStatistics()
for _ in range(max_steps):
response = await self._get_ai_reply(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=False,
# TODO: also pass in reasoning content
)
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
persisted_messages, should_continue = await self._handle_ai_response(
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
# stream step
# TODO: improve TTFT
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
for message in letta_messages:
yield f"data: {message.model_dump_json()}\n\n"
# update usage
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
# Return back usage
yield f"data: {usage.model_dump_json()}\n\n"
async def _step(
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
) -> Tuple[List[Message], List[Message], CompletionUsage]:
@@ -98,6 +160,9 @@ class LettaAgent(BaseAgent):
# TODO: also pass in reasoning content
)
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
@@ -126,7 +191,7 @@ class LettaAgent(BaseAgent):
@trace_method
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
@@ -164,6 +229,9 @@ class LettaAgent(BaseAgent):
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
else:
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
async for chunk in interface.process(stream):
yield f"data: {chunk.model_dump_json()}\n\n"

View File

@@ -635,7 +635,7 @@ async def send_message(
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "google_vertex", "google_ai"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
if agent_eligible and feature_enabled and model_compatible:
experimental_agent = LettaAgent(
@@ -706,13 +706,18 @@ async def send_message_streaming(
passage_manager=server.passage_manager,
actor=actor,
)
result = StreamingResponse(
experimental_agent.step_stream(
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens
),
media_type="text/event-stream",
)
if request.stream_tokens:
result = StreamingResponse(
experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message),
media_type="text/event-stream",
)
else:
result = StreamingResponse(
experimental_agent.step_stream_no_tokens(
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message
),
media_type="text/event-stream",
)
else:
result = await server.send_message_to_agent(
agent_id=agent_id,

View File

@@ -636,6 +636,33 @@ async def test_streaming_tool_call_async_client(
assert_tool_call_response(messages, streaming=True)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_step_streaming_greeting_with_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
stream_tokens=False,
)
messages = []
for message in response:
messages.append(message)
assert_greeting_with_assistant_message_response(messages, streaming=True)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,