* feat: add usage columns to steps table Adds denormalized usage fields to the steps table for easier querying: - model_handle: The model handle (e.g., "openai/gpt-4o-mini") - cached_input_tokens: Tokens served from cache - cache_write_tokens: Tokens written to cache (Anthropic) - reasoning_tokens: Reasoning/thinking tokens These fields mirror LettaUsageStatistics and are extracted from the existing prompt_tokens_details and completion_tokens_details JSON columns. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * chore: regenerate OpenAPI specs and SDK for usage columns 🤖 Generated with [Letta Code](https://letta.com) Co-authored-by: Sarah Wooders <sarahwooders@users.noreply.github.com> --------- Co-authored-by: Letta <noreply@letta.com> Co-authored-by: letta-code <248085862+letta-code@users.noreply.github.com> Co-authored-by: Sarah Wooders <sarahwooders@users.noreply.github.com>
1953 lines
98 KiB
Python
1953 lines
98 KiB
Python
import json
|
||
import uuid
|
||
from collections.abc import AsyncGenerator
|
||
from datetime import datetime
|
||
from typing import Optional, Union
|
||
|
||
from openai import AsyncStream
|
||
from openai.types.chat import ChatCompletionChunk
|
||
from opentelemetry.trace import Span
|
||
|
||
from letta.agents.base_agent import BaseAgent
|
||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||
from letta.agents.helpers import (
|
||
_build_rule_violation_result,
|
||
_create_letta_response,
|
||
_load_last_function_response,
|
||
_pop_heartbeat,
|
||
_prepare_in_context_messages_no_persist_async,
|
||
_safe_load_tool_call_str,
|
||
generate_step_id,
|
||
)
|
||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
|
||
from letta.errors import ContextWindowExceededError
|
||
from letta.helpers import ToolRulesSolver
|
||
from letta.helpers.datetime_helpers import AsyncTimer, get_utc_time, get_utc_timestamp_ns, ns_to_ms
|
||
from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
|
||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
||
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
|
||
from letta.llm_api.llm_client import LLMClient
|
||
from letta.llm_api.llm_client_base import LLMClientBase
|
||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||
from letta.log import get_logger
|
||
from letta.otel.context import get_ctx_attributes
|
||
from letta.otel.metric_registry import MetricRegistry
|
||
from letta.otel.tracing import log_event, trace_method, tracer
|
||
from letta.schemas.agent import AgentState, UpdateAgent
|
||
from letta.schemas.enums import JobStatus, LLMCallType, ProviderType, StepStatus, ToolType
|
||
from letta.schemas.letta_message import MessageType
|
||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||
from letta.schemas.letta_response import LettaResponse
|
||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||
from letta.schemas.llm_config import LLMConfig
|
||
from letta.schemas.message import Message, MessageCreateBase
|
||
from letta.schemas.openai.chat_completion_response import (
|
||
FunctionCall,
|
||
ToolCall,
|
||
UsageStatistics,
|
||
UsageStatisticsCompletionTokenDetails,
|
||
UsageStatisticsPromptTokenDetails,
|
||
)
|
||
from letta.schemas.step import StepProgression
|
||
from letta.schemas.step_metrics import StepMetrics
|
||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||
from letta.schemas.usage import LettaUsageStatistics
|
||
from letta.schemas.user import User
|
||
from letta.server.rest_api.utils import (
|
||
create_approval_request_message_from_llm_response,
|
||
create_letta_messages_from_llm_response,
|
||
)
|
||
from letta.services.agent_manager import AgentManager
|
||
from letta.services.block_manager import BlockManager
|
||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||
from letta.services.job_manager import JobManager
|
||
from letta.services.message_manager import MessageManager
|
||
from letta.services.passage_manager import PassageManager
|
||
from letta.services.step_manager import NoopStepManager, StepManager
|
||
from letta.services.summarizer.enums import SummarizationMode
|
||
from letta.services.summarizer.summarizer import Summarizer
|
||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||
from letta.settings import model_settings, settings, summarizer_settings
|
||
from letta.system import package_function_response
|
||
from letta.types import JsonDict
|
||
from letta.utils import log_telemetry, validate_function_response
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
DEFAULT_SUMMARY_BLOCK_LABEL = "conversation_summary"
|
||
|
||
|
||
class LettaAgent(BaseAgent):
|
||
def __init__(
|
||
self,
|
||
agent_id: str,
|
||
message_manager: MessageManager,
|
||
agent_manager: AgentManager,
|
||
block_manager: BlockManager,
|
||
job_manager: JobManager,
|
||
passage_manager: PassageManager,
|
||
actor: User,
|
||
step_manager: StepManager = NoopStepManager(),
|
||
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
||
current_run_id: str | None = None,
|
||
## summarizer settings
|
||
summarizer_mode: SummarizationMode = summarizer_settings.mode,
|
||
# for static_buffer mode
|
||
summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL,
|
||
message_buffer_limit: int = summarizer_settings.message_buffer_limit,
|
||
message_buffer_min: int = summarizer_settings.message_buffer_min,
|
||
enable_summarization: bool = summarizer_settings.enable_summarization,
|
||
max_summarization_retries: int = summarizer_settings.max_summarization_retries,
|
||
# for partial_evict mode
|
||
partial_evict_summarizer_percentage: float = summarizer_settings.partial_evict_summarizer_percentage,
|
||
):
|
||
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
|
||
|
||
# TODO: Make this more general, factorable
|
||
# Summarizer settings
|
||
self.block_manager = block_manager
|
||
self.job_manager = job_manager
|
||
self.passage_manager = passage_manager
|
||
self.step_manager = step_manager
|
||
self.telemetry_manager = telemetry_manager
|
||
self.job_manager = job_manager
|
||
self.current_run_id = current_run_id
|
||
self.response_messages: list[Message] = []
|
||
|
||
self.last_function_response = None
|
||
|
||
# Cached archival memory/message size
|
||
self.num_messages = None
|
||
self.num_archival_memories = None
|
||
|
||
self.summarization_agent = None
|
||
self.summary_block_label = summary_block_label
|
||
self.max_summarization_retries = max_summarization_retries
|
||
self.logger = get_logger(agent_id)
|
||
|
||
# TODO: Expand to more
|
||
if enable_summarization and model_settings.openai_api_key:
|
||
self.summarization_agent = EphemeralSummaryAgent(
|
||
target_block_label=self.summary_block_label,
|
||
agent_id=agent_id,
|
||
block_manager=self.block_manager,
|
||
message_manager=self.message_manager,
|
||
agent_manager=self.agent_manager,
|
||
actor=self.actor,
|
||
)
|
||
|
||
self.summarizer = Summarizer(
|
||
mode=summarizer_mode,
|
||
# TODO consolidate to not use this, or push it into the Summarizer() class
|
||
summarizer_agent=self.summarization_agent,
|
||
# TODO: Make this configurable
|
||
message_buffer_limit=message_buffer_limit,
|
||
message_buffer_min=message_buffer_min,
|
||
partial_evict_summarizer_percentage=partial_evict_summarizer_percentage,
|
||
agent_manager=self.agent_manager,
|
||
message_manager=self.message_manager,
|
||
actor=self.actor,
|
||
agent_id=self.agent_id,
|
||
)
|
||
|
||
async def _check_run_cancellation(self) -> bool:
|
||
"""
|
||
Check if the current run associated with this agent execution has been cancelled.
|
||
|
||
Returns:
|
||
True if the run is cancelled, False otherwise (or if no run is associated)
|
||
"""
|
||
if not self.job_manager or not self.current_run_id:
|
||
return False
|
||
|
||
try:
|
||
job = await self.job_manager.get_job_by_id_async(job_id=self.current_run_id, actor=self.actor)
|
||
return job.status == JobStatus.cancelled
|
||
except Exception as e:
|
||
# Log the error but don't fail the execution
|
||
logger.warning(f"Failed to check job cancellation status for job {self.current_run_id}: {e}")
|
||
return False
|
||
|
||
@trace_method
|
||
async def step(
|
||
self,
|
||
input_messages: list[MessageCreateBase],
|
||
max_steps: int = DEFAULT_MAX_STEPS,
|
||
run_id: str | None = None,
|
||
use_assistant_message: bool = True,
|
||
request_start_timestamp_ns: int | None = None,
|
||
include_return_message_types: list[MessageType] | None = None,
|
||
dry_run: bool = False,
|
||
) -> Union[LettaResponse, dict]:
|
||
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
|
||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||
agent_id=self.agent_id,
|
||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||
actor=self.actor,
|
||
)
|
||
result = await self._step(
|
||
agent_state=agent_state,
|
||
input_messages=input_messages,
|
||
max_steps=max_steps,
|
||
run_id=run_id,
|
||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||
dry_run=dry_run,
|
||
)
|
||
|
||
# If dry run, return the request payload directly
|
||
if dry_run:
|
||
return result
|
||
|
||
_, new_in_context_messages, stop_reason, usage = result
|
||
return _create_letta_response(
|
||
new_in_context_messages=new_in_context_messages,
|
||
use_assistant_message=use_assistant_message,
|
||
stop_reason=stop_reason,
|
||
usage=usage,
|
||
include_return_message_types=include_return_message_types,
|
||
)
|
||
|
||
@trace_method
|
||
async def step_stream_no_tokens(
|
||
self,
|
||
input_messages: list[MessageCreateBase],
|
||
max_steps: int = DEFAULT_MAX_STEPS,
|
||
use_assistant_message: bool = True,
|
||
request_start_timestamp_ns: int | None = None,
|
||
include_return_message_types: list[MessageType] | None = None,
|
||
run_id: str | None = None,
|
||
):
|
||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||
agent_id=self.agent_id,
|
||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||
actor=self.actor,
|
||
)
|
||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
||
input_messages, agent_state, self.message_manager, self.actor
|
||
)
|
||
initial_messages = new_in_context_messages
|
||
in_context_messages = current_in_context_messages
|
||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||
llm_client = LLMClient.create(
|
||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||
put_inner_thoughts_first=True,
|
||
actor=self.actor,
|
||
)
|
||
stop_reason = None
|
||
job_update_metadata = None
|
||
usage = LettaUsageStatistics()
|
||
|
||
# span for request
|
||
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||
|
||
for i in range(max_steps):
|
||
if in_context_messages[-1].role == "approval":
|
||
approval_request_message = in_context_messages[-1]
|
||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
|
||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||
approval_request_message.tool_calls[0],
|
||
[], # TODO: update this
|
||
agent_state,
|
||
tool_rules_solver,
|
||
usage,
|
||
reasoning_content=approval_request_message.content,
|
||
step_id=approval_request_message.step_id,
|
||
initial_messages=initial_messages,
|
||
is_final_step=(i == max_steps - 1),
|
||
step_metrics=step_metrics,
|
||
run_id=self.current_run_id,
|
||
is_approval=input_messages[0].approve,
|
||
is_denial=input_messages[0].approve == False,
|
||
denial_reason=input_messages[0].reason,
|
||
)
|
||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||
initial_messages = None
|
||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||
|
||
# stream step
|
||
# TODO: improve TTFT
|
||
filter_user_messages = [m for m in persisted_messages if m.role != "user" and m.role != "approval"]
|
||
letta_messages = Message.to_letta_messages_from_list(
|
||
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||
)
|
||
|
||
for message in letta_messages:
|
||
if include_return_message_types is None or message.message_type in include_return_message_types:
|
||
yield f"data: {message.model_dump_json()}\n\n"
|
||
else:
|
||
# Check for job cancellation at the start of each step
|
||
if await self._check_run_cancellation():
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
|
||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
||
break
|
||
|
||
step_id = generate_step_id()
|
||
step_start = get_utc_timestamp_ns()
|
||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||
agent_step_span.set_attributes({"step_id": step_id})
|
||
|
||
step_progression = StepProgression.START
|
||
should_continue = False
|
||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||
|
||
# Create step early with PENDING status
|
||
logged_step = await self.step_manager.log_step_async(
|
||
actor=self.actor,
|
||
agent_id=agent_state.id,
|
||
provider_name=agent_state.llm_config.model_endpoint_type,
|
||
provider_category=agent_state.llm_config.provider_category or "base",
|
||
model=agent_state.llm_config.model,
|
||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||
context_window_limit=agent_state.llm_config.context_window,
|
||
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||
provider_id=None,
|
||
run_id=self.current_run_id if self.current_run_id else None,
|
||
step_id=step_id,
|
||
project_id=agent_state.project_id,
|
||
status=StepStatus.PENDING,
|
||
model_handle=agent_state.llm_config.handle,
|
||
)
|
||
# Only use step_id in messages if step was actually created
|
||
effective_step_id = step_id if logged_step else None
|
||
|
||
try:
|
||
(
|
||
request_data,
|
||
response_data,
|
||
current_in_context_messages,
|
||
new_in_context_messages,
|
||
valid_tool_names,
|
||
) = await self._build_and_request_from_llm(
|
||
current_in_context_messages,
|
||
new_in_context_messages,
|
||
agent_state,
|
||
llm_client,
|
||
tool_rules_solver,
|
||
agent_step_span,
|
||
step_metrics,
|
||
run_id=run_id,
|
||
)
|
||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||
|
||
step_progression = StepProgression.RESPONSE_RECEIVED
|
||
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
||
|
||
try:
|
||
response = await llm_client.convert_response_to_chat_completion(
|
||
response_data, in_context_messages, agent_state.llm_config
|
||
)
|
||
except ValueError as e:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value)
|
||
raise e
|
||
|
||
# update usage
|
||
usage.step_count += 1
|
||
usage.completion_tokens += response.usage.completion_tokens
|
||
usage.prompt_tokens += response.usage.prompt_tokens
|
||
usage.total_tokens += response.usage.total_tokens
|
||
MetricRegistry().message_output_tokens.record(
|
||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||
)
|
||
|
||
if not response.choices[0].message.tool_calls:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||
tool_call = response.choices[0].message.tool_calls[0]
|
||
if response.choices[0].message.reasoning_content:
|
||
reasoning = [
|
||
ReasoningContent(
|
||
reasoning=response.choices[0].message.reasoning_content,
|
||
is_native=True,
|
||
signature=response.choices[0].message.reasoning_content_signature,
|
||
)
|
||
]
|
||
elif response.choices[0].message.omitted_reasoning_content:
|
||
reasoning = [OmittedReasoningContent()]
|
||
elif response.choices[0].message.content:
|
||
reasoning = [
|
||
TextContent(text=response.choices[0].message.content)
|
||
] # reasoning placed into content for legacy reasons
|
||
else:
|
||
self.logger.info("No reasoning content found.")
|
||
reasoning = None
|
||
|
||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||
tool_call,
|
||
valid_tool_names,
|
||
agent_state,
|
||
tool_rules_solver,
|
||
response.usage,
|
||
reasoning_content=reasoning,
|
||
step_id=effective_step_id,
|
||
initial_messages=initial_messages,
|
||
agent_step_span=agent_step_span,
|
||
is_final_step=(i == max_steps - 1),
|
||
step_metrics=step_metrics,
|
||
)
|
||
step_progression = StepProgression.STEP_LOGGED
|
||
|
||
# Update step with actual usage now that we have it (if step was created)
|
||
if logged_step:
|
||
await self.step_manager.update_step_success_async(self.actor, step_id, response.usage, stop_reason)
|
||
|
||
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic
|
||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||
initial_messages = None
|
||
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
||
|
||
# log step time
|
||
now = get_utc_timestamp_ns()
|
||
step_ns = now - step_start
|
||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||
agent_step_span.end()
|
||
|
||
# stream step
|
||
# TODO: improve TTFT
|
||
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
||
letta_messages = Message.to_letta_messages_from_list(
|
||
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||
)
|
||
letta_messages = [m for m in letta_messages if m.message_type != "approval_response_message"]
|
||
|
||
for message in letta_messages:
|
||
if include_return_message_types is None or message.message_type in include_return_message_types:
|
||
yield f"data: {message.model_dump_json()}\n\n"
|
||
|
||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||
step_progression = StepProgression.FINISHED
|
||
|
||
# Record step metrics for successful completion
|
||
if logged_step and step_metrics:
|
||
# Set the step_ns that was already calculated
|
||
step_metrics.step_ns = step_ns
|
||
await self._record_step_metrics(
|
||
step_id=step_id,
|
||
agent_state=agent_state,
|
||
step_metrics=step_metrics,
|
||
)
|
||
|
||
except Exception as e:
|
||
# Handle any unexpected errors during step processing
|
||
self.logger.error(f"Error during step processing: {e}")
|
||
job_update_metadata = {"error": str(e)}
|
||
|
||
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||
if not stop_reason:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||
elif stop_reason.stop_reason not in (
|
||
StopReasonType.no_tool_call,
|
||
StopReasonType.invalid_tool_call,
|
||
StopReasonType.invalid_llm_response,
|
||
):
|
||
self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason)
|
||
|
||
# Send error stop reason to client and re-raise
|
||
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
|
||
raise
|
||
|
||
# Update step if it needs to be updated
|
||
finally:
|
||
if step_progression == StepProgression.FINISHED and should_continue:
|
||
continue
|
||
|
||
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
|
||
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||
try:
|
||
if step_progression == StepProgression.FINISHED and not should_continue:
|
||
# Successfully completed - update with final usage and stop reason
|
||
if stop_reason is None:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||
# Note: step already updated with success status after _handle_ai_response
|
||
if logged_step:
|
||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||
break
|
||
|
||
# Handle error cases
|
||
if step_progression < StepProgression.STEP_LOGGED:
|
||
# Error occurred before step was fully logged
|
||
import traceback
|
||
|
||
if logged_step:
|
||
await self.step_manager.update_step_error_async(
|
||
actor=self.actor,
|
||
step_id=step_id, # Use original step_id for telemetry
|
||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
||
error_message=str(e) if "e" in locals() else "Unknown error",
|
||
error_traceback=traceback.format_exc(),
|
||
stop_reason=stop_reason,
|
||
)
|
||
|
||
if step_progression <= StepProgression.RESPONSE_RECEIVED:
|
||
# TODO (cliandy): persist response if we get it back
|
||
if settings.track_errored_messages and initial_messages:
|
||
for message in initial_messages:
|
||
message.is_err = True
|
||
message.step_id = effective_step_id
|
||
await self.message_manager.create_many_messages_async(
|
||
initial_messages,
|
||
actor=self.actor,
|
||
project_id=agent_state.project_id,
|
||
template_id=agent_state.template_id,
|
||
)
|
||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||
if stop_reason is None:
|
||
self.logger.error("Error in step after logging step")
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||
if logged_step:
|
||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||
else:
|
||
self.logger.error("Invalid StepProgression value")
|
||
|
||
if settings.track_stop_reason:
|
||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
|
||
|
||
# Record partial step metrics on failure (capture whatever timing data we have)
|
||
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
|
||
# Calculate total step time up to the failure point
|
||
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
|
||
await self._record_step_metrics(
|
||
step_id=step_id,
|
||
agent_state=agent_state,
|
||
step_metrics=step_metrics,
|
||
job_id=locals().get("run_id", self.current_run_id),
|
||
)
|
||
|
||
except Exception as e:
|
||
self.logger.error("Failed to update step: %s", e)
|
||
|
||
if not should_continue:
|
||
break
|
||
|
||
# Extend the in context message ids
|
||
if not agent_state.message_buffer_autoclear:
|
||
await self._rebuild_context_window(
|
||
in_context_messages=current_in_context_messages,
|
||
new_letta_messages=new_in_context_messages,
|
||
llm_config=agent_state.llm_config,
|
||
total_tokens=usage.total_tokens,
|
||
force=False,
|
||
run_id=run_id,
|
||
)
|
||
|
||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
|
||
|
||
# Return back usage
|
||
for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason):
|
||
yield f"data: {finish_chunk}\n\n"
|
||
|
||
async def _step(
|
||
self,
|
||
agent_state: AgentState,
|
||
input_messages: list[MessageCreateBase],
|
||
max_steps: int = DEFAULT_MAX_STEPS,
|
||
run_id: str | None = None,
|
||
request_start_timestamp_ns: int | None = None,
|
||
dry_run: bool = False,
|
||
) -> Union[tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics], dict]:
|
||
"""
|
||
Carries out an invocation of the agent loop. In each step, the agent
|
||
1. Rebuilds its memory
|
||
2. Generates a request for the LLM
|
||
3. Fetches a response from the LLM
|
||
4. Processes the response
|
||
"""
|
||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
||
input_messages, agent_state, self.message_manager, self.actor
|
||
)
|
||
initial_messages = new_in_context_messages
|
||
in_context_messages = current_in_context_messages
|
||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||
llm_client = LLMClient.create(
|
||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||
put_inner_thoughts_first=True,
|
||
actor=self.actor,
|
||
)
|
||
|
||
# span for request
|
||
request_span = tracer.start_span("time_to_first_token")
|
||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||
|
||
stop_reason = None
|
||
job_update_metadata = None
|
||
usage = LettaUsageStatistics()
|
||
for i in range(max_steps):
|
||
if in_context_messages[-1].role == "approval":
|
||
approval_request_message = in_context_messages[-1]
|
||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
|
||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||
approval_request_message.tool_calls[0],
|
||
[], # TODO: update this
|
||
agent_state,
|
||
tool_rules_solver,
|
||
usage,
|
||
reasoning_content=approval_request_message.content,
|
||
step_id=approval_request_message.step_id,
|
||
initial_messages=initial_messages,
|
||
is_final_step=(i == max_steps - 1),
|
||
step_metrics=step_metrics,
|
||
run_id=run_id or self.current_run_id,
|
||
is_approval=input_messages[0].approve,
|
||
is_denial=input_messages[0].approve == False,
|
||
denial_reason=input_messages[0].reason,
|
||
)
|
||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||
initial_messages = None
|
||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||
else:
|
||
# If dry run, build request data and return it without making LLM call
|
||
if dry_run:
|
||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||
llm_client=llm_client,
|
||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||
agent_state=agent_state,
|
||
tool_rules_solver=tool_rules_solver,
|
||
)
|
||
return request_data
|
||
|
||
# Check for job cancellation at the start of each step
|
||
if await self._check_run_cancellation():
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
|
||
break
|
||
|
||
step_id = generate_step_id()
|
||
step_start = get_utc_timestamp_ns()
|
||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||
agent_step_span.set_attributes({"step_id": step_id})
|
||
|
||
step_progression = StepProgression.START
|
||
should_continue = False
|
||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||
|
||
# Create step early with PENDING status
|
||
logged_step = await self.step_manager.log_step_async(
|
||
actor=self.actor,
|
||
agent_id=agent_state.id,
|
||
provider_name=agent_state.llm_config.model_endpoint_type,
|
||
provider_category=agent_state.llm_config.provider_category or "base",
|
||
model=agent_state.llm_config.model,
|
||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||
context_window_limit=agent_state.llm_config.context_window,
|
||
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||
provider_id=None,
|
||
run_id=run_id if run_id else self.current_run_id,
|
||
step_id=step_id,
|
||
project_id=agent_state.project_id,
|
||
status=StepStatus.PENDING,
|
||
model_handle=agent_state.llm_config.handle,
|
||
)
|
||
# Only use step_id in messages if step was actually created
|
||
effective_step_id = step_id if logged_step else None
|
||
|
||
try:
|
||
(
|
||
request_data,
|
||
response_data,
|
||
current_in_context_messages,
|
||
new_in_context_messages,
|
||
valid_tool_names,
|
||
) = await self._build_and_request_from_llm(
|
||
current_in_context_messages,
|
||
new_in_context_messages,
|
||
agent_state,
|
||
llm_client,
|
||
tool_rules_solver,
|
||
agent_step_span,
|
||
step_metrics,
|
||
run_id=run_id,
|
||
)
|
||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||
|
||
step_progression = StepProgression.RESPONSE_RECEIVED
|
||
log_event("agent.step.llm_response.received") # [3^]
|
||
|
||
try:
|
||
response = await llm_client.convert_response_to_chat_completion(
|
||
response_data, in_context_messages, agent_state.llm_config
|
||
)
|
||
except ValueError as e:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value)
|
||
raise e
|
||
|
||
usage.step_count += 1
|
||
usage.completion_tokens += response.usage.completion_tokens
|
||
usage.prompt_tokens += response.usage.prompt_tokens
|
||
usage.total_tokens += response.usage.total_tokens
|
||
usage.run_ids = [run_id] if run_id else None
|
||
MetricRegistry().message_output_tokens.record(
|
||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||
)
|
||
|
||
if not response.choices[0].message.tool_calls:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||
tool_call = response.choices[0].message.tool_calls[0]
|
||
if response.choices[0].message.reasoning_content:
|
||
reasoning = [
|
||
ReasoningContent(
|
||
reasoning=response.choices[0].message.reasoning_content,
|
||
is_native=True,
|
||
signature=response.choices[0].message.reasoning_content_signature,
|
||
)
|
||
]
|
||
elif response.choices[0].message.content:
|
||
reasoning = [
|
||
TextContent(text=response.choices[0].message.content)
|
||
] # reasoning placed into content for legacy reasons
|
||
elif response.choices[0].message.omitted_reasoning_content:
|
||
reasoning = [OmittedReasoningContent()]
|
||
else:
|
||
self.logger.info("No reasoning content found.")
|
||
reasoning = None
|
||
|
||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||
tool_call,
|
||
valid_tool_names,
|
||
agent_state,
|
||
tool_rules_solver,
|
||
response.usage,
|
||
reasoning_content=reasoning,
|
||
step_id=effective_step_id,
|
||
initial_messages=initial_messages,
|
||
agent_step_span=agent_step_span,
|
||
is_final_step=(i == max_steps - 1),
|
||
run_id=run_id,
|
||
step_metrics=step_metrics,
|
||
)
|
||
step_progression = StepProgression.STEP_LOGGED
|
||
|
||
# Update step with actual usage now that we have it (if step was created)
|
||
if logged_step:
|
||
await self.step_manager.update_step_success_async(self.actor, step_id, response.usage, stop_reason)
|
||
|
||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||
|
||
initial_messages = None
|
||
log_event("agent.step.llm_response.processed") # [4^]
|
||
|
||
# log step time
|
||
now = get_utc_timestamp_ns()
|
||
step_ns = now - step_start
|
||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||
agent_step_span.end()
|
||
|
||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||
step_progression = StepProgression.FINISHED
|
||
|
||
# Record step metrics for successful completion
|
||
if logged_step and step_metrics:
|
||
# Set the step_ns that was already calculated
|
||
step_metrics.step_ns = step_ns
|
||
await self._record_step_metrics(
|
||
step_id=step_id,
|
||
agent_state=agent_state,
|
||
step_metrics=step_metrics,
|
||
run_id=run_id if run_id else self.current_run_id,
|
||
)
|
||
|
||
except Exception as e:
|
||
# Handle any unexpected errors during step processing
|
||
self.logger.error(f"Error during step processing: {e}")
|
||
job_update_metadata = {"error": str(e)}
|
||
|
||
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||
if not stop_reason:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||
elif stop_reason.stop_reason not in (
|
||
StopReasonType.no_tool_call,
|
||
StopReasonType.invalid_tool_call,
|
||
StopReasonType.invalid_llm_response,
|
||
):
|
||
self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason)
|
||
raise
|
||
|
||
# Update step if it needs to be updated
|
||
finally:
|
||
if step_progression == StepProgression.FINISHED and should_continue:
|
||
continue
|
||
|
||
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
|
||
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||
try:
|
||
if step_progression == StepProgression.FINISHED and not should_continue:
|
||
# Successfully completed - update with final usage and stop reason
|
||
if stop_reason is None:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||
if logged_step:
|
||
await self.step_manager.update_step_success_async(self.actor, step_id, usage, stop_reason)
|
||
break
|
||
|
||
# Handle error cases
|
||
if step_progression < StepProgression.STEP_LOGGED:
|
||
# Error occurred before step was fully logged
|
||
import traceback
|
||
|
||
if logged_step:
|
||
await self.step_manager.update_step_error_async(
|
||
actor=self.actor,
|
||
step_id=step_id, # Use original step_id for telemetry
|
||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
||
error_message=str(e) if "e" in locals() else "Unknown error",
|
||
error_traceback=traceback.format_exc(),
|
||
stop_reason=stop_reason,
|
||
)
|
||
|
||
if step_progression <= StepProgression.RESPONSE_RECEIVED:
|
||
# TODO (cliandy): persist response if we get it back
|
||
if settings.track_errored_messages and initial_messages:
|
||
for message in initial_messages:
|
||
message.is_err = True
|
||
message.step_id = effective_step_id
|
||
await self.message_manager.create_many_messages_async(
|
||
initial_messages,
|
||
actor=self.actor,
|
||
project_id=agent_state.project_id,
|
||
template_id=agent_state.template_id,
|
||
)
|
||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||
if stop_reason is None:
|
||
self.logger.error("Error in step after logging step")
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||
if logged_step:
|
||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||
else:
|
||
self.logger.error("Invalid StepProgression value")
|
||
|
||
if settings.track_stop_reason:
|
||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
|
||
|
||
# Record partial step metrics on failure (capture whatever timing data we have)
|
||
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
|
||
# Calculate total step time up to the failure point
|
||
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
|
||
await self._record_step_metrics(
|
||
step_id=step_id,
|
||
agent_state=agent_state,
|
||
step_metrics=step_metrics,
|
||
job_id=locals().get("run_id", self.current_run_id),
|
||
)
|
||
|
||
except Exception as e:
|
||
self.logger.error("Failed to update step: %s", e)
|
||
|
||
if not should_continue:
|
||
break
|
||
|
||
# Extend the in context message ids
|
||
if not agent_state.message_buffer_autoclear:
|
||
await self._rebuild_context_window(
|
||
in_context_messages=current_in_context_messages,
|
||
new_letta_messages=new_in_context_messages,
|
||
llm_config=agent_state.llm_config,
|
||
total_tokens=usage.total_tokens,
|
||
force=False,
|
||
run_id=run_id,
|
||
)
|
||
|
||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
|
||
|
||
return current_in_context_messages, new_in_context_messages, stop_reason, usage
|
||
|
||
async def _update_agent_last_run_metrics(self, completion_time: datetime, duration_ms: float) -> None:
|
||
if not settings.track_last_agent_run:
|
||
return
|
||
try:
|
||
await self.agent_manager.update_agent_async(
|
||
agent_id=self.agent_id,
|
||
agent_update=UpdateAgent(last_run_completion=completion_time, last_run_duration_ms=duration_ms),
|
||
actor=self.actor,
|
||
)
|
||
except Exception as e:
|
||
self.logger.error(f"Failed to update agent's last run metrics: {e}")
|
||
|
||
@trace_method
|
||
async def step_stream(
|
||
self,
|
||
input_messages: list[MessageCreateBase],
|
||
max_steps: int = DEFAULT_MAX_STEPS,
|
||
use_assistant_message: bool = True,
|
||
request_start_timestamp_ns: int | None = None,
|
||
include_return_message_types: list[MessageType] | None = None,
|
||
run_id: str | None = None,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""
|
||
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
|
||
Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent
|
||
1. Rebuilds its memory
|
||
2. Generates a request for the LLM
|
||
3. Fetches a response from the LLM
|
||
4. Processes the response
|
||
"""
|
||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||
agent_id=self.agent_id,
|
||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||
actor=self.actor,
|
||
)
|
||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
||
input_messages, agent_state, self.message_manager, self.actor
|
||
)
|
||
initial_messages = new_in_context_messages
|
||
in_context_messages = current_in_context_messages
|
||
|
||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||
llm_client = LLMClient.create(
|
||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||
put_inner_thoughts_first=True,
|
||
actor=self.actor,
|
||
)
|
||
stop_reason = None
|
||
job_update_metadata = None
|
||
usage = LettaUsageStatistics()
|
||
first_chunk, request_span = True, None
|
||
if request_start_timestamp_ns:
|
||
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||
|
||
for i in range(max_steps):
|
||
if in_context_messages[-1].role == "approval":
|
||
approval_request_message = in_context_messages[-1]
|
||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
|
||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||
approval_request_message.tool_calls[0],
|
||
[], # TODO: update this
|
||
agent_state,
|
||
tool_rules_solver,
|
||
usage,
|
||
reasoning_content=approval_request_message.content,
|
||
step_id=approval_request_message.step_id,
|
||
initial_messages=new_in_context_messages,
|
||
is_final_step=(i == max_steps - 1),
|
||
step_metrics=step_metrics,
|
||
run_id=self.current_run_id,
|
||
is_approval=input_messages[0].approve,
|
||
is_denial=input_messages[0].approve == False,
|
||
denial_reason=input_messages[0].reason,
|
||
)
|
||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||
initial_messages = None
|
||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||
|
||
# yields tool response as this is handled from Letta and not the response from the LLM provider
|
||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||
if not (use_assistant_message and tool_return.name == "send_message"):
|
||
# Apply message type filtering if specified
|
||
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
|
||
yield f"data: {tool_return.model_dump_json()}\n\n"
|
||
else:
|
||
step_id = generate_step_id()
|
||
# Check for job cancellation at the start of each step
|
||
if await self._check_run_cancellation():
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
|
||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
||
break
|
||
|
||
step_start = get_utc_timestamp_ns()
|
||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||
agent_step_span.set_attributes({"step_id": step_id})
|
||
|
||
step_progression = StepProgression.START
|
||
should_continue = False
|
||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||
|
||
# Create step early with PENDING status
|
||
logged_step = await self.step_manager.log_step_async(
|
||
actor=self.actor,
|
||
agent_id=agent_state.id,
|
||
provider_name=agent_state.llm_config.model_endpoint_type,
|
||
provider_category=agent_state.llm_config.provider_category or "base",
|
||
model=agent_state.llm_config.model,
|
||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||
context_window_limit=agent_state.llm_config.context_window,
|
||
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||
provider_id=None,
|
||
run_id=self.current_run_id if self.current_run_id else None,
|
||
step_id=step_id,
|
||
project_id=agent_state.project_id,
|
||
status=StepStatus.PENDING,
|
||
model_handle=agent_state.llm_config.handle,
|
||
)
|
||
# Only use step_id in messages if step was actually created
|
||
effective_step_id = step_id if logged_step else None
|
||
|
||
try:
|
||
(
|
||
request_data,
|
||
stream,
|
||
current_in_context_messages,
|
||
new_in_context_messages,
|
||
valid_tool_names,
|
||
provider_request_start_timestamp_ns,
|
||
) = await self._build_and_request_from_llm_streaming(
|
||
first_chunk,
|
||
agent_step_span,
|
||
request_start_timestamp_ns,
|
||
current_in_context_messages,
|
||
new_in_context_messages,
|
||
agent_state,
|
||
llm_client,
|
||
tool_rules_solver,
|
||
run_id=run_id,
|
||
step_id=step_id,
|
||
)
|
||
|
||
step_progression = StepProgression.STREAM_RECEIVED
|
||
log_event("agent.stream.llm_response.received") # [3^]
|
||
|
||
# TODO: THIS IS INCREDIBLY UGLY
|
||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
||
interface = AnthropicStreamingInterface(
|
||
use_assistant_message=use_assistant_message,
|
||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||
requires_approval_tools=tool_rules_solver.get_requires_approval_tools(valid_tool_names),
|
||
)
|
||
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
|
||
interface = OpenAIStreamingInterface(
|
||
use_assistant_message=use_assistant_message,
|
||
is_openai_proxy=agent_state.llm_config.provider_name == "lmstudio_openai",
|
||
messages=current_in_context_messages + new_in_context_messages,
|
||
tools=request_data.get("tools", []),
|
||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||
requires_approval_tools=tool_rules_solver.get_requires_approval_tools(valid_tool_names),
|
||
)
|
||
else:
|
||
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
||
|
||
async for chunk in interface.process(
|
||
stream,
|
||
ttft_span=request_span,
|
||
):
|
||
# Measure TTFT (trace, metric, and db). This should be consolidated.
|
||
if first_chunk and request_span is not None:
|
||
now = get_utc_timestamp_ns()
|
||
ttft_ns = now - request_start_timestamp_ns
|
||
|
||
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
|
||
metric_attributes = get_ctx_attributes()
|
||
metric_attributes["model.name"] = agent_state.llm_config.model
|
||
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
|
||
|
||
if self.current_run_id and self.job_manager:
|
||
await self.job_manager.record_ttft(self.current_run_id, ttft_ns, self.actor)
|
||
|
||
first_chunk = False
|
||
|
||
if include_return_message_types is None or chunk.message_type in include_return_message_types:
|
||
# filter down returned data
|
||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||
|
||
stream_end_time_ns = get_utc_timestamp_ns()
|
||
|
||
# Some providers that rely on the OpenAI client currently e.g. LMStudio don't get usage metrics back on the last streaming chunk, fall back to manual values
|
||
if isinstance(interface, OpenAIStreamingInterface) and not interface.input_tokens and not interface.output_tokens:
|
||
logger.warning(
|
||
f"No token usage metrics received from OpenAI streaming interface for {agent_state.llm_config.model}, falling back to estimated values. Input tokens: {interface.fallback_input_tokens}, Output tokens: {interface.fallback_output_tokens}"
|
||
)
|
||
interface.input_tokens = interface.fallback_input_tokens
|
||
interface.output_tokens = interface.fallback_output_tokens
|
||
|
||
usage.step_count += 1
|
||
usage.completion_tokens += interface.output_tokens
|
||
usage.prompt_tokens += interface.input_tokens
|
||
usage.total_tokens += interface.input_tokens + interface.output_tokens
|
||
# Aggregate cache and reasoning tokens if available from streaming interface (handle None defaults)
|
||
if hasattr(interface, "cached_tokens") and interface.cached_tokens is not None:
|
||
usage.cached_input_tokens = (usage.cached_input_tokens or 0) + interface.cached_tokens
|
||
if hasattr(interface, "cache_read_tokens") and interface.cache_read_tokens is not None:
|
||
usage.cached_input_tokens = (usage.cached_input_tokens or 0) + interface.cache_read_tokens
|
||
if hasattr(interface, "cache_creation_tokens") and interface.cache_creation_tokens is not None:
|
||
usage.cache_write_tokens = (usage.cache_write_tokens or 0) + interface.cache_creation_tokens
|
||
if hasattr(interface, "reasoning_tokens") and interface.reasoning_tokens is not None:
|
||
usage.reasoning_tokens = (usage.reasoning_tokens or 0) + interface.reasoning_tokens
|
||
MetricRegistry().message_output_tokens.record(
|
||
usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||
)
|
||
|
||
# log LLM request time
|
||
llm_request_ns = stream_end_time_ns - provider_request_start_timestamp_ns
|
||
step_metrics.llm_request_ns = llm_request_ns
|
||
|
||
llm_request_ms = ns_to_ms(llm_request_ns)
|
||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
|
||
MetricRegistry().llm_execution_time_ms_histogram.record(
|
||
llm_request_ms,
|
||
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
|
||
)
|
||
|
||
# Process resulting stream content
|
||
try:
|
||
tool_call = interface.get_tool_call_object()
|
||
except ValueError as e:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||
raise e
|
||
except Exception as e:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
|
||
raise e
|
||
reasoning_content = interface.get_reasoning_content()
|
||
|
||
# Log provider trace telemetry after stream processing
|
||
await llm_client.log_provider_trace_async(
|
||
request_data=request_data,
|
||
response_json={
|
||
"content": {
|
||
"tool_call": tool_call.model_dump() if tool_call else None,
|
||
"reasoning": [c.model_dump() for c in reasoning_content] if reasoning_content else [],
|
||
},
|
||
"model": getattr(interface, "model", None),
|
||
"usage": {
|
||
"input_tokens": interface.input_tokens,
|
||
"output_tokens": interface.output_tokens,
|
||
},
|
||
},
|
||
llm_config=agent_state.llm_config,
|
||
latency_ms=int(llm_request_ms),
|
||
)
|
||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||
tool_call,
|
||
valid_tool_names,
|
||
agent_state,
|
||
tool_rules_solver,
|
||
UsageStatistics(
|
||
completion_tokens=usage.completion_tokens,
|
||
prompt_tokens=usage.prompt_tokens,
|
||
total_tokens=usage.total_tokens,
|
||
),
|
||
reasoning_content=reasoning_content,
|
||
pre_computed_assistant_message_id=interface.letta_message_id,
|
||
step_id=effective_step_id,
|
||
initial_messages=initial_messages,
|
||
agent_step_span=agent_step_span,
|
||
is_final_step=(i == max_steps - 1),
|
||
step_metrics=step_metrics,
|
||
)
|
||
step_progression = StepProgression.STEP_LOGGED
|
||
|
||
# Update step with actual usage now that we have it (if step was created)
|
||
if logged_step:
|
||
# Build detailed token breakdowns from LettaUsageStatistics
|
||
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached/reasoning tokens")
|
||
# Only include fields that were actually reported by the provider
|
||
prompt_details = None
|
||
if usage.cached_input_tokens is not None or usage.cache_write_tokens is not None:
|
||
prompt_details = UsageStatisticsPromptTokenDetails(
|
||
cached_tokens=usage.cached_input_tokens if usage.cached_input_tokens is not None else None,
|
||
cache_read_tokens=usage.cached_input_tokens if usage.cached_input_tokens is not None else None,
|
||
cache_creation_tokens=usage.cache_write_tokens if usage.cache_write_tokens is not None else None,
|
||
)
|
||
|
||
completion_details = None
|
||
if usage.reasoning_tokens is not None:
|
||
completion_details = UsageStatisticsCompletionTokenDetails(
|
||
reasoning_tokens=usage.reasoning_tokens,
|
||
)
|
||
|
||
await self.step_manager.update_step_success_async(
|
||
self.actor,
|
||
step_id,
|
||
UsageStatistics(
|
||
completion_tokens=usage.completion_tokens,
|
||
prompt_tokens=usage.prompt_tokens,
|
||
total_tokens=usage.total_tokens,
|
||
prompt_tokens_details=prompt_details,
|
||
completion_tokens_details=completion_details,
|
||
),
|
||
stop_reason,
|
||
)
|
||
|
||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||
|
||
initial_messages = None
|
||
|
||
# log total step time
|
||
now = get_utc_timestamp_ns()
|
||
step_ns = now - step_start
|
||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||
agent_step_span.end()
|
||
|
||
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
||
# log_event("agent.stream.llm_response.processed") # [4^]
|
||
|
||
if persisted_messages[-1].role != "approval":
|
||
# yields tool response as this is handled from Letta and not the response from the LLM provider
|
||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||
if not (use_assistant_message and tool_return.name == "send_message"):
|
||
# Apply message type filtering if specified
|
||
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
|
||
yield f"data: {tool_return.model_dump_json()}\n\n"
|
||
|
||
# TODO (cliandy): consolidate and expand with trace
|
||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||
step_progression = StepProgression.FINISHED
|
||
|
||
# Record step metrics for successful completion
|
||
if logged_step and step_metrics:
|
||
try:
|
||
# Set the step_ns that was already calculated
|
||
step_metrics.step_ns = step_ns
|
||
|
||
# Get context attributes for project and template IDs
|
||
ctx_attrs = get_ctx_attributes()
|
||
|
||
await self._record_step_metrics(
|
||
step_id=step_id,
|
||
agent_state=agent_state,
|
||
step_metrics=step_metrics,
|
||
ctx_attrs=ctx_attrs,
|
||
job_id=self.current_run_id,
|
||
)
|
||
except Exception as metrics_error:
|
||
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
|
||
|
||
except Exception as e:
|
||
# Handle any unexpected errors during step processing
|
||
self.logger.error(f"Error during step processing: {e}")
|
||
job_update_metadata = {"error": str(e)}
|
||
|
||
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||
if not stop_reason:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||
elif stop_reason.stop_reason not in (
|
||
StopReasonType.no_tool_call,
|
||
StopReasonType.invalid_tool_call,
|
||
StopReasonType.invalid_llm_response,
|
||
):
|
||
self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason)
|
||
|
||
# Send error stop reason to client and re-raise with expected response code
|
||
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
|
||
raise
|
||
|
||
# Update step if it needs to be updated
|
||
finally:
|
||
if step_progression == StepProgression.FINISHED and should_continue:
|
||
continue
|
||
|
||
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
|
||
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||
try:
|
||
if step_progression == StepProgression.FINISHED and not should_continue:
|
||
# Successfully completed - update with final usage and stop reason
|
||
if stop_reason is None:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||
# Note: step already updated with success status after _handle_ai_response
|
||
if logged_step:
|
||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||
break
|
||
|
||
# Handle error cases
|
||
if step_progression < StepProgression.STEP_LOGGED:
|
||
# Error occurred before step was fully logged
|
||
import traceback
|
||
|
||
if logged_step:
|
||
await self.step_manager.update_step_error_async(
|
||
actor=self.actor,
|
||
step_id=step_id, # Use original step_id for telemetry
|
||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
||
error_message=str(e) if "e" in locals() else "Unknown error",
|
||
error_traceback=traceback.format_exc(),
|
||
stop_reason=stop_reason,
|
||
)
|
||
|
||
if step_progression <= StepProgression.STREAM_RECEIVED:
|
||
if first_chunk and settings.track_errored_messages and initial_messages:
|
||
for message in initial_messages:
|
||
message.is_err = True
|
||
message.step_id = effective_step_id
|
||
await self.message_manager.create_many_messages_async(
|
||
initial_messages,
|
||
actor=self.actor,
|
||
project_id=agent_state.project_id,
|
||
template_id=agent_state.template_id,
|
||
)
|
||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||
if stop_reason is None:
|
||
self.logger.error("Error in step after logging step")
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||
if logged_step:
|
||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||
else:
|
||
self.logger.error("Invalid StepProgression value")
|
||
|
||
# Do tracking for failure cases. Can consolidate with success conditions later.
|
||
if settings.track_stop_reason:
|
||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
|
||
|
||
# Record partial step metrics on failure (capture whatever timing data we have)
|
||
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
|
||
try:
|
||
# Calculate total step time up to the failure point
|
||
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
|
||
|
||
# Get context attributes for project and template IDs
|
||
ctx_attrs = get_ctx_attributes()
|
||
|
||
await self._record_step_metrics(
|
||
step_id=step_id,
|
||
agent_state=agent_state,
|
||
step_metrics=step_metrics,
|
||
ctx_attrs=ctx_attrs,
|
||
job_id=locals().get("run_id", self.current_run_id),
|
||
)
|
||
except Exception as metrics_error:
|
||
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
|
||
|
||
except Exception as e:
|
||
self.logger.error("Failed to update step: %s", e)
|
||
|
||
if not should_continue:
|
||
break
|
||
# Extend the in context message ids
|
||
if not agent_state.message_buffer_autoclear:
|
||
await self._rebuild_context_window(
|
||
in_context_messages=current_in_context_messages,
|
||
new_letta_messages=new_in_context_messages,
|
||
llm_config=agent_state.llm_config,
|
||
total_tokens=usage.total_tokens,
|
||
force=False,
|
||
run_id=run_id,
|
||
)
|
||
|
||
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
|
||
|
||
for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason):
|
||
yield f"data: {finish_chunk}\n\n"
|
||
|
||
async def _log_request(
|
||
self, request_start_timestamp_ns: int, request_span: "Span | None", job_update_metadata: dict | None, is_error: bool
|
||
):
|
||
if request_start_timestamp_ns:
|
||
now_ns, now = get_utc_timestamp_ns(), get_utc_time()
|
||
duration_ns = now_ns - request_start_timestamp_ns
|
||
if request_span:
|
||
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)})
|
||
await self._update_agent_last_run_metrics(now, ns_to_ms(duration_ns))
|
||
if settings.track_agent_run and self.current_run_id:
|
||
await self.job_manager.record_response_duration(self.current_run_id, duration_ns, self.actor)
|
||
await self.job_manager.safe_update_job_status_async(
|
||
job_id=self.current_run_id,
|
||
new_status=JobStatus.failed if is_error else JobStatus.completed,
|
||
actor=self.actor,
|
||
metadata=job_update_metadata,
|
||
)
|
||
if request_span:
|
||
request_span.end()
|
||
|
||
async def _record_step_metrics(
|
||
self,
|
||
*,
|
||
step_id: str,
|
||
agent_state: AgentState,
|
||
step_metrics: StepMetrics,
|
||
ctx_attrs: dict | None = None,
|
||
job_id: str | None = None,
|
||
) -> None:
|
||
try:
|
||
attrs = ctx_attrs or get_ctx_attributes()
|
||
await self.step_manager.record_step_metrics_async(
|
||
actor=self.actor,
|
||
step_id=step_id,
|
||
llm_request_ns=step_metrics.llm_request_ns,
|
||
tool_execution_ns=step_metrics.tool_execution_ns,
|
||
step_ns=step_metrics.step_ns,
|
||
agent_id=agent_state.id,
|
||
job_id=job_id or self.current_run_id,
|
||
project_id=attrs.get("project.id") or agent_state.project_id,
|
||
template_id=attrs.get("template.id"),
|
||
base_template_id=attrs.get("base_template.id"),
|
||
)
|
||
except Exception as metrics_error:
|
||
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
|
||
|
||
# noinspection PyInconsistentReturns
|
||
async def _build_and_request_from_llm(
|
||
self,
|
||
current_in_context_messages: list[Message],
|
||
new_in_context_messages: list[Message],
|
||
agent_state: AgentState,
|
||
llm_client: LLMClientBase,
|
||
tool_rules_solver: ToolRulesSolver,
|
||
agent_step_span: "Span",
|
||
step_metrics: StepMetrics,
|
||
run_id: str | None = None,
|
||
) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None:
|
||
for attempt in range(self.max_summarization_retries + 1):
|
||
try:
|
||
log_event("agent.stream_no_tokens.messages.refreshed")
|
||
# Create LLM request data
|
||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||
llm_client=llm_client,
|
||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||
agent_state=agent_state,
|
||
tool_rules_solver=tool_rules_solver,
|
||
)
|
||
log_event("agent.stream_no_tokens.llm_request.created")
|
||
|
||
async with AsyncTimer() as timer:
|
||
# Attempt LLM request with telemetry
|
||
llm_client.set_telemetry_context(
|
||
telemetry_manager=self.telemetry_manager,
|
||
agent_id=self.agent_id,
|
||
agent_tags=agent_state.tags,
|
||
run_id=self.current_run_id,
|
||
step_id=step_metrics.id,
|
||
call_type=LLMCallType.agent_step,
|
||
)
|
||
response = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
|
||
|
||
# Track LLM request time
|
||
step_metrics.llm_request_ns = int(timer.elapsed_ns)
|
||
|
||
MetricRegistry().llm_execution_time_ms_histogram.record(
|
||
timer.elapsed_ms,
|
||
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
|
||
)
|
||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": timer.elapsed_ms})
|
||
|
||
return request_data, response, current_in_context_messages, new_in_context_messages, valid_tool_names
|
||
|
||
except Exception as e:
|
||
if attempt == self.max_summarization_retries:
|
||
raise e
|
||
|
||
# Handle the error and prepare for retry
|
||
current_in_context_messages = await self._handle_llm_error(
|
||
e,
|
||
llm_client=llm_client,
|
||
in_context_messages=current_in_context_messages,
|
||
new_letta_messages=new_in_context_messages,
|
||
llm_config=agent_state.llm_config,
|
||
force=True,
|
||
run_id=run_id,
|
||
)
|
||
new_in_context_messages = []
|
||
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
|
||
|
||
# noinspection PyInconsistentReturns
|
||
async def _build_and_request_from_llm_streaming(
|
||
self,
|
||
first_chunk: bool,
|
||
ttft_span: "Span",
|
||
request_start_timestamp_ns: int,
|
||
current_in_context_messages: list[Message],
|
||
new_in_context_messages: list[Message],
|
||
agent_state: AgentState,
|
||
llm_client: LLMClientBase,
|
||
tool_rules_solver: ToolRulesSolver,
|
||
run_id: str | None = None,
|
||
step_id: str | None = None,
|
||
) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None:
|
||
for attempt in range(self.max_summarization_retries + 1):
|
||
try:
|
||
log_event("agent.stream_no_tokens.messages.refreshed")
|
||
# Create LLM request data
|
||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||
llm_client=llm_client,
|
||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||
agent_state=agent_state,
|
||
tool_rules_solver=tool_rules_solver,
|
||
)
|
||
log_event("agent.stream.llm_request.created") # [2^]
|
||
|
||
provider_request_start_timestamp_ns = get_utc_timestamp_ns()
|
||
if first_chunk and ttft_span is not None:
|
||
request_start_to_provider_request_start_ns = provider_request_start_timestamp_ns - request_start_timestamp_ns
|
||
ttft_span.add_event(
|
||
name="request_start_to_provider_request_start_ns",
|
||
attributes={"request_start_to_provider_request_start_ns": ns_to_ms(request_start_to_provider_request_start_ns)},
|
||
)
|
||
|
||
# Set telemetry context before streaming
|
||
llm_client.set_telemetry_context(
|
||
telemetry_manager=self.telemetry_manager,
|
||
agent_id=self.agent_id,
|
||
agent_tags=agent_state.tags,
|
||
run_id=self.current_run_id,
|
||
step_id=step_id,
|
||
call_type=LLMCallType.agent_step,
|
||
)
|
||
|
||
# Attempt LLM request with telemetry wrapper
|
||
return (
|
||
request_data,
|
||
await llm_client.stream_async(request_data, agent_state.llm_config),
|
||
current_in_context_messages,
|
||
new_in_context_messages,
|
||
valid_tool_names,
|
||
provider_request_start_timestamp_ns,
|
||
)
|
||
|
||
except Exception as e:
|
||
if attempt == self.max_summarization_retries:
|
||
raise e
|
||
|
||
# Handle the error and prepare for retry
|
||
current_in_context_messages = await self._handle_llm_error(
|
||
e,
|
||
llm_client=llm_client,
|
||
in_context_messages=current_in_context_messages,
|
||
new_letta_messages=new_in_context_messages,
|
||
llm_config=agent_state.llm_config,
|
||
force=True,
|
||
run_id=run_id,
|
||
)
|
||
new_in_context_messages: list[Message] = []
|
||
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
|
||
|
||
@trace_method
|
||
async def _handle_llm_error(
|
||
self,
|
||
e: Exception,
|
||
llm_client: LLMClientBase,
|
||
in_context_messages: list[Message],
|
||
new_letta_messages: list[Message],
|
||
llm_config: LLMConfig,
|
||
force: bool,
|
||
run_id: str | None = None,
|
||
step_id: str | None = None,
|
||
) -> list[Message]:
|
||
if isinstance(e, ContextWindowExceededError):
|
||
return await self._rebuild_context_window(
|
||
in_context_messages=in_context_messages,
|
||
new_letta_messages=new_letta_messages,
|
||
llm_config=llm_config,
|
||
force=force,
|
||
run_id=run_id,
|
||
step_id=step_id,
|
||
)
|
||
else:
|
||
raise llm_client.handle_llm_error(e)
|
||
|
||
@trace_method
|
||
async def _rebuild_context_window(
|
||
self,
|
||
in_context_messages: list[Message],
|
||
new_letta_messages: list[Message],
|
||
llm_config: LLMConfig,
|
||
total_tokens: int | None = None,
|
||
force: bool = False,
|
||
run_id: str | None = None,
|
||
step_id: str | None = None,
|
||
) -> list[Message]:
|
||
# If total tokens is reached, we truncate down
|
||
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
|
||
# TODO: `force` and `clear` seem to no longer be used, we should remove
|
||
if force or (total_tokens and total_tokens > llm_config.context_window):
|
||
self.logger.warning(
|
||
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
|
||
)
|
||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||
in_context_messages=in_context_messages,
|
||
new_letta_messages=new_letta_messages,
|
||
force=True,
|
||
clear=True,
|
||
run_id=run_id,
|
||
step_id=step_id,
|
||
)
|
||
else:
|
||
# NOTE (Sarah): Seems like this is doing nothing?
|
||
self.logger.info(
|
||
f"Total tokens {total_tokens} does not exceed configured max tokens {llm_config.context_window}, passing summarizing w/o force."
|
||
)
|
||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||
in_context_messages=in_context_messages,
|
||
new_letta_messages=new_letta_messages,
|
||
run_id=run_id,
|
||
step_id=step_id,
|
||
)
|
||
await self.agent_manager.update_message_ids_async(
|
||
agent_id=self.agent_id,
|
||
message_ids=[m.id for m in new_in_context_messages],
|
||
actor=self.actor,
|
||
)
|
||
|
||
return new_in_context_messages
|
||
|
||
@trace_method
|
||
async def summarize_conversation_history(self) -> None:
|
||
"""Called when the developer explicitly triggers compaction via the API"""
|
||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
||
message_ids = agent_state.message_ids
|
||
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=self.actor)
|
||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||
in_context_messages=in_context_messages, new_letta_messages=[], force=True
|
||
)
|
||
return await self.agent_manager.update_message_ids_async(
|
||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||
)
|
||
|
||
@trace_method
|
||
async def _create_llm_request_data_async(
|
||
self,
|
||
llm_client: LLMClientBase,
|
||
in_context_messages: list[Message],
|
||
agent_state: AgentState,
|
||
tool_rules_solver: ToolRulesSolver,
|
||
) -> tuple[dict, list[str]]:
|
||
if not self.num_messages:
|
||
self.num_messages = await self.message_manager.size_async(
|
||
agent_id=agent_state.id,
|
||
actor=self.actor,
|
||
)
|
||
if not self.num_archival_memories:
|
||
self.num_archival_memories = await self.passage_manager.agent_passage_size_async(
|
||
agent_id=agent_state.id,
|
||
actor=self.actor,
|
||
)
|
||
|
||
in_context_messages = await self._rebuild_memory_async(
|
||
in_context_messages,
|
||
agent_state,
|
||
num_messages=self.num_messages,
|
||
num_archival_memories=self.num_archival_memories,
|
||
tool_rules_solver=tool_rules_solver,
|
||
)
|
||
|
||
# scrub inner thoughts from messages if reasoning is completely disabled
|
||
in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, agent_state.llm_config)
|
||
|
||
tools = [
|
||
t
|
||
for t in agent_state.tools
|
||
if t.tool_type
|
||
in {
|
||
ToolType.CUSTOM,
|
||
ToolType.LETTA_CORE,
|
||
ToolType.LETTA_MEMORY_CORE,
|
||
ToolType.LETTA_MULTI_AGENT_CORE,
|
||
ToolType.LETTA_SLEEPTIME_CORE,
|
||
ToolType.LETTA_VOICE_SLEEPTIME_CORE,
|
||
ToolType.LETTA_BUILTIN,
|
||
ToolType.LETTA_FILES_CORE,
|
||
ToolType.EXTERNAL_MCP,
|
||
}
|
||
]
|
||
|
||
# Mirror the sync agent loop: get allowed tools or allow all if none are allowed
|
||
self.last_function_response = self._load_last_function_response(in_context_messages)
|
||
valid_tool_names = tool_rules_solver.get_allowed_tool_names(
|
||
available_tools=set([t.name for t in tools]),
|
||
last_function_response=self.last_function_response,
|
||
) or list(set(t.name for t in tools))
|
||
|
||
# TODO: Copied from legacy agent loop, so please be cautious
|
||
# Set force tool
|
||
force_tool_call = None
|
||
if len(valid_tool_names) == 1:
|
||
force_tool_call = valid_tool_names[0]
|
||
|
||
allowed_tools = [
|
||
enable_strict_mode(t.json_schema, strict=agent_state.llm_config.strict) for t in tools if t.name in set(valid_tool_names)
|
||
]
|
||
# Extract terminal tool names from tool rules
|
||
terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules}
|
||
allowed_tools = runtime_override_tool_json_schema(
|
||
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True, terminal_tools=terminal_tool_names
|
||
)
|
||
|
||
return (
|
||
llm_client.build_request_data(
|
||
agent_state.agent_type,
|
||
in_context_messages,
|
||
agent_state.llm_config,
|
||
allowed_tools,
|
||
force_tool_call,
|
||
),
|
||
valid_tool_names,
|
||
)
|
||
|
||
@trace_method
|
||
async def _handle_ai_response(
|
||
self,
|
||
tool_call: ToolCall,
|
||
valid_tool_names: list[str],
|
||
agent_state: AgentState,
|
||
tool_rules_solver: ToolRulesSolver,
|
||
usage: UsageStatistics,
|
||
reasoning_content: list[TextContent | ReasoningContent | RedactedReasoningContent | OmittedReasoningContent] | None = None,
|
||
pre_computed_assistant_message_id: str | None = None,
|
||
step_id: str | None = None,
|
||
initial_messages: list[Message] | None = None,
|
||
agent_step_span: Optional["Span"] = None,
|
||
is_final_step: bool | None = None,
|
||
run_id: str | None = None,
|
||
step_metrics: StepMetrics = None,
|
||
is_approval: bool | None = None,
|
||
is_denial: bool | None = None,
|
||
denial_reason: str | None = None,
|
||
) -> tuple[list[Message], bool, LettaStopReason | None]:
|
||
"""
|
||
Handle the final AI response once streaming completes, execute / validate the
|
||
tool call, decide whether we should keep stepping, and persist state.
|
||
"""
|
||
tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
||
|
||
if is_denial:
|
||
continue_stepping = True
|
||
stop_reason = None
|
||
tool_call_messages = create_letta_messages_from_llm_response(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
function_name="",
|
||
function_arguments={},
|
||
tool_execution_result=ToolExecutionResult(status="error"),
|
||
tool_call_id=tool_call_id,
|
||
function_response=f"Error: request to call tool denied. User reason: {denial_reason}",
|
||
timezone=agent_state.timezone,
|
||
continue_stepping=continue_stepping,
|
||
heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.",
|
||
reasoning_content=None,
|
||
pre_computed_assistant_message_id=None,
|
||
step_id=step_id,
|
||
run_id=self.current_run_id,
|
||
is_approval_response=True,
|
||
)
|
||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||
messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id
|
||
)
|
||
return persisted_messages, continue_stepping, stop_reason
|
||
|
||
# 1. Parse and validate the tool-call envelope
|
||
tool_call_name: str = tool_call.function.name
|
||
|
||
tool_args = _safe_load_tool_call_str(tool_call.function.arguments)
|
||
request_heartbeat: bool = _pop_heartbeat(tool_args)
|
||
tool_args.pop(INNER_THOUGHTS_KWARG, None)
|
||
|
||
log_telemetry(
|
||
self.logger,
|
||
"_handle_ai_response execute tool start",
|
||
tool_name=tool_call_name,
|
||
tool_args=tool_args,
|
||
tool_call_id=tool_call_id,
|
||
request_heartbeat=request_heartbeat,
|
||
)
|
||
if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name):
|
||
tool_args[REQUEST_HEARTBEAT_PARAM] = request_heartbeat
|
||
approval_messages = create_approval_request_message_from_llm_response(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
requested_tool_calls=[
|
||
ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))
|
||
],
|
||
reasoning_content=reasoning_content,
|
||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||
step_id=step_id,
|
||
)
|
||
messages_to_persist = (initial_messages or []) + approval_messages
|
||
continue_stepping = False
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
|
||
else:
|
||
# 2. Execute the tool (or synthesize an error result if disallowed)
|
||
tool_rule_violated = tool_call_name not in valid_tool_names and not is_approval
|
||
if tool_rule_violated:
|
||
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
|
||
else:
|
||
# Track tool execution time
|
||
tool_start_time = get_utc_timestamp_ns()
|
||
tool_execution_result = await self._execute_tool(
|
||
tool_name=tool_call_name,
|
||
tool_args=tool_args,
|
||
agent_state=agent_state,
|
||
agent_step_span=agent_step_span,
|
||
step_id=step_id,
|
||
)
|
||
tool_end_time = get_utc_timestamp_ns()
|
||
|
||
# Store tool execution time in metrics
|
||
step_metrics.tool_execution_ns = tool_end_time - tool_start_time
|
||
|
||
log_telemetry(
|
||
self.logger,
|
||
"_handle_ai_response execute tool finish",
|
||
tool_execution_result=tool_execution_result,
|
||
tool_call_id=tool_call_id,
|
||
)
|
||
|
||
# 3. Prepare the function-response payload
|
||
truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
|
||
return_char_limit = next(
|
||
(t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
|
||
None,
|
||
)
|
||
function_response_string = validate_function_response(
|
||
tool_execution_result.func_return,
|
||
return_char_limit=return_char_limit,
|
||
truncate=truncate,
|
||
)
|
||
self.last_function_response = package_function_response(
|
||
was_success=tool_execution_result.success_flag,
|
||
response_string=function_response_string,
|
||
timezone=agent_state.timezone,
|
||
)
|
||
|
||
# 4. Decide whether to keep stepping (focal section simplified)
|
||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||
agent_state=agent_state,
|
||
request_heartbeat=request_heartbeat,
|
||
tool_call_name=tool_call_name,
|
||
tool_rule_violated=tool_rule_violated,
|
||
tool_rules_solver=tool_rules_solver,
|
||
is_final_step=is_final_step,
|
||
)
|
||
|
||
# 5. Create messages (step was already created at the beginning)
|
||
tool_call_messages = create_letta_messages_from_llm_response(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
function_name=tool_call_name,
|
||
function_arguments=tool_args,
|
||
tool_execution_result=tool_execution_result,
|
||
tool_call_id=tool_call_id,
|
||
function_response=function_response_string,
|
||
timezone=agent_state.timezone,
|
||
continue_stepping=continue_stepping,
|
||
heartbeat_reason=heartbeat_reason,
|
||
reasoning_content=reasoning_content,
|
||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||
step_id=step_id,
|
||
run_id=self.current_run_id,
|
||
is_approval_response=is_approval or is_denial,
|
||
)
|
||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||
|
||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||
messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id
|
||
)
|
||
|
||
return persisted_messages, continue_stepping, stop_reason
|
||
|
||
def _decide_continuation(
|
||
self,
|
||
agent_state: AgentState,
|
||
request_heartbeat: bool,
|
||
tool_call_name: str,
|
||
tool_rule_violated: bool,
|
||
tool_rules_solver: ToolRulesSolver,
|
||
is_final_step: bool | None,
|
||
) -> tuple[bool, str | None, LettaStopReason | None]:
|
||
continue_stepping = request_heartbeat
|
||
heartbeat_reason: str | None = None
|
||
stop_reason: LettaStopReason | None = None
|
||
|
||
if tool_rule_violated:
|
||
continue_stepping = True
|
||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
|
||
else:
|
||
tool_rules_solver.register_tool_call(tool_call_name)
|
||
|
||
if tool_rules_solver.is_terminal_tool(tool_call_name):
|
||
if continue_stepping:
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
|
||
continue_stepping = False
|
||
|
||
elif tool_rules_solver.has_children_tools(tool_call_name):
|
||
continue_stepping = True
|
||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: child tool rule."
|
||
|
||
elif tool_rules_solver.is_continue_tool(tool_call_name):
|
||
continue_stepping = True
|
||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: continue tool rule."
|
||
|
||
# – hard stop overrides –
|
||
if is_final_step:
|
||
continue_stepping = False
|
||
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
||
else:
|
||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||
if not continue_stepping and uncalled:
|
||
continue_stepping = True
|
||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [{', '.join(uncalled)}] to be called still."
|
||
|
||
stop_reason = None # reset – we’re still going
|
||
|
||
return continue_stepping, heartbeat_reason, stop_reason
|
||
|
||
@trace_method
|
||
async def _execute_tool(
|
||
self,
|
||
tool_name: str,
|
||
tool_args: JsonDict,
|
||
agent_state: AgentState,
|
||
agent_step_span: Optional["Span"] = None,
|
||
step_id: str | None = None,
|
||
) -> "ToolExecutionResult":
|
||
"""
|
||
Executes a tool and returns the ToolExecutionResult.
|
||
"""
|
||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||
|
||
# Special memory case
|
||
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
||
if not target_tool:
|
||
# TODO: fix this error message
|
||
return ToolExecutionResult(
|
||
func_return=f"Tool {tool_name} not found",
|
||
status="error",
|
||
)
|
||
|
||
# TODO: This temp. Move this logic and code to executors
|
||
|
||
if agent_step_span:
|
||
start_time = get_utc_timestamp_ns()
|
||
agent_step_span.add_event(name="tool_execution_started")
|
||
|
||
# Use pre-decrypted environment variable values (populated in from_orm_async)
|
||
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
|
||
tool_execution_manager = ToolExecutionManager(
|
||
agent_state=agent_state,
|
||
message_manager=self.message_manager,
|
||
agent_manager=self.agent_manager,
|
||
block_manager=self.block_manager,
|
||
job_manager=self.job_manager,
|
||
passage_manager=self.passage_manager,
|
||
sandbox_env_vars=sandbox_env_vars,
|
||
actor=self.actor,
|
||
)
|
||
# TODO: Integrate sandbox result
|
||
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||
function_name=tool_name,
|
||
function_args=tool_args,
|
||
tool=target_tool,
|
||
step_id=step_id,
|
||
)
|
||
if agent_step_span:
|
||
end_time = get_utc_timestamp_ns()
|
||
agent_step_span.add_event(
|
||
name="tool_execution_completed",
|
||
attributes={
|
||
"tool_name": target_tool.name,
|
||
"duration_ms": ns_to_ms(end_time - start_time),
|
||
"success": tool_execution_result.success_flag,
|
||
"tool_type": target_tool.tool_type,
|
||
"tool_id": target_tool.id,
|
||
},
|
||
)
|
||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_execution_result.model_dump())
|
||
return tool_execution_result
|