feat: add step checkpointing for progress tracking (#4458)

* feat: add step checkpointing for progress tracking

* openapi sync
This commit is contained in:
cthomas
2025-09-08 10:30:44 -07:00
committed by GitHub
parent 8f3aabd89d
commit 57e69a35bc
4 changed files with 111 additions and 21 deletions

View File

@@ -0,0 +1,33 @@
"""add build request latency to step metrics
Revision ID: 750dd87faa12
Revises: 5b804970e6a0
Create Date: 2025-09-06 14:28:32.119084
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "750dd87faa12"
down_revision: Union[str, None] = "5b804970e6a0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("step_metrics", sa.Column("step_start_ns", sa.BigInteger(), nullable=True))
op.add_column("step_metrics", sa.Column("llm_request_start_ns", sa.BigInteger(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("step_metrics", "step_start_ns")
op.drop_column("step_metrics", "llm_request_start_ns")
# ### end Alembic commands ###

View File

@@ -1,3 +1,4 @@
import asyncio
import json
import uuid
from typing import AsyncGenerator, Tuple
@@ -369,19 +370,14 @@ class LettaAgentV2(BaseAgentV2):
step_id = approval_request.step_id
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
else:
step_id = generate_step_id()
step_start_ns = get_utc_timestamp_ns()
# Check for job cancellation at the start of each step
if run_id and await self._check_run_cancellation(run_id):
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
self.logger.info(f"Agent execution cancelled for run {run_id}")
return
agent_step_span = tracer.start_span("agent_step", start_time=step_start_ns)
agent_step_span.set_attributes({"step_id": step_id})
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
step_id = generate_step_id()
step_progression, step_metrics, agent_step_span = self._step_checkpoint_start(step_id=step_id)
# Create step early with PENDING status
logged_step = await self.step_manager.log_step_async(
@@ -412,11 +408,7 @@ class LettaAgentV2(BaseAgentV2):
yield request_data
return
provider_request_start_timestamp_ns = get_utc_timestamp_ns()
agent_step_span.add_event(
name="request_start_to_provider_request_start_ns",
attributes={"request_start_to_provider_request_start_ns": ns_to_ms(provider_request_start_timestamp_ns)},
)
step_progression, step_metrics = self._step_checkpoint_llm_request_start(step_metrics, agent_step_span)
try:
invocation = llm_adapter.invoke_llm(
@@ -436,10 +428,9 @@ class LettaAgentV2(BaseAgentV2):
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value)
raise
step_progression = StepProgression.RESPONSE_RECEIVED
llm_request_ns = llm_adapter.llm_request_finish_timestamp_ns - provider_request_start_timestamp_ns
step_metrics.llm_request_ns = llm_request_ns
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
step_progression, step_metrics = self._step_checkpoint_llm_request_finish(
step_metrics, agent_step_span, llm_adapter.llm_request_finish_timestamp_ns
)
self._update_global_usage_stats(llm_adapter.usage)
@@ -504,11 +495,7 @@ class LettaAgentV2(BaseAgentV2):
if include_return_message_types is None or message.message_type in include_return_message_types:
yield message
step_progression = StepProgression.FINISHED
if agent_step_span is not None:
step_ns = get_utc_timestamp_ns() - step_start_ns
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
step_progression, step_metrics = self._step_checkpoint_finish(step_metrics, agent_step_span, run_id)
def _initialize_state(self):
self.should_continue = True
@@ -691,6 +678,41 @@ class LettaAgentV2(BaseAgentV2):
raise ValueError(f"Invalid JSON format in message: {text_content}")
return None
def _step_checkpoint_start(self, step_id: str) -> Tuple[StepProgression, StepMetrics, Span]:
step_start_ns = get_utc_timestamp_ns()
step_metrics = StepMetrics(id=step_id, step_start_ns=step_start_ns)
agent_step_span = tracer.start_span("agent_step", start_time=step_start_ns)
agent_step_span.set_attributes({"step_id": step_id})
return StepProgression.START, step_metrics, agent_step_span
def _step_checkpoint_llm_request_start(self, step_metrics: StepMetrics, agent_step_span: Span) -> Tuple[StepProgression, StepMetrics]:
llm_request_start_ns = get_utc_timestamp_ns()
step_metrics.llm_request_start_ns = llm_request_start_ns
agent_step_span.add_event(
name="request_start_to_provider_request_start_ns",
attributes={"request_start_to_provider_request_start_ns": ns_to_ms(llm_request_start_ns)},
)
return StepProgression.START, step_metrics
def _step_checkpoint_llm_request_finish(
self, step_metrics: StepMetrics, agent_step_span: Span, llm_request_finish_timestamp_ns: int
) -> Tuple[StepProgression, StepMetrics]:
llm_request_ns = llm_request_finish_timestamp_ns - step_metrics.llm_request_start_ns
step_metrics.llm_request_ns = llm_request_ns
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
return StepProgression.RESPONSE_RECEIVED, step_metrics
def _step_checkpoint_finish(
self, step_metrics: StepMetrics, agent_step_span: Span | None, run_id: str | None
) -> Tuple[StepProgression, StepMetrics]:
step_ns = get_utc_timestamp_ns() - step_metrics.step_start_ns
step_metrics.step_ns = step_ns
if agent_step_span is not None:
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
self._record_step_metrics(step_id=step_metrics.step_id, step_metrics=step_metrics)
return StepProgression.FINISHED, step_metrics
def _update_global_usage_stats(self, step_usage_stats: LettaUsageStatistics):
self.usage.step_count += step_usage_stats.step_count
self.usage.completion_tokens += step_usage_stats.completion_tokens
@@ -1018,6 +1040,29 @@ class LettaAgentV2(BaseAgentV2):
return new_in_context_messages
def _record_step_metrics(
self,
*,
step_id: str,
step_metrics: StepMetrics,
run_id: str | None = None,
):
task = asyncio.create_task(
self.step_manager.record_step_metrics_async(
actor=self.actor,
step_id=step_id,
llm_request_ns=step_metrics.llm_request_ns,
tool_execution_ns=step_metrics.tool_execution_ns,
step_ns=step_metrics.step_ns,
agent_id=self.agent_state.id,
job_id=run_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
base_template_id=self.agent_state.base_template_id,
)
)
return task
def get_finish_chunks_for_stream(
self,
usage: LettaUsageStatistics,

View File

@@ -43,6 +43,16 @@ class StepMetrics(SqlalchemyBase, ProjectMixin, AgentMixin):
nullable=True,
doc="The unique identifier of the job",
)
step_start_ns: Mapped[Optional[int]] = mapped_column(
BigInteger,
nullable=True,
doc="The timestamp of the start of the step in nanoseconds",
)
llm_request_start_ns: Mapped[Optional[int]] = mapped_column(
BigInteger,
nullable=True,
doc="The timestamp of the start of the LLM request in nanoseconds",
)
llm_request_ns: Mapped[Optional[int]] = mapped_column(
BigInteger,
nullable=True,

View File

@@ -15,6 +15,8 @@ class StepMetrics(StepMetricsBase):
provider_id: Optional[str] = Field(None, description="The unique identifier of the provider.")
job_id: Optional[str] = Field(None, description="The unique identifier of the job.")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
step_start_ns: Optional[int] = Field(None, description="The timestamp of the start of the step in nanoseconds.")
llm_request_start_ns: Optional[int] = Field(None, description="The timestamp of the start of the llm request in nanoseconds.")
llm_request_ns: Optional[int] = Field(None, description="Time spent on LLM requests in nanoseconds.")
tool_execution_ns: Optional[int] = Field(None, description="Time spent on tool execution in nanoseconds.")
step_ns: Optional[int] = Field(None, description="Total time for the step in nanoseconds.")