Files
letta-server/letta/agents/letta_agent.py
jnjpng 25d75d6528 fix: update aysnc get plaintext callsites (#7069)
* base

* resolve

* fix

* fix
2025-12-17 17:31:02 -08:00

1939 lines
98 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from typing import Optional, Union
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk
from opentelemetry.trace import Span
from letta.agents.base_agent import BaseAgent
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.agents.helpers import (
_build_rule_violation_result,
_create_letta_response,
_load_last_function_response,
_pop_heartbeat,
_prepare_in_context_messages_no_persist_async,
_safe_load_tool_call_str,
generate_step_id,
)
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import AsyncTimer, get_utc_time, get_utc_timestamp_ns, ns_to_ms
from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
from letta.llm_api.llm_client import LLMClient
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.otel.tracing import log_event, trace_method, tracer
from letta.schemas.agent import AgentState, UpdateAgent
from letta.schemas.enums import JobStatus, ProviderType, StepStatus, ToolType
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreateBase
from letta.schemas.openai.chat_completion_response import (
FunctionCall,
ToolCall,
UsageStatistics,
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import (
create_approval_request_message_from_llm_response,
create_letta_messages_from_llm_response,
)
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import NoopStepManager, StepManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import model_settings, settings, summarizer_settings
from letta.system import package_function_response
from letta.types import JsonDict
from letta.utils import log_telemetry, validate_function_response
logger = get_logger(__name__)
DEFAULT_SUMMARY_BLOCK_LABEL = "conversation_summary"
class LettaAgent(BaseAgent):
def __init__(
self,
agent_id: str,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
job_manager: JobManager,
passage_manager: PassageManager,
actor: User,
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
current_run_id: str | None = None,
## summarizer settings
summarizer_mode: SummarizationMode = summarizer_settings.mode,
# for static_buffer mode
summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL,
message_buffer_limit: int = summarizer_settings.message_buffer_limit,
message_buffer_min: int = summarizer_settings.message_buffer_min,
enable_summarization: bool = summarizer_settings.enable_summarization,
max_summarization_retries: int = summarizer_settings.max_summarization_retries,
# for partial_evict mode
partial_evict_summarizer_percentage: float = summarizer_settings.partial_evict_summarizer_percentage,
):
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
# TODO: Make this more general, factorable
# Summarizer settings
self.block_manager = block_manager
self.job_manager = job_manager
self.passage_manager = passage_manager
self.step_manager = step_manager
self.telemetry_manager = telemetry_manager
self.job_manager = job_manager
self.current_run_id = current_run_id
self.response_messages: list[Message] = []
self.last_function_response = None
# Cached archival memory/message size
self.num_messages = None
self.num_archival_memories = None
self.summarization_agent = None
self.summary_block_label = summary_block_label
self.max_summarization_retries = max_summarization_retries
self.logger = get_logger(agent_id)
# TODO: Expand to more
if enable_summarization and model_settings.openai_api_key:
self.summarization_agent = EphemeralSummaryAgent(
target_block_label=self.summary_block_label,
agent_id=agent_id,
block_manager=self.block_manager,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
actor=self.actor,
)
self.summarizer = Summarizer(
mode=summarizer_mode,
# TODO consolidate to not use this, or push it into the Summarizer() class
summarizer_agent=self.summarization_agent,
# TODO: Make this configurable
message_buffer_limit=message_buffer_limit,
message_buffer_min=message_buffer_min,
partial_evict_summarizer_percentage=partial_evict_summarizer_percentage,
agent_manager=self.agent_manager,
message_manager=self.message_manager,
actor=self.actor,
agent_id=self.agent_id,
)
async def _check_run_cancellation(self) -> bool:
"""
Check if the current run associated with this agent execution has been cancelled.
Returns:
True if the run is cancelled, False otherwise (or if no run is associated)
"""
if not self.job_manager or not self.current_run_id:
return False
try:
job = await self.job_manager.get_job_by_id_async(job_id=self.current_run_id, actor=self.actor)
return job.status == JobStatus.cancelled
except Exception as e:
# Log the error but don't fail the execution
logger.warning(f"Failed to check job cancellation status for job {self.current_run_id}: {e}")
return False
@trace_method
async def step(
self,
input_messages: list[MessageCreateBase],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: str | None = None,
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
dry_run: bool = False,
) -> Union[LettaResponse, dict]:
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id,
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
actor=self.actor,
)
result = await self._step(
agent_state=agent_state,
input_messages=input_messages,
max_steps=max_steps,
run_id=run_id,
request_start_timestamp_ns=request_start_timestamp_ns,
dry_run=dry_run,
)
# If dry run, return the request payload directly
if dry_run:
return result
_, new_in_context_messages, stop_reason, usage = result
return _create_letta_response(
new_in_context_messages=new_in_context_messages,
use_assistant_message=use_assistant_message,
stop_reason=stop_reason,
usage=usage,
include_return_message_types=include_return_message_types,
)
@trace_method
async def step_stream_no_tokens(
self,
input_messages: list[MessageCreateBase],
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
):
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id,
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
actor=self.actor,
)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
input_messages, agent_state, self.message_manager, self.actor
)
initial_messages = new_in_context_messages
in_context_messages = current_in_context_messages
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
stop_reason = None
job_update_metadata = None
usage = LettaUsageStatistics()
# span for request
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for i in range(max_steps):
if in_context_messages[-1].role == "approval":
approval_request_message = in_context_messages[-1]
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
approval_request_message.tool_calls[0],
[], # TODO: update this
agent_state,
tool_rules_solver,
usage,
reasoning_content=approval_request_message.content,
step_id=approval_request_message.step_id,
initial_messages=initial_messages,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
run_id=self.current_run_id,
is_approval=input_messages[0].approve,
is_denial=input_messages[0].approve == False,
denial_reason=input_messages[0].reason,
)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
in_context_messages = current_in_context_messages + new_in_context_messages
# stream step
# TODO: improve TTFT
filter_user_messages = [m for m in persisted_messages if m.role != "user" and m.role != "approval"]
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
for message in letta_messages:
if include_return_message_types is None or message.message_type in include_return_message_types:
yield f"data: {message.model_dump_json()}\n\n"
else:
# Check for job cancellation at the start of each step
if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
break
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
step_progression = StepProgression.START
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
# Create step early with PENDING status
logged_step = await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
run_id=self.current_run_id if self.current_run_id else None,
step_id=step_id,
project_id=agent_state.project_id,
status=StepStatus.PENDING,
)
# Only use step_id in messages if step was actually created
effective_step_id = step_id if logged_step else None
try:
(
request_data,
response_data,
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
) = await self._build_and_request_from_llm(
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
agent_step_span,
step_metrics,
)
in_context_messages = current_in_context_messages + new_in_context_messages
step_progression = StepProgression.RESPONSE_RECEIVED
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
try:
response = await llm_client.convert_response_to_chat_completion(
response_data, in_context_messages, agent_state.llm_config
)
except ValueError as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value)
raise e
# update usage
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
if not response.choices[0].message.tool_calls:
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
)
]
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
elif response.choices[0].message.content:
reasoning = [
TextContent(text=response.choices[0].message.content)
] # reasoning placed into content for legacy reasons
else:
self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=effective_step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
)
step_progression = StepProgression.STEP_LOGGED
# Update step with actual usage now that we have it (if step was created)
if logged_step:
await self.step_manager.update_step_success_async(self.actor, step_id, response.usage, stop_reason)
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# Log LLM Trace
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
),
)
step_progression = StepProgression.LOGGED_TRACE
# stream step
# TODO: improve TTFT
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
letta_messages = [m for m in letta_messages if m.message_type != "approval_response_message"]
for message in letta_messages:
if include_return_message_types is None or message.message_type in include_return_message_types:
yield f"data: {message.model_dump_json()}\n\n"
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
# Record step metrics for successful completion
if logged_step and step_metrics:
# Set the step_ns that was already calculated
step_metrics.step_ns = step_ns
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
)
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (
StopReasonType.no_tool_call,
StopReasonType.invalid_tool_call,
StopReasonType.invalid_llm_response,
):
self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason)
# Send error stop reason to client and re-raise
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
raise
# Update step if it needs to be updated
finally:
if step_progression == StepProgression.FINISHED and should_continue:
continue
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression == StepProgression.FINISHED and not should_continue:
# Successfully completed - update with final usage and stop reason
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
# Note: step already updated with success status after _handle_ai_response
if logged_step:
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
break
# Handle error cases
if step_progression < StepProgression.STEP_LOGGED:
# Error occurred before step was fully logged
import traceback
if logged_step:
await self.step_manager.update_step_error_async(
actor=self.actor,
step_id=step_id, # Use original step_id for telemetry
error_type=type(e).__name__ if "e" in locals() else "Unknown",
error_message=str(e) if "e" in locals() else "Unknown error",
error_traceback=traceback.format_exc(),
stop_reason=stop_reason,
)
if step_progression <= StepProgression.RESPONSE_RECEIVED:
# TODO (cliandy): persist response if we get it back
if settings.track_errored_messages and initial_messages:
for message in initial_messages:
message.is_err = True
message.step_id = effective_step_id
await self.message_manager.create_many_messages_async(
initial_messages,
actor=self.actor,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
if logged_step:
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
# Record partial step metrics on failure (capture whatever timing data we have)
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
# Calculate total step time up to the failure point
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
job_id=locals().get("run_id", self.current_run_id),
)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
await self._rebuild_context_window(
in_context_messages=current_in_context_messages,
new_letta_messages=new_in_context_messages,
llm_config=agent_state.llm_config,
total_tokens=usage.total_tokens,
force=False,
)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
# Return back usage
for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason):
yield f"data: {finish_chunk}\n\n"
async def _step(
self,
agent_state: AgentState,
input_messages: list[MessageCreateBase],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: str | None = None,
request_start_timestamp_ns: int | None = None,
dry_run: bool = False,
) -> Union[tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics], dict]:
"""
Carries out an invocation of the agent loop. In each step, the agent
1. Rebuilds its memory
2. Generates a request for the LLM
3. Fetches a response from the LLM
4. Processes the response
"""
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
input_messages, agent_state, self.message_manager, self.actor
)
initial_messages = new_in_context_messages
in_context_messages = current_in_context_messages
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
# span for request
request_span = tracer.start_span("time_to_first_token")
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
stop_reason = None
job_update_metadata = None
usage = LettaUsageStatistics()
for i in range(max_steps):
if in_context_messages[-1].role == "approval":
approval_request_message = in_context_messages[-1]
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
approval_request_message.tool_calls[0],
[], # TODO: update this
agent_state,
tool_rules_solver,
usage,
reasoning_content=approval_request_message.content,
step_id=approval_request_message.step_id,
initial_messages=initial_messages,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
run_id=run_id or self.current_run_id,
is_approval=input_messages[0].approve,
is_denial=input_messages[0].approve == False,
denial_reason=input_messages[0].reason,
)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
in_context_messages = current_in_context_messages + new_in_context_messages
else:
# If dry run, build request data and return it without making LLM call
if dry_run:
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
)
return request_data
# Check for job cancellation at the start of each step
if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
break
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
step_progression = StepProgression.START
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
# Create step early with PENDING status
logged_step = await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
run_id=run_id if run_id else self.current_run_id,
step_id=step_id,
project_id=agent_state.project_id,
status=StepStatus.PENDING,
)
# Only use step_id in messages if step was actually created
effective_step_id = step_id if logged_step else None
try:
(
request_data,
response_data,
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
) = await self._build_and_request_from_llm(
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
agent_step_span,
step_metrics,
)
in_context_messages = current_in_context_messages + new_in_context_messages
step_progression = StepProgression.RESPONSE_RECEIVED
log_event("agent.step.llm_response.received") # [3^]
try:
response = await llm_client.convert_response_to_chat_completion(
response_data, in_context_messages, agent_state.llm_config
)
except ValueError as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value)
raise e
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
usage.run_ids = [run_id] if run_id else None
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
if not response.choices[0].message.tool_calls:
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
)
]
elif response.choices[0].message.content:
reasoning = [
TextContent(text=response.choices[0].message.content)
] # reasoning placed into content for legacy reasons
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
else:
self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=effective_step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
run_id=run_id,
step_metrics=step_metrics,
)
step_progression = StepProgression.STEP_LOGGED
# Update step with actual usage now that we have it (if step was created)
if logged_step:
await self.step_manager.update_step_success_async(self.actor, step_id, response.usage, stop_reason)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
log_event("agent.step.llm_response.processed") # [4^]
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# Log LLM Trace
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
),
)
step_progression = StepProgression.LOGGED_TRACE
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
# Record step metrics for successful completion
if logged_step and step_metrics:
# Set the step_ns that was already calculated
step_metrics.step_ns = step_ns
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
run_id=run_id if run_id else self.current_run_id,
)
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (
StopReasonType.no_tool_call,
StopReasonType.invalid_tool_call,
StopReasonType.invalid_llm_response,
):
self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason)
raise
# Update step if it needs to be updated
finally:
if step_progression == StepProgression.FINISHED and should_continue:
continue
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression == StepProgression.FINISHED and not should_continue:
# Successfully completed - update with final usage and stop reason
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
if logged_step:
await self.step_manager.update_step_success_async(self.actor, step_id, usage, stop_reason)
break
# Handle error cases
if step_progression < StepProgression.STEP_LOGGED:
# Error occurred before step was fully logged
import traceback
if logged_step:
await self.step_manager.update_step_error_async(
actor=self.actor,
step_id=step_id, # Use original step_id for telemetry
error_type=type(e).__name__ if "e" in locals() else "Unknown",
error_message=str(e) if "e" in locals() else "Unknown error",
error_traceback=traceback.format_exc(),
stop_reason=stop_reason,
)
if step_progression <= StepProgression.RESPONSE_RECEIVED:
# TODO (cliandy): persist response if we get it back
if settings.track_errored_messages and initial_messages:
for message in initial_messages:
message.is_err = True
message.step_id = effective_step_id
await self.message_manager.create_many_messages_async(
initial_messages,
actor=self.actor,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
if logged_step:
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
# Record partial step metrics on failure (capture whatever timing data we have)
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
# Calculate total step time up to the failure point
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
job_id=locals().get("run_id", self.current_run_id),
)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
await self._rebuild_context_window(
in_context_messages=current_in_context_messages,
new_letta_messages=new_in_context_messages,
llm_config=agent_state.llm_config,
total_tokens=usage.total_tokens,
force=False,
)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
return current_in_context_messages, new_in_context_messages, stop_reason, usage
async def _update_agent_last_run_metrics(self, completion_time: datetime, duration_ms: float) -> None:
if not settings.track_last_agent_run:
return
try:
await self.agent_manager.update_agent_async(
agent_id=self.agent_id,
agent_update=UpdateAgent(last_run_completion=completion_time, last_run_duration_ms=duration_ms),
actor=self.actor,
)
except Exception as e:
self.logger.error(f"Failed to update agent's last run metrics: {e}")
@trace_method
async def step_stream(
self,
input_messages: list[MessageCreateBase],
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
) -> AsyncGenerator[str, None]:
"""
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent
1. Rebuilds its memory
2. Generates a request for the LLM
3. Fetches a response from the LLM
4. Processes the response
"""
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id,
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
actor=self.actor,
)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
input_messages, agent_state, self.message_manager, self.actor
)
initial_messages = new_in_context_messages
in_context_messages = current_in_context_messages
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
stop_reason = None
job_update_metadata = None
usage = LettaUsageStatistics()
first_chunk, request_span = True, None
if request_start_timestamp_ns:
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for i in range(max_steps):
if in_context_messages[-1].role == "approval":
approval_request_message = in_context_messages[-1]
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
approval_request_message.tool_calls[0],
[], # TODO: update this
agent_state,
tool_rules_solver,
usage,
reasoning_content=approval_request_message.content,
step_id=approval_request_message.step_id,
initial_messages=new_in_context_messages,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
run_id=self.current_run_id,
is_approval=input_messages[0].approve,
is_denial=input_messages[0].approve == False,
denial_reason=input_messages[0].reason,
)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
in_context_messages = current_in_context_messages + new_in_context_messages
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
if not (use_assistant_message and tool_return.name == "send_message"):
# Apply message type filtering if specified
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
else:
step_id = generate_step_id()
# Check for job cancellation at the start of each step
if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
break
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
step_progression = StepProgression.START
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
# Create step early with PENDING status
logged_step = await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
run_id=self.current_run_id if self.current_run_id else None,
step_id=step_id,
project_id=agent_state.project_id,
status=StepStatus.PENDING,
)
# Only use step_id in messages if step was actually created
effective_step_id = step_id if logged_step else None
try:
(
request_data,
stream,
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
provider_request_start_timestamp_ns,
) = await self._build_and_request_from_llm_streaming(
first_chunk,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
)
step_progression = StepProgression.STREAM_RECEIVED
log_event("agent.stream.llm_response.received") # [3^]
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
requires_approval_tools=tool_rules_solver.get_requires_approval_tools(valid_tool_names),
)
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
interface = OpenAIStreamingInterface(
use_assistant_message=use_assistant_message,
is_openai_proxy=agent_state.llm_config.provider_name == "lmstudio_openai",
messages=current_in_context_messages + new_in_context_messages,
tools=request_data.get("tools", []),
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
requires_approval_tools=tool_rules_solver.get_requires_approval_tools(valid_tool_names),
)
else:
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
async for chunk in interface.process(
stream,
ttft_span=request_span,
):
# Measure TTFT (trace, metric, and db). This should be consolidated.
if first_chunk and request_span is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = agent_state.llm_config.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
if self.current_run_id and self.job_manager:
await self.job_manager.record_ttft(self.current_run_id, ttft_ns, self.actor)
first_chunk = False
if include_return_message_types is None or chunk.message_type in include_return_message_types:
# filter down returned data
yield f"data: {chunk.model_dump_json()}\n\n"
stream_end_time_ns = get_utc_timestamp_ns()
# Some providers that rely on the OpenAI client currently e.g. LMStudio don't get usage metrics back on the last streaming chunk, fall back to manual values
if isinstance(interface, OpenAIStreamingInterface) and not interface.input_tokens and not interface.output_tokens:
logger.warning(
f"No token usage metrics received from OpenAI streaming interface for {agent_state.llm_config.model}, falling back to estimated values. Input tokens: {interface.fallback_input_tokens}, Output tokens: {interface.fallback_output_tokens}"
)
interface.input_tokens = interface.fallback_input_tokens
interface.output_tokens = interface.fallback_output_tokens
usage.step_count += 1
usage.completion_tokens += interface.output_tokens
usage.prompt_tokens += interface.input_tokens
usage.total_tokens += interface.input_tokens + interface.output_tokens
# Aggregate cache and reasoning tokens if available from streaming interface (handle None defaults)
if hasattr(interface, "cached_tokens") and interface.cached_tokens is not None:
usage.cached_input_tokens = (usage.cached_input_tokens or 0) + interface.cached_tokens
if hasattr(interface, "cache_read_tokens") and interface.cache_read_tokens is not None:
usage.cached_input_tokens = (usage.cached_input_tokens or 0) + interface.cache_read_tokens
if hasattr(interface, "cache_creation_tokens") and interface.cache_creation_tokens is not None:
usage.cache_write_tokens = (usage.cache_write_tokens or 0) + interface.cache_creation_tokens
if hasattr(interface, "reasoning_tokens") and interface.reasoning_tokens is not None:
usage.reasoning_tokens = (usage.reasoning_tokens or 0) + interface.reasoning_tokens
MetricRegistry().message_output_tokens.record(
usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
# log LLM request time
llm_request_ns = stream_end_time_ns - provider_request_start_timestamp_ns
step_metrics.llm_request_ns = llm_request_ns
llm_request_ms = ns_to_ms(llm_request_ns)
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
MetricRegistry().llm_execution_time_ms_histogram.record(
llm_request_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Process resulting stream content
try:
tool_call = interface.get_tool_call_object()
except ValueError as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
raise e
except Exception as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
raise e
reasoning_content = interface.get_reasoning_content()
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
UsageStatistics(
completion_tokens=usage.completion_tokens,
prompt_tokens=usage.prompt_tokens,
total_tokens=usage.total_tokens,
),
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_message_id,
step_id=effective_step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
)
step_progression = StepProgression.STEP_LOGGED
# Update step with actual usage now that we have it (if step was created)
if logged_step:
# Build detailed token breakdowns from LettaUsageStatistics
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached/reasoning tokens")
# Only include fields that were actually reported by the provider
prompt_details = None
if usage.cached_input_tokens is not None or usage.cache_write_tokens is not None:
prompt_details = UsageStatisticsPromptTokenDetails(
cached_tokens=usage.cached_input_tokens if usage.cached_input_tokens is not None else None,
cache_read_tokens=usage.cached_input_tokens if usage.cached_input_tokens is not None else None,
cache_creation_tokens=usage.cache_write_tokens if usage.cache_write_tokens is not None else None,
)
completion_details = None
if usage.reasoning_tokens is not None:
completion_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=usage.reasoning_tokens,
)
await self.step_manager.update_step_success_async(
self.actor,
step_id,
UsageStatistics(
completion_tokens=usage.completion_tokens,
prompt_tokens=usage.prompt_tokens,
total_tokens=usage.total_tokens,
prompt_tokens_details=prompt_details,
completion_tokens_details=completion_details,
),
stop_reason,
)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
# log total step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
# log_event("agent.stream.llm_response.processed") # [4^]
# Log LLM Trace
# We are piecing together the streamed response here.
# Content here does not match the actual response schema as streams come in chunks.
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
},
},
step_id=step_id, # Use original step_id for telemetry
),
)
step_progression = StepProgression.LOGGED_TRACE
if persisted_messages[-1].role != "approval":
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
if not (use_assistant_message and tool_return.name == "send_message"):
# Apply message type filtering if specified
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
# TODO (cliandy): consolidate and expand with trace
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
# Record step metrics for successful completion
if logged_step and step_metrics:
try:
# Set the step_ns that was already calculated
step_metrics.step_ns = step_ns
# Get context attributes for project and template IDs
ctx_attrs = get_ctx_attributes()
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
ctx_attrs=ctx_attrs,
job_id=self.current_run_id,
)
except Exception as metrics_error:
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (
StopReasonType.no_tool_call,
StopReasonType.invalid_tool_call,
StopReasonType.invalid_llm_response,
):
self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason)
# Send error stop reason to client and re-raise with expected response code
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
raise
# Update step if it needs to be updated
finally:
if step_progression == StepProgression.FINISHED and should_continue:
continue
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression == StepProgression.FINISHED and not should_continue:
# Successfully completed - update with final usage and stop reason
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
# Note: step already updated with success status after _handle_ai_response
if logged_step:
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
break
# Handle error cases
if step_progression < StepProgression.STEP_LOGGED:
# Error occurred before step was fully logged
import traceback
if logged_step:
await self.step_manager.update_step_error_async(
actor=self.actor,
step_id=step_id, # Use original step_id for telemetry
error_type=type(e).__name__ if "e" in locals() else "Unknown",
error_message=str(e) if "e" in locals() else "Unknown error",
error_traceback=traceback.format_exc(),
stop_reason=stop_reason,
)
if step_progression <= StepProgression.STREAM_RECEIVED:
if first_chunk and settings.track_errored_messages and initial_messages:
for message in initial_messages:
message.is_err = True
message.step_id = effective_step_id
await self.message_manager.create_many_messages_async(
initial_messages,
actor=self.actor,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
if logged_step:
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
# Do tracking for failure cases. Can consolidate with success conditions later.
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
# Record partial step metrics on failure (capture whatever timing data we have)
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
try:
# Calculate total step time up to the failure point
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
# Get context attributes for project and template IDs
ctx_attrs = get_ctx_attributes()
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
ctx_attrs=ctx_attrs,
job_id=locals().get("run_id", self.current_run_id),
)
except Exception as metrics_error:
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
await self._rebuild_context_window(
in_context_messages=current_in_context_messages,
new_letta_messages=new_in_context_messages,
llm_config=agent_state.llm_config,
total_tokens=usage.total_tokens,
force=False,
)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason):
yield f"data: {finish_chunk}\n\n"
async def _log_request(
self, request_start_timestamp_ns: int, request_span: "Span | None", job_update_metadata: dict | None, is_error: bool
):
if request_start_timestamp_ns:
now_ns, now = get_utc_timestamp_ns(), get_utc_time()
duration_ns = now_ns - request_start_timestamp_ns
if request_span:
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)})
await self._update_agent_last_run_metrics(now, ns_to_ms(duration_ns))
if settings.track_agent_run and self.current_run_id:
await self.job_manager.record_response_duration(self.current_run_id, duration_ns, self.actor)
await self.job_manager.safe_update_job_status_async(
job_id=self.current_run_id,
new_status=JobStatus.failed if is_error else JobStatus.completed,
actor=self.actor,
metadata=job_update_metadata,
)
if request_span:
request_span.end()
async def _record_step_metrics(
self,
*,
step_id: str,
agent_state: AgentState,
step_metrics: StepMetrics,
ctx_attrs: dict | None = None,
job_id: str | None = None,
) -> None:
try:
attrs = ctx_attrs or get_ctx_attributes()
await self.step_manager.record_step_metrics_async(
actor=self.actor,
step_id=step_id,
llm_request_ns=step_metrics.llm_request_ns,
tool_execution_ns=step_metrics.tool_execution_ns,
step_ns=step_metrics.step_ns,
agent_id=agent_state.id,
job_id=job_id or self.current_run_id,
project_id=attrs.get("project.id") or agent_state.project_id,
template_id=attrs.get("template.id"),
base_template_id=attrs.get("base_template.id"),
)
except Exception as metrics_error:
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
# noinspection PyInconsistentReturns
async def _build_and_request_from_llm(
self,
current_in_context_messages: list[Message],
new_in_context_messages: list[Message],
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
agent_step_span: "Span",
step_metrics: StepMetrics,
) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
# Create LLM request data
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
)
log_event("agent.stream_no_tokens.llm_request.created")
async with AsyncTimer() as timer:
# Attempt LLM request
response = await llm_client.request_async(request_data, agent_state.llm_config)
# Track LLM request time
step_metrics.llm_request_ns = int(timer.elapsed_ns)
MetricRegistry().llm_execution_time_ms_histogram.record(
timer.elapsed_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": timer.elapsed_ms})
return request_data, response, current_in_context_messages, new_in_context_messages, valid_tool_names
except Exception as e:
if attempt == self.max_summarization_retries:
raise e
# Handle the error and prepare for retry
current_in_context_messages = await self._handle_llm_error(
e,
llm_client=llm_client,
in_context_messages=current_in_context_messages,
new_letta_messages=new_in_context_messages,
llm_config=agent_state.llm_config,
force=True,
)
new_in_context_messages = []
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
# noinspection PyInconsistentReturns
async def _build_and_request_from_llm_streaming(
self,
first_chunk: bool,
ttft_span: "Span",
request_start_timestamp_ns: int,
current_in_context_messages: list[Message],
new_in_context_messages: list[Message],
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
# Create LLM request data
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
)
log_event("agent.stream.llm_request.created") # [2^]
provider_request_start_timestamp_ns = get_utc_timestamp_ns()
if first_chunk and ttft_span is not None:
request_start_to_provider_request_start_ns = provider_request_start_timestamp_ns - request_start_timestamp_ns
ttft_span.add_event(
name="request_start_to_provider_request_start_ns",
attributes={"request_start_to_provider_request_start_ns": ns_to_ms(request_start_to_provider_request_start_ns)},
)
# Attempt LLM request
return (
request_data,
await llm_client.stream_async(request_data, agent_state.llm_config),
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
provider_request_start_timestamp_ns,
)
except Exception as e:
if attempt == self.max_summarization_retries:
raise e
# Handle the error and prepare for retry
current_in_context_messages = await self._handle_llm_error(
e,
llm_client=llm_client,
in_context_messages=current_in_context_messages,
new_letta_messages=new_in_context_messages,
llm_config=agent_state.llm_config,
force=True,
)
new_in_context_messages: list[Message] = []
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
@trace_method
async def _handle_llm_error(
self,
e: Exception,
llm_client: LLMClientBase,
in_context_messages: list[Message],
new_letta_messages: list[Message],
llm_config: LLMConfig,
force: bool,
) -> list[Message]:
if isinstance(e, ContextWindowExceededError):
return await self._rebuild_context_window(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, llm_config=llm_config, force=force
)
else:
raise llm_client.handle_llm_error(e)
@trace_method
async def _rebuild_context_window(
self,
in_context_messages: list[Message],
new_letta_messages: list[Message],
llm_config: LLMConfig,
total_tokens: int | None = None,
force: bool = False,
) -> list[Message]:
# If total tokens is reached, we truncate down
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
# TODO: `force` and `clear` seem to no longer be used, we should remove
if force or (total_tokens and total_tokens > llm_config.context_window):
self.logger.warning(
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
)
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
force=True,
clear=True,
)
else:
# NOTE (Sarah): Seems like this is doing nothing?
self.logger.info(
f"Total tokens {total_tokens} does not exceed configured max tokens {llm_config.context_window}, passing summarizing w/o force."
)
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
)
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_id,
message_ids=[m.id for m in new_in_context_messages],
actor=self.actor,
)
return new_in_context_messages
@trace_method
async def summarize_conversation_history(self) -> None:
"""Called when the developer explicitly triggers compaction via the API"""
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
message_ids = agent_state.message_ids
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=self.actor)
new_in_context_messages, updated = await self.summarizer.summarize(
in_context_messages=in_context_messages, new_letta_messages=[], force=True
)
return await self.agent_manager.update_message_ids_async(
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
@trace_method
async def _create_llm_request_data_async(
self,
llm_client: LLMClientBase,
in_context_messages: list[Message],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
) -> tuple[dict, list[str]]:
if not self.num_messages:
self.num_messages = await self.message_manager.size_async(
agent_id=agent_state.id,
actor=self.actor,
)
if not self.num_archival_memories:
self.num_archival_memories = await self.passage_manager.agent_passage_size_async(
agent_id=agent_state.id,
actor=self.actor,
)
in_context_messages = await self._rebuild_memory_async(
in_context_messages,
agent_state,
num_messages=self.num_messages,
num_archival_memories=self.num_archival_memories,
tool_rules_solver=tool_rules_solver,
)
# scrub inner thoughts from messages if reasoning is completely disabled
in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, agent_state.llm_config)
tools = [
t
for t in agent_state.tools
if t.tool_type
in {
ToolType.CUSTOM,
ToolType.LETTA_CORE,
ToolType.LETTA_MEMORY_CORE,
ToolType.LETTA_MULTI_AGENT_CORE,
ToolType.LETTA_SLEEPTIME_CORE,
ToolType.LETTA_VOICE_SLEEPTIME_CORE,
ToolType.LETTA_BUILTIN,
ToolType.LETTA_FILES_CORE,
ToolType.EXTERNAL_MCP,
}
]
# Mirror the sync agent loop: get allowed tools or allow all if none are allowed
self.last_function_response = self._load_last_function_response(in_context_messages)
valid_tool_names = tool_rules_solver.get_allowed_tool_names(
available_tools=set([t.name for t in tools]),
last_function_response=self.last_function_response,
) or list(set(t.name for t in tools))
# TODO: Copied from legacy agent loop, so please be cautious
# Set force tool
force_tool_call = None
if len(valid_tool_names) == 1:
force_tool_call = valid_tool_names[0]
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
# Extract terminal tool names from tool rules
terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules}
allowed_tools = runtime_override_tool_json_schema(
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True, terminal_tools=terminal_tool_names
)
return (
llm_client.build_request_data(
agent_state.agent_type,
in_context_messages,
agent_state.llm_config,
allowed_tools,
force_tool_call,
),
valid_tool_names,
)
@trace_method
async def _handle_ai_response(
self,
tool_call: ToolCall,
valid_tool_names: list[str],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
usage: UsageStatistics,
reasoning_content: list[TextContent | ReasoningContent | RedactedReasoningContent | OmittedReasoningContent] | None = None,
pre_computed_assistant_message_id: str | None = None,
step_id: str | None = None,
initial_messages: list[Message] | None = None,
agent_step_span: Optional["Span"] = None,
is_final_step: bool | None = None,
run_id: str | None = None,
step_metrics: StepMetrics = None,
is_approval: bool | None = None,
is_denial: bool | None = None,
denial_reason: str | None = None,
) -> tuple[list[Message], bool, LettaStopReason | None]:
"""
Handle the final AI response once streaming completes, execute / validate the
tool call, decide whether we should keep stepping, and persist state.
"""
tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
if is_denial:
continue_stepping = True
stop_reason = None
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name="",
function_arguments={},
tool_execution_result=ToolExecutionResult(status="error"),
tool_call_id=tool_call_id,
function_response=f"Error: request to call tool denied. User reason: {denial_reason}",
timezone=agent_state.timezone,
continue_stepping=continue_stepping,
heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.",
reasoning_content=None,
pre_computed_assistant_message_id=None,
step_id=step_id,
run_id=self.current_run_id,
is_approval_response=True,
)
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id
)
return persisted_messages, continue_stepping, stop_reason
# 1. Parse and validate the tool-call envelope
tool_call_name: str = tool_call.function.name
tool_args = _safe_load_tool_call_str(tool_call.function.arguments)
request_heartbeat: bool = _pop_heartbeat(tool_args)
tool_args.pop(INNER_THOUGHTS_KWARG, None)
log_telemetry(
self.logger,
"_handle_ai_response execute tool start",
tool_name=tool_call_name,
tool_args=tool_args,
tool_call_id=tool_call_id,
request_heartbeat=request_heartbeat,
)
if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name):
tool_args[REQUEST_HEARTBEAT_PARAM] = request_heartbeat
approval_messages = create_approval_request_message_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
requested_tool_calls=[
ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))
],
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
)
messages_to_persist = (initial_messages or []) + approval_messages
continue_stepping = False
stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
else:
# 2. Execute the tool (or synthesize an error result if disallowed)
tool_rule_violated = tool_call_name not in valid_tool_names and not is_approval
if tool_rule_violated:
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
else:
# Track tool execution time
tool_start_time = get_utc_timestamp_ns()
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
tool_args=tool_args,
agent_state=agent_state,
agent_step_span=agent_step_span,
step_id=step_id,
)
tool_end_time = get_utc_timestamp_ns()
# Store tool execution time in metrics
step_metrics.tool_execution_ns = tool_end_time - tool_start_time
log_telemetry(
self.logger,
"_handle_ai_response execute tool finish",
tool_execution_result=tool_execution_result,
tool_call_id=tool_call_id,
)
# 3. Prepare the function-response payload
truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
return_char_limit = next(
(t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
None,
)
function_response_string = validate_function_response(
tool_execution_result.func_return,
return_char_limit=return_char_limit,
truncate=truncate,
)
self.last_function_response = package_function_response(
was_success=tool_execution_result.success_flag,
response_string=function_response_string,
timezone=agent_state.timezone,
)
# 4. Decide whether to keep stepping (focal section simplified)
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
agent_state=agent_state,
request_heartbeat=request_heartbeat,
tool_call_name=tool_call_name,
tool_rule_violated=tool_rule_violated,
tool_rules_solver=tool_rules_solver,
is_final_step=is_final_step,
)
# 5. Create messages (step was already created at the beginning)
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,
function_arguments=tool_args,
tool_execution_result=tool_execution_result,
tool_call_id=tool_call_id,
function_response=function_response_string,
timezone=agent_state.timezone,
continue_stepping=continue_stepping,
heartbeat_reason=heartbeat_reason,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
run_id=self.current_run_id,
is_approval_response=is_approval or is_denial,
)
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id
)
return persisted_messages, continue_stepping, stop_reason
def _decide_continuation(
self,
agent_state: AgentState,
request_heartbeat: bool,
tool_call_name: str,
tool_rule_violated: bool,
tool_rules_solver: ToolRulesSolver,
is_final_step: bool | None,
) -> tuple[bool, str | None, LettaStopReason | None]:
continue_stepping = request_heartbeat
heartbeat_reason: str | None = None
stop_reason: LettaStopReason | None = None
if tool_rule_violated:
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
else:
tool_rules_solver.register_tool_call(tool_call_name)
if tool_rules_solver.is_terminal_tool(tool_call_name):
if continue_stepping:
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
continue_stepping = False
elif tool_rules_solver.has_children_tools(tool_call_name):
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: child tool rule."
elif tool_rules_solver.is_continue_tool(tool_call_name):
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: continue tool rule."
# hard stop overrides
if is_final_step:
continue_stepping = False
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
else:
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
if not continue_stepping and uncalled:
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [{', '.join(uncalled)}] to be called still."
stop_reason = None # reset were still going
return continue_stepping, heartbeat_reason, stop_reason
@trace_method
async def _execute_tool(
self,
tool_name: str,
tool_args: JsonDict,
agent_state: AgentState,
agent_step_span: Optional["Span"] = None,
step_id: str | None = None,
) -> "ToolExecutionResult":
"""
Executes a tool and returns the ToolExecutionResult.
"""
from letta.schemas.tool_execution_result import ToolExecutionResult
# Special memory case
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
if not target_tool:
# TODO: fix this error message
return ToolExecutionResult(
func_return=f"Tool {tool_name} not found",
status="error",
)
# TODO: This temp. Move this logic and code to executors
if agent_step_span:
start_time = get_utc_timestamp_ns()
agent_step_span.add_event(name="tool_execution_started")
# Use pre-decrypted environment variable values (populated in from_orm_async)
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
tool_execution_manager = ToolExecutionManager(
agent_state=agent_state,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
job_manager=self.job_manager,
passage_manager=self.passage_manager,
sandbox_env_vars=sandbox_env_vars,
actor=self.actor,
)
# TODO: Integrate sandbox result
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
tool_execution_result = await tool_execution_manager.execute_tool_async(
function_name=tool_name,
function_args=tool_args,
tool=target_tool,
step_id=step_id,
)
if agent_step_span:
end_time = get_utc_timestamp_ns()
agent_step_span.add_event(
name="tool_execution_completed",
attributes={
"tool_name": target_tool.name,
"duration_ms": ns_to_ms(end_time - start_time),
"success": tool_execution_result.success_flag,
"tool_type": target_tool.tool_type,
"tool_id": target_tool.id,
},
)
log_event(name=f"finish_{tool_name}_execution", attributes=tool_execution_result.model_dump())
return tool_execution_result