feat: add step checkpointing for progress tracking (#4458)
* feat: add step checkpointing for progress tracking * openapi sync
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
"""add build request latency to step metrics
|
||||
|
||||
Revision ID: 750dd87faa12
|
||||
Revises: 5b804970e6a0
|
||||
Create Date: 2025-09-06 14:28:32.119084
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "750dd87faa12"
|
||||
down_revision: Union[str, None] = "5b804970e6a0"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("step_metrics", sa.Column("step_start_ns", sa.BigInteger(), nullable=True))
|
||||
op.add_column("step_metrics", sa.Column("llm_request_start_ns", sa.BigInteger(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("step_metrics", "step_start_ns")
|
||||
op.drop_column("step_metrics", "llm_request_start_ns")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Tuple
|
||||
@@ -369,19 +370,14 @@ class LettaAgentV2(BaseAgentV2):
|
||||
step_id = approval_request.step_id
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
|
||||
else:
|
||||
step_id = generate_step_id()
|
||||
step_start_ns = get_utc_timestamp_ns()
|
||||
|
||||
# Check for job cancellation at the start of each step
|
||||
if run_id and await self._check_run_cancellation(run_id):
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||
self.logger.info(f"Agent execution cancelled for run {run_id}")
|
||||
return
|
||||
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start_ns)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||||
step_id = generate_step_id()
|
||||
step_progression, step_metrics, agent_step_span = self._step_checkpoint_start(step_id=step_id)
|
||||
|
||||
# Create step early with PENDING status
|
||||
logged_step = await self.step_manager.log_step_async(
|
||||
@@ -412,11 +408,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
yield request_data
|
||||
return
|
||||
|
||||
provider_request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||
agent_step_span.add_event(
|
||||
name="request_start_to_provider_request_start_ns",
|
||||
attributes={"request_start_to_provider_request_start_ns": ns_to_ms(provider_request_start_timestamp_ns)},
|
||||
)
|
||||
step_progression, step_metrics = self._step_checkpoint_llm_request_start(step_metrics, agent_step_span)
|
||||
|
||||
try:
|
||||
invocation = llm_adapter.invoke_llm(
|
||||
@@ -436,10 +428,9 @@ class LettaAgentV2(BaseAgentV2):
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value)
|
||||
raise
|
||||
|
||||
step_progression = StepProgression.RESPONSE_RECEIVED
|
||||
llm_request_ns = llm_adapter.llm_request_finish_timestamp_ns - provider_request_start_timestamp_ns
|
||||
step_metrics.llm_request_ns = llm_request_ns
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
|
||||
step_progression, step_metrics = self._step_checkpoint_llm_request_finish(
|
||||
step_metrics, agent_step_span, llm_adapter.llm_request_finish_timestamp_ns
|
||||
)
|
||||
|
||||
self._update_global_usage_stats(llm_adapter.usage)
|
||||
|
||||
@@ -504,11 +495,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
if include_return_message_types is None or message.message_type in include_return_message_types:
|
||||
yield message
|
||||
|
||||
step_progression = StepProgression.FINISHED
|
||||
if agent_step_span is not None:
|
||||
step_ns = get_utc_timestamp_ns() - step_start_ns
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
step_progression, step_metrics = self._step_checkpoint_finish(step_metrics, agent_step_span, run_id)
|
||||
|
||||
def _initialize_state(self):
|
||||
self.should_continue = True
|
||||
@@ -691,6 +678,41 @@ class LettaAgentV2(BaseAgentV2):
|
||||
raise ValueError(f"Invalid JSON format in message: {text_content}")
|
||||
return None
|
||||
|
||||
def _step_checkpoint_start(self, step_id: str) -> Tuple[StepProgression, StepMetrics, Span]:
|
||||
step_start_ns = get_utc_timestamp_ns()
|
||||
step_metrics = StepMetrics(id=step_id, step_start_ns=step_start_ns)
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start_ns)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
return StepProgression.START, step_metrics, agent_step_span
|
||||
|
||||
def _step_checkpoint_llm_request_start(self, step_metrics: StepMetrics, agent_step_span: Span) -> Tuple[StepProgression, StepMetrics]:
|
||||
llm_request_start_ns = get_utc_timestamp_ns()
|
||||
step_metrics.llm_request_start_ns = llm_request_start_ns
|
||||
agent_step_span.add_event(
|
||||
name="request_start_to_provider_request_start_ns",
|
||||
attributes={"request_start_to_provider_request_start_ns": ns_to_ms(llm_request_start_ns)},
|
||||
)
|
||||
return StepProgression.START, step_metrics
|
||||
|
||||
def _step_checkpoint_llm_request_finish(
|
||||
self, step_metrics: StepMetrics, agent_step_span: Span, llm_request_finish_timestamp_ns: int
|
||||
) -> Tuple[StepProgression, StepMetrics]:
|
||||
llm_request_ns = llm_request_finish_timestamp_ns - step_metrics.llm_request_start_ns
|
||||
step_metrics.llm_request_ns = llm_request_ns
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
|
||||
return StepProgression.RESPONSE_RECEIVED, step_metrics
|
||||
|
||||
def _step_checkpoint_finish(
|
||||
self, step_metrics: StepMetrics, agent_step_span: Span | None, run_id: str | None
|
||||
) -> Tuple[StepProgression, StepMetrics]:
|
||||
step_ns = get_utc_timestamp_ns() - step_metrics.step_start_ns
|
||||
step_metrics.step_ns = step_ns
|
||||
if agent_step_span is not None:
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
self._record_step_metrics(step_id=step_metrics.step_id, step_metrics=step_metrics)
|
||||
return StepProgression.FINISHED, step_metrics
|
||||
|
||||
def _update_global_usage_stats(self, step_usage_stats: LettaUsageStatistics):
|
||||
self.usage.step_count += step_usage_stats.step_count
|
||||
self.usage.completion_tokens += step_usage_stats.completion_tokens
|
||||
@@ -1018,6 +1040,29 @@ class LettaAgentV2(BaseAgentV2):
|
||||
|
||||
return new_in_context_messages
|
||||
|
||||
def _record_step_metrics(
|
||||
self,
|
||||
*,
|
||||
step_id: str,
|
||||
step_metrics: StepMetrics,
|
||||
run_id: str | None = None,
|
||||
):
|
||||
task = asyncio.create_task(
|
||||
self.step_manager.record_step_metrics_async(
|
||||
actor=self.actor,
|
||||
step_id=step_id,
|
||||
llm_request_ns=step_metrics.llm_request_ns,
|
||||
tool_execution_ns=step_metrics.tool_execution_ns,
|
||||
step_ns=step_metrics.step_ns,
|
||||
agent_id=self.agent_state.id,
|
||||
job_id=run_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
base_template_id=self.agent_state.base_template_id,
|
||||
)
|
||||
)
|
||||
return task
|
||||
|
||||
def get_finish_chunks_for_stream(
|
||||
self,
|
||||
usage: LettaUsageStatistics,
|
||||
|
||||
@@ -43,6 +43,16 @@ class StepMetrics(SqlalchemyBase, ProjectMixin, AgentMixin):
|
||||
nullable=True,
|
||||
doc="The unique identifier of the job",
|
||||
)
|
||||
step_start_ns: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=True,
|
||||
doc="The timestamp of the start of the step in nanoseconds",
|
||||
)
|
||||
llm_request_start_ns: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=True,
|
||||
doc="The timestamp of the start of the LLM request in nanoseconds",
|
||||
)
|
||||
llm_request_ns: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=True,
|
||||
|
||||
@@ -15,6 +15,8 @@ class StepMetrics(StepMetricsBase):
|
||||
provider_id: Optional[str] = Field(None, description="The unique identifier of the provider.")
|
||||
job_id: Optional[str] = Field(None, description="The unique identifier of the job.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
step_start_ns: Optional[int] = Field(None, description="The timestamp of the start of the step in nanoseconds.")
|
||||
llm_request_start_ns: Optional[int] = Field(None, description="The timestamp of the start of the llm request in nanoseconds.")
|
||||
llm_request_ns: Optional[int] = Field(None, description="Time spent on LLM requests in nanoseconds.")
|
||||
tool_execution_ns: Optional[int] = Field(None, description="Time spent on tool execution in nanoseconds.")
|
||||
step_ns: Optional[int] = Field(None, description="Total time for the step in nanoseconds.")
|
||||
|
||||
Reference in New Issue
Block a user