fix: pass in usage statistics to patch streaming error (#2264)
This commit is contained in:
@@ -139,7 +139,7 @@ class LettaAgent(BaseAgent):
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning
|
||||
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -241,7 +241,7 @@ class LettaAgent(BaseAgent):
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning, step_id=step_id, usage=usage
|
||||
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning, step_id=step_id
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -352,11 +352,15 @@ class LettaAgent(BaseAgent):
|
||||
tool_call,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
UsageStatistics(
|
||||
completion_tokens=interface.output_tokens,
|
||||
prompt_tokens=interface.input_tokens,
|
||||
total_tokens=interface.input_tokens + interface.output_tokens,
|
||||
),
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
|
||||
pre_computed_tool_message_id=interface.letta_tool_message_id,
|
||||
step_id=step_id,
|
||||
usage=usage,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -472,11 +476,11 @@ class LettaAgent(BaseAgent):
|
||||
tool_call: ToolCall,
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
usage: UsageStatistics,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
usage: LettaUsageStatistics = None,
|
||||
) -> Tuple[List[Message], bool]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
@@ -537,11 +541,7 @@ class LettaAgent(BaseAgent):
|
||||
model=agent_state.llm_config.model,
|
||||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=agent_state.llm_config.context_window,
|
||||
usage=UsageStatistics(
|
||||
total_tokens=usage.total_tokens,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
),
|
||||
usage=usage,
|
||||
provider_id=None,
|
||||
job_id=None,
|
||||
step_id=step_id,
|
||||
|
||||
@@ -171,6 +171,10 @@ def assert_greeting_with_assistant_message_response(
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
assert messages[index].completion_tokens > 0
|
||||
assert messages[index].total_tokens > 0
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def assert_greeting_without_assistant_message_response(
|
||||
|
||||
Reference in New Issue
Block a user