fix: pass in usage statistics to patch streaming error (#2264)

This commit is contained in:
Sarah Wooders
2025-05-19 21:10:11 -07:00
committed by GitHub
parent 068f27d83d
commit 1465c48bb9
2 changed files with 13 additions and 9 deletions

View File

@@ -139,7 +139,7 @@ class LettaAgent(BaseAgent):
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
persisted_messages, should_continue = await self._handle_ai_response(
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
@@ -241,7 +241,7 @@ class LettaAgent(BaseAgent):
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
persisted_messages, should_continue = await self._handle_ai_response(
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning, step_id=step_id, usage=usage
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning, step_id=step_id
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
@@ -352,11 +352,15 @@ class LettaAgent(BaseAgent):
tool_call,
agent_state,
tool_rules_solver,
UsageStatistics(
completion_tokens=interface.output_tokens,
prompt_tokens=interface.input_tokens,
total_tokens=interface.input_tokens + interface.output_tokens,
),
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
pre_computed_tool_message_id=interface.letta_tool_message_id,
step_id=step_id,
usage=usage,
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
@@ -472,11 +476,11 @@ class LettaAgent(BaseAgent):
tool_call: ToolCall,
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
usage: UsageStatistics,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
step_id: str | None = None,
usage: LettaUsageStatistics = None,
) -> Tuple[List[Message], bool]:
"""
Now that streaming is done, handle the final AI response.
@@ -537,11 +541,7 @@ class LettaAgent(BaseAgent):
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(
total_tokens=usage.total_tokens,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
),
usage=usage,
provider_id=None,
job_id=None,
step_id=step_id,

View File

@@ -171,6 +171,10 @@ def assert_greeting_with_assistant_message_response(
if streaming:
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
assert messages[index].completion_tokens > 0
assert messages[index].total_tokens > 0
assert messages[index].step_count > 0
def assert_greeting_without_assistant_message_response(