From 1465c48bb96e14da71afef1eff00562cd464f24a Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Mon, 19 May 2025 21:10:11 -0700 Subject: [PATCH] fix: pass in usage statistics to patch streaming error (#2264) --- letta/agents/letta_agent.py | 18 +++++++++--------- tests/integration_test_send_message.py | 4 ++++ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 718fd583..88af2eca 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -139,7 +139,7 @@ class LettaAgent(BaseAgent): reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons persisted_messages, should_continue = await self._handle_ai_response( - tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning + tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -241,7 +241,7 @@ class LettaAgent(BaseAgent): reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons persisted_messages, should_continue = await self._handle_ai_response( - tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning, step_id=step_id, usage=usage + tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning, step_id=step_id ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -352,11 +352,15 @@ class LettaAgent(BaseAgent): tool_call, agent_state, tool_rules_solver, + UsageStatistics( + completion_tokens=interface.output_tokens, + prompt_tokens=interface.input_tokens, + total_tokens=interface.input_tokens + interface.output_tokens, + ), reasoning_content=reasoning_content, pre_computed_assistant_message_id=interface.letta_assistant_message_id, pre_computed_tool_message_id=interface.letta_tool_message_id, step_id=step_id, - usage=usage, ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -472,11 +476,11 @@ class LettaAgent(BaseAgent): tool_call: ToolCall, agent_state: AgentState, tool_rules_solver: ToolRulesSolver, + usage: UsageStatistics, reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, pre_computed_tool_message_id: Optional[str] = None, step_id: str | None = None, - usage: LettaUsageStatistics = None, ) -> Tuple[List[Message], bool]: """ Now that streaming is done, handle the final AI response. @@ -537,11 +541,7 @@ class LettaAgent(BaseAgent): model=agent_state.llm_config.model, model_endpoint=agent_state.llm_config.model_endpoint, context_window_limit=agent_state.llm_config.context_window, - usage=UsageStatistics( - total_tokens=usage.total_tokens, - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens, - ), + usage=usage, provider_id=None, job_id=None, step_id=step_id, diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index afaf7959..a7cc37f0 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -171,6 +171,10 @@ def assert_greeting_with_assistant_message_response( if streaming: assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 def assert_greeting_without_assistant_message_response(