diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 7847956a..5f60dcbb 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -10,14 +10,18 @@ from letta.server.rest_api.utils import create_input_messages from letta.services.message_manager import MessageManager -def _create_letta_response(new_in_context_messages: list[Message], use_assistant_message: bool) -> LettaResponse: +def _create_letta_response( + new_in_context_messages: list[Message], use_assistant_message: bool, usage: LettaUsageStatistics +) -> LettaResponse: """ Converts the newly created/persisted messages into a LettaResponse. """ + # NOTE: hacky solution to avoid returning heartbeat messages and the original user message + filter_user_messages = [m for m in new_in_context_messages if m.role != "user"] response_messages = Message.to_letta_messages_from_list( - messages=new_in_context_messages, use_assistant_message=use_assistant_message, reverse=False + messages=filter_user_messages, use_assistant_message=use_assistant_message, reverse=False ) - return LettaResponse(messages=response_messages, usage=LettaUsageStatistics()) + return LettaResponse(messages=response_messages, usage=usage) def _prepare_in_context_messages( diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index b2564e7d..7aa89be8 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -4,6 +4,7 @@ import uuid from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from openai import AsyncStream +from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from letta.agents.base_agent import BaseAgent @@ -23,6 +24,7 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import ToolCall +from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.server.rest_api.utils import create_letta_messages_from_llm_response from letta.services.agent_manager import AgentManager @@ -65,14 +67,16 @@ class LettaAgent(BaseAgent): @trace_method async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse: agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor) - current_in_context_messages, new_in_context_messages = await self._step( + current_in_context_messages, new_in_context_messages, usage = await self._step( agent_state=agent_state, input_messages=input_messages, max_steps=max_steps ) - return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message) + return _create_letta_response( + new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage + ) async def _step( self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10 - ) -> Tuple[List[Message], List[Message]]: + ) -> Tuple[List[Message], List[Message], CompletionUsage]: current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( input_messages, agent_state, self.message_manager, self.actor ) @@ -82,6 +86,7 @@ class LettaAgent(BaseAgent): put_inner_thoughts_first=True, actor=self.actor, ) + usage = LettaUsageStatistics() for _ in range(max_steps): response = await self._get_ai_reply( llm_client=llm_client, @@ -101,6 +106,13 @@ class LettaAgent(BaseAgent): self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) + # update usage + # TODO: add run_id + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens + if not should_continue: break @@ -109,7 +121,7 @@ class LettaAgent(BaseAgent): message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) - return current_in_context_messages, new_in_context_messages + return current_in_context_messages, new_in_context_messages, usage @trace_method async def step_stream( @@ -129,6 +141,7 @@ class LettaAgent(BaseAgent): put_inner_thoughts_first=True, actor=self.actor, ) + usage = LettaUsageStatistics() for _ in range(max_steps): stream = await self._get_ai_reply( @@ -138,7 +151,6 @@ class LettaAgent(BaseAgent): tool_rules_solver=tool_rules_solver, stream=True, ) - # TODO: THIS IS INCREDIBLY UGLY # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED interface = AnthropicStreamingInterface( @@ -147,6 +159,12 @@ class LettaAgent(BaseAgent): async for chunk in interface.process(stream): yield f"data: {chunk.model_dump_json()}\n\n" + # update usage + usage.step_count += 1 + usage.completion_tokens += interface.output_tokens + usage.prompt_tokens += interface.input_tokens + usage.total_tokens += interface.input_tokens + interface.output_tokens + # Process resulting stream content tool_call = interface.get_tool_call_object() reasoning_content = interface.get_reasoning_content() @@ -179,7 +197,7 @@ class LettaAgent(BaseAgent): self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) # TODO: Also yield out a letta usage stats SSE - + yield f"data: {usage.model_dump_json()}\n\n" yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n" @trace_method diff --git a/letta/agents/voice_sleeptime_agent.py b/letta/agents/voice_sleeptime_agent.py index 9ed3bc26..86922571 100644 --- a/letta/agents/voice_sleeptime_agent.py +++ b/letta/agents/voice_sleeptime_agent.py @@ -74,7 +74,7 @@ class VoiceSleeptimeAgent(LettaAgent): ] # Summarize - current_in_context_messages, new_in_context_messages = await super()._step( + current_in_context_messages, new_in_context_messages, usage = await super()._step( agent_state=agent_state, input_messages=input_messages, max_steps=max_steps ) new_in_context_messages, updated = self.summarizer.summarize( @@ -84,7 +84,9 @@ class VoiceSleeptimeAgent(LettaAgent): agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor ) - return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message) + return _create_letta_response( + new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage + ) @trace_method async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]: