feat: support step streaming for new agent loop (#2182)
This commit is contained in:
@@ -75,6 +75,68 @@ class LettaAgent(BaseAgent):
|
||||
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor=self.actor,
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
for _ in range(max_steps):
|
||||
response = await self._get_ai_reply(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=False,
|
||||
# TODO: also pass in reasoning content
|
||||
)
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
# TODO: make into a real error
|
||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
||||
letta_messages = Message.to_letta_messages_from_list(
|
||||
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||
)
|
||||
for message in letta_messages:
|
||||
yield f"data: {message.model_dump_json()}\n\n"
|
||||
|
||||
# update usage
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
|
||||
# Return back usage
|
||||
yield f"data: {usage.model_dump_json()}\n\n"
|
||||
|
||||
async def _step(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
) -> Tuple[List[Message], List[Message], CompletionUsage]:
|
||||
@@ -98,6 +160,9 @@ class LettaAgent(BaseAgent):
|
||||
# TODO: also pass in reasoning content
|
||||
)
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
# TODO: make into a real error
|
||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
|
||||
@@ -126,7 +191,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main streaming loop that yields partial tokens.
|
||||
@@ -164,6 +229,9 @@ class LettaAgent(BaseAgent):
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
||||
|
||||
async for chunk in interface.process(stream):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
@@ -635,7 +635,7 @@ async def send_message(
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
|
||||
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "google_vertex", "google_ai"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
|
||||
|
||||
if agent_eligible and feature_enabled and model_compatible:
|
||||
experimental_agent = LettaAgent(
|
||||
@@ -706,13 +706,18 @@ async def send_message_streaming(
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
result = StreamingResponse(
|
||||
experimental_agent.step_stream(
|
||||
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
if request.stream_tokens:
|
||||
result = StreamingResponse(
|
||||
experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
result = StreamingResponse(
|
||||
experimental_agent.step_stream_no_tokens(
|
||||
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
|
||||
@@ -636,6 +636,33 @@ async def test_streaming_tool_call_async_client(
|
||||
assert_tool_call_response(messages, streaming=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_step_streaming_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
stream_tokens=False,
|
||||
)
|
||||
messages = []
|
||||
for message in response:
|
||||
messages.append(message)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
|
||||
Reference in New Issue
Block a user