feat: add usage tracking for new agent loop + filter out role=user messages in response (#2175)
This commit is contained in:
@@ -10,14 +10,18 @@ from letta.server.rest_api.utils import create_input_messages
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
|
||||
def _create_letta_response(new_in_context_messages: list[Message], use_assistant_message: bool) -> LettaResponse:
|
||||
def _create_letta_response(
|
||||
new_in_context_messages: list[Message], use_assistant_message: bool, usage: LettaUsageStatistics
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Converts the newly created/persisted messages into a LettaResponse.
|
||||
"""
|
||||
# NOTE: hacky solution to avoid returning heartbeat messages and the original user message
|
||||
filter_user_messages = [m for m in new_in_context_messages if m.role != "user"]
|
||||
response_messages = Message.to_letta_messages_from_list(
|
||||
messages=new_in_context_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||
messages=filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||
)
|
||||
return LettaResponse(messages=response_messages, usage=LettaUsageStatistics())
|
||||
return LettaResponse(messages=response_messages, usage=usage)
|
||||
|
||||
|
||||
def _prepare_in_context_messages(
|
||||
|
||||
@@ -4,6 +4,7 @@ import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
@@ -23,6 +24,7 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -65,14 +67,16 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
current_in_context_messages, new_in_context_messages = await self._step(
|
||||
current_in_context_messages, new_in_context_messages, usage = await self._step(
|
||||
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
|
||||
)
|
||||
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message)
|
||||
return _create_letta_response(
|
||||
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
|
||||
)
|
||||
|
||||
async def _step(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
) -> Tuple[List[Message], List[Message], CompletionUsage]:
|
||||
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
@@ -82,6 +86,7 @@ class LettaAgent(BaseAgent):
|
||||
put_inner_thoughts_first=True,
|
||||
actor=self.actor,
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
for _ in range(max_steps):
|
||||
response = await self._get_ai_reply(
|
||||
llm_client=llm_client,
|
||||
@@ -101,6 +106,13 @@ class LettaAgent(BaseAgent):
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
# update usage
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
@@ -109,7 +121,7 @@ class LettaAgent(BaseAgent):
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
return current_in_context_messages, new_in_context_messages, usage
|
||||
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
@@ -129,6 +141,7 @@ class LettaAgent(BaseAgent):
|
||||
put_inner_thoughts_first=True,
|
||||
actor=self.actor,
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
|
||||
for _ in range(max_steps):
|
||||
stream = await self._get_ai_reply(
|
||||
@@ -138,7 +151,6 @@ class LettaAgent(BaseAgent):
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
interface = AnthropicStreamingInterface(
|
||||
@@ -147,6 +159,12 @@ class LettaAgent(BaseAgent):
|
||||
async for chunk in interface.process(stream):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# update usage
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += interface.output_tokens
|
||||
usage.prompt_tokens += interface.input_tokens
|
||||
usage.total_tokens += interface.input_tokens + interface.output_tokens
|
||||
|
||||
# Process resulting stream content
|
||||
tool_call = interface.get_tool_call_object()
|
||||
reasoning_content = interface.get_reasoning_content()
|
||||
@@ -179,7 +197,7 @@ class LettaAgent(BaseAgent):
|
||||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
# TODO: Also yield out a letta usage stats SSE
|
||||
|
||||
yield f"data: {usage.model_dump_json()}\n\n"
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
|
||||
@trace_method
|
||||
|
||||
@@ -74,7 +74,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
]
|
||||
|
||||
# Summarize
|
||||
current_in_context_messages, new_in_context_messages = await super()._step(
|
||||
current_in_context_messages, new_in_context_messages, usage = await super()._step(
|
||||
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
|
||||
)
|
||||
new_in_context_messages, updated = self.summarizer.summarize(
|
||||
@@ -84,7 +84,9 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||||
)
|
||||
|
||||
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message)
|
||||
return _create_letta_response(
|
||||
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
|
||||
Reference in New Issue
Block a user