feat: add usage tracking for new agent loop + filter out role=user messages in response (#2175)

This commit is contained in:
Sarah Wooders
2025-05-14 13:08:27 -07:00
committed by GitHub
parent cb46805a6e
commit 935b476dcd
3 changed files with 35 additions and 11 deletions

View File

@@ -10,14 +10,18 @@ from letta.server.rest_api.utils import create_input_messages
from letta.services.message_manager import MessageManager
def _create_letta_response(new_in_context_messages: list[Message], use_assistant_message: bool) -> LettaResponse:
def _create_letta_response(
new_in_context_messages: list[Message], use_assistant_message: bool, usage: LettaUsageStatistics
) -> LettaResponse:
"""
Converts the newly created/persisted messages into a LettaResponse.
"""
# NOTE: hacky solution to avoid returning heartbeat messages and the original user message
filter_user_messages = [m for m in new_in_context_messages if m.role != "user"]
response_messages = Message.to_letta_messages_from_list(
messages=new_in_context_messages, use_assistant_message=use_assistant_message, reverse=False
messages=filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
return LettaResponse(messages=response_messages, usage=LettaUsageStatistics())
return LettaResponse(messages=response_messages, usage=usage)
def _prepare_in_context_messages(

View File

@@ -4,6 +4,7 @@ import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from openai import AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from letta.agents.base_agent import BaseAgent
@@ -23,6 +24,7 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
from letta.services.agent_manager import AgentManager
@@ -65,14 +67,16 @@ class LettaAgent(BaseAgent):
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = await self._step(
current_in_context_messages, new_in_context_messages, usage = await self._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
)
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message)
return _create_letta_response(
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
)
async def _step(
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
) -> Tuple[List[Message], List[Message]]:
) -> Tuple[List[Message], List[Message], CompletionUsage]:
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_messages, agent_state, self.message_manager, self.actor
)
@@ -82,6 +86,7 @@ class LettaAgent(BaseAgent):
put_inner_thoughts_first=True,
actor=self.actor,
)
usage = LettaUsageStatistics()
for _ in range(max_steps):
response = await self._get_ai_reply(
llm_client=llm_client,
@@ -101,6 +106,13 @@ class LettaAgent(BaseAgent):
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
# update usage
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
if not should_continue:
break
@@ -109,7 +121,7 @@ class LettaAgent(BaseAgent):
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
return current_in_context_messages, new_in_context_messages
return current_in_context_messages, new_in_context_messages, usage
@trace_method
async def step_stream(
@@ -129,6 +141,7 @@ class LettaAgent(BaseAgent):
put_inner_thoughts_first=True,
actor=self.actor,
)
usage = LettaUsageStatistics()
for _ in range(max_steps):
stream = await self._get_ai_reply(
@@ -138,7 +151,6 @@ class LettaAgent(BaseAgent):
tool_rules_solver=tool_rules_solver,
stream=True,
)
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
interface = AnthropicStreamingInterface(
@@ -147,6 +159,12 @@ class LettaAgent(BaseAgent):
async for chunk in interface.process(stream):
yield f"data: {chunk.model_dump_json()}\n\n"
# update usage
usage.step_count += 1
usage.completion_tokens += interface.output_tokens
usage.prompt_tokens += interface.input_tokens
usage.total_tokens += interface.input_tokens + interface.output_tokens
# Process resulting stream content
tool_call = interface.get_tool_call_object()
reasoning_content = interface.get_reasoning_content()
@@ -179,7 +197,7 @@ class LettaAgent(BaseAgent):
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
# TODO: Also yield out a letta usage stats SSE
yield f"data: {usage.model_dump_json()}\n\n"
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
@trace_method

View File

@@ -74,7 +74,7 @@ class VoiceSleeptimeAgent(LettaAgent):
]
# Summarize
current_in_context_messages, new_in_context_messages = await super()._step(
current_in_context_messages, new_in_context_messages, usage = await super()._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
)
new_in_context_messages, updated = self.summarizer.summarize(
@@ -84,7 +84,9 @@ class VoiceSleeptimeAgent(LettaAgent):
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message)
return _create_letta_response(
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
)
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]: