feat: record step metrics to table

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-08-13 11:29:18 -07:00
committed by GitHub
parent 9054919eba
commit 0179901e79
5 changed files with 316 additions and 7 deletions

View File

@@ -44,6 +44,7 @@ from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
@@ -242,6 +243,7 @@ class LettaAgent(BaseAgent):
step_progression = StepProgression.START
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
# Create step early with PENDING status
logged_step = await self.step_manager.log_step_async(
@@ -271,6 +273,7 @@ class LettaAgent(BaseAgent):
llm_client,
tool_rules_solver,
agent_step_span,
step_metrics,
)
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -320,6 +323,7 @@ class LettaAgent(BaseAgent):
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
)
step_progression = StepProgression.STEP_LOGGED
@@ -365,6 +369,17 @@ class LettaAgent(BaseAgent):
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
# Record step metrics for successful completion
if logged_step and step_metrics:
# Set the step_ns that was already calculated
step_metrics.step_ns = step_ns
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
)
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
@@ -433,6 +448,17 @@ class LettaAgent(BaseAgent):
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
# Record partial step metrics on failure (capture whatever timing data we have)
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
# Calculate total step time up to the failure point
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
job_id=locals().get("run_id", self.current_run_id),
)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
@@ -513,6 +539,7 @@ class LettaAgent(BaseAgent):
step_progression = StepProgression.START
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
# Create step early with PENDING status
logged_step = await self.step_manager.log_step_async(
@@ -536,7 +563,13 @@ class LettaAgent(BaseAgent):
try:
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm(
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
agent_step_span,
step_metrics,
)
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -587,6 +620,7 @@ class LettaAgent(BaseAgent):
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
run_id=run_id,
step_metrics=step_metrics,
)
step_progression = StepProgression.STEP_LOGGED
@@ -622,6 +656,17 @@ class LettaAgent(BaseAgent):
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
# Record step metrics for successful completion
if logged_step and step_metrics:
# Set the step_ns that was already calculated
step_metrics.step_ns = step_ns
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
job_id=run_id if run_id else self.current_run_id,
)
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
@@ -686,6 +731,17 @@ class LettaAgent(BaseAgent):
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
# Record partial step metrics on failure (capture whatever timing data we have)
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
# Calculate total step time up to the failure point
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
job_id=locals().get("run_id", self.current_run_id),
)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
@@ -774,6 +830,7 @@ class LettaAgent(BaseAgent):
step_progression = StepProgression.START
should_continue = False
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
# Create step early with PENDING status
logged_step = await self.step_manager.log_step_async(
@@ -875,7 +932,10 @@ class LettaAgent(BaseAgent):
)
# log LLM request time
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
llm_request_ns = stream_end_time_ns - provider_request_start_timestamp_ns
step_metrics.llm_request_ns = llm_request_ns
llm_request_ms = ns_to_ms(llm_request_ns)
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
MetricRegistry().llm_execution_time_ms_histogram.record(
llm_request_ms,
@@ -908,6 +968,7 @@ class LettaAgent(BaseAgent):
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
)
step_progression = StepProgression.STEP_LOGGED
@@ -979,6 +1040,25 @@ class LettaAgent(BaseAgent):
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
# Record step metrics for successful completion
if logged_step and step_metrics:
try:
# Set the step_ns that was already calculated
step_metrics.step_ns = step_ns
# Get context attributes for project and template IDs
ctx_attrs = get_ctx_attributes()
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
ctx_attrs=ctx_attrs,
job_id=self.current_run_id,
)
except Exception as metrics_error:
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
@@ -1047,6 +1127,25 @@ class LettaAgent(BaseAgent):
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
# Record partial step metrics on failure (capture whatever timing data we have)
if logged_step and step_metrics and step_progression < StepProgression.FINISHED:
try:
# Calculate total step time up to the failure point
step_metrics.step_ns = get_utc_timestamp_ns() - step_start
# Get context attributes for project and template IDs
ctx_attrs = get_ctx_attributes()
await self._record_step_metrics(
step_id=step_id,
agent_state=agent_state,
step_metrics=step_metrics,
ctx_attrs=ctx_attrs,
job_id=locals().get("run_id", self.current_run_id),
)
except Exception as metrics_error:
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
except Exception as e:
self.logger.error("Failed to update step: %s", e)
@@ -1087,6 +1186,32 @@ class LettaAgent(BaseAgent):
if request_span:
request_span.end()
async def _record_step_metrics(
self,
*,
step_id: str,
agent_state: AgentState,
step_metrics: StepMetrics,
ctx_attrs: dict | None = None,
job_id: str | None = None,
) -> None:
try:
attrs = ctx_attrs or get_ctx_attributes()
await self.step_manager.record_step_metrics_async(
actor=self.actor,
step_id=step_id,
llm_request_ns=step_metrics.llm_request_ns,
tool_execution_ns=step_metrics.tool_execution_ns,
step_ns=step_metrics.step_ns,
agent_id=agent_state.id,
job_id=job_id or self.current_run_id,
project_id=attrs.get("project.id") or agent_state.project_id,
template_id=attrs.get("template.id"),
base_template_id=attrs.get("base_template.id"),
)
except Exception as metrics_error:
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
# noinspection PyInconsistentReturns
async def _build_and_request_from_llm(
self,
@@ -1096,6 +1221,7 @@ class LettaAgent(BaseAgent):
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
agent_step_span: "Span",
step_metrics: StepMetrics,
) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
@@ -1112,6 +1238,10 @@ class LettaAgent(BaseAgent):
async with AsyncTimer() as timer:
# Attempt LLM request
response = await llm_client.request_async(request_data, agent_state.llm_config)
# Track LLM request time
step_metrics.llm_request_ns = int(timer.elapsed_ns)
MetricRegistry().llm_execution_time_ms_histogram.record(
timer.elapsed_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
@@ -1352,6 +1482,7 @@ class LettaAgent(BaseAgent):
agent_step_span: Optional["Span"] = None,
is_final_step: bool | None = None,
run_id: str | None = None,
step_metrics: StepMetrics = None,
) -> tuple[list[Message], bool, LettaStopReason | None]:
"""
Handle the final AI response once streaming completes, execute / validate the
@@ -1378,6 +1509,8 @@ class LettaAgent(BaseAgent):
if tool_rule_violated:
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
else:
# Track tool execution time
tool_start_time = get_utc_timestamp_ns()
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
tool_args=tool_args,
@@ -1385,6 +1518,10 @@ class LettaAgent(BaseAgent):
agent_step_span=agent_step_span,
step_id=step_id,
)
tool_end_time = get_utc_timestamp_ns()
# Store tool execution time in metrics
step_metrics.tool_execution_ns = tool_end_time - tool_start_time
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id

View File

@@ -118,7 +118,7 @@ class AsyncTimer:
def __init__(self, callback_func: Callable | None = None):
self._start_time_ns = None
self._end_time_ns = None
self.elapsed_ns = None
self._elapsed_ns = None
self.callback_func = callback_func
async def __aenter__(self):
@@ -127,7 +127,7 @@ class AsyncTimer:
async def __aexit__(self, exc_type, exc, tb):
self._end_time_ns = time.perf_counter_ns()
self.elapsed_ns = self._end_time_ns - self._start_time_ns
self._elapsed_ns = self._end_time_ns - self._start_time_ns
if self.callback_func:
from asyncio import iscoroutinefunction
@@ -139,6 +139,10 @@ class AsyncTimer:
@property
def elapsed_ms(self):
if self.elapsed_ns is not None:
return ns_to_ms(self.elapsed_ns)
if self._elapsed_ns is not None:
return ns_to_ms(self._elapsed_ns)
return None
@property
def elapsed_ns(self):
return self._elapsed_ns

View File

@@ -1,11 +1,15 @@
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional
from sqlalchemy import BigInteger, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
from letta.orm.mixins import AgentMixin, ProjectMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.step_metrics import StepMetrics as PydanticStepMetrics
from letta.schemas.user import User
from letta.settings import DatabaseChoice, settings
if TYPE_CHECKING:
from letta.orm.agent import Agent
@@ -69,3 +73,38 @@ class StepMetrics(SqlalchemyBase, ProjectMixin, AgentMixin):
step: Mapped["Step"] = relationship("Step", back_populates="metrics", uselist=False)
job: Mapped[Optional["Job"]] = relationship("Job")
agent: Mapped[Optional["Agent"]] = relationship("Agent")
def create(
self,
db_session: Session,
actor: Optional[User] = None,
no_commit: bool = False,
) -> "StepMetrics":
"""Override create to handle SQLite timestamp issues"""
# For SQLite, explicitly set timestamps as server_default may not work
if settings.database_engine == DatabaseChoice.SQLITE:
now = datetime.now(timezone.utc)
if not self.created_at:
self.created_at = now
if not self.updated_at:
self.updated_at = now
return super().create(db_session, actor=actor, no_commit=no_commit)
async def create_async(
self,
db_session: AsyncSession,
actor: Optional[User] = None,
no_commit: bool = False,
no_refresh: bool = False,
) -> "StepMetrics":
"""Override create_async to handle SQLite timestamp issues"""
# For SQLite, explicitly set timestamps as server_default may not work
if settings.database_engine == DatabaseChoice.SQLITE:
now = datetime.now(timezone.utc)
if not self.created_at:
self.created_at = now
if not self.updated_at:
self.updated_at = now
return await super().create_async(db_session, actor=actor, no_commit=no_commit, no_refresh=no_refresh)

View File

@@ -11,11 +11,13 @@ from letta.orm.errors import NoResultFound
from letta.orm.job import Job as JobModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel
from letta.orm.step_metrics import StepMetrics as StepMetricsModel
from letta.otel.tracing import get_trace_id, trace_method
from letta.schemas.enums import StepStatus
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.step_metrics import StepMetrics as PydanticStepMetrics
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
@@ -372,6 +374,65 @@ class StepManager:
await session.commit()
return step.to_pydantic()
@enforce_types
@trace_method
async def record_step_metrics_async(
self,
actor: PydanticUser,
step_id: str,
llm_request_ns: Optional[int] = None,
tool_execution_ns: Optional[int] = None,
step_ns: Optional[int] = None,
agent_id: Optional[str] = None,
job_id: Optional[str] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
base_template_id: Optional[str] = None,
) -> PydanticStepMetrics:
"""Record performance metrics for a step.
Args:
actor: The user making the request
step_id: The ID of the step to record metrics for
llm_request_ns: Time spent on LLM request in nanoseconds
tool_execution_ns: Time spent on tool execution in nanoseconds
step_ns: Total time for the step in nanoseconds
agent_id: The ID of the agent
job_id: The ID of the job
project_id: The ID of the project
template_id: The ID of the template
base_template_id: The ID of the base template
Returns:
The created step metrics
Raises:
NoResultFound: If the step does not exist
"""
async with db_registry.async_session() as session:
step = await session.get(StepModel, step_id)
if not step:
raise NoResultFound(f"Step with id {step_id} does not exist")
if step.organization_id != actor.organization_id:
raise Exception("Unauthorized")
metrics_data = {
"id": step_id,
"organization_id": actor.organization_id,
"agent_id": agent_id or step.agent_id,
"job_id": job_id or step.job_id,
"project_id": project_id or step.project_id,
"llm_request_ns": llm_request_ns,
"tool_execution_ns": tool_execution_ns,
"step_ns": step_ns,
"template_id": template_id,
"base_template_id": base_template_id,
}
metrics = StepMetricsModel(**metrics_data)
await metrics.create_async(session)
return metrics.to_pydantic()
def _verify_job_access(
self,
session: Session,

View File

@@ -8340,6 +8340,74 @@ async def test_step_manager_list_steps_with_status_filter(server: SyncServer, sa
assert status_counts[status] >= 1, f"No steps found with status {status}"
async def test_step_manager_record_metrics(server: SyncServer, sarah_agent, default_job, default_user, event_loop):
"""Test recording step metrics functionality."""
step_manager = server.step_manager
# Create a step first
step = await step_manager.log_step_async(
agent_id=sarah_agent.id,
provider_name="openai",
provider_category="base",
model="gpt-4o-mini",
model_endpoint="https://api.openai.com/v1",
context_window_limit=8192,
job_id=default_job.id,
usage=UsageStatistics(
completion_tokens=10,
prompt_tokens=20,
total_tokens=30,
),
actor=default_user,
project_id=sarah_agent.project_id,
status=StepStatus.PENDING,
)
# Record metrics for the step
llm_request_ns = 1_500_000_000 # 1.5 seconds
tool_execution_ns = 500_000_000 # 0.5 seconds
step_ns = 2_100_000_000 # 2.1 seconds
metrics = await step_manager.record_step_metrics_async(
actor=default_user,
step_id=step.id,
llm_request_ns=llm_request_ns,
tool_execution_ns=tool_execution_ns,
step_ns=step_ns,
agent_id=sarah_agent.id,
job_id=default_job.id,
project_id=sarah_agent.project_id,
template_id="template-id",
base_template_id="base-template-id",
)
# Verify the metrics were recorded correctly
assert metrics.id == step.id
assert metrics.llm_request_ns == llm_request_ns
assert metrics.tool_execution_ns == tool_execution_ns
assert metrics.step_ns == step_ns
assert metrics.agent_id == sarah_agent.id
assert metrics.job_id == default_job.id
assert metrics.project_id == sarah_agent.project_id
assert metrics.template_id == "template-id"
assert metrics.base_template_id == "base-template-id"
async def test_step_manager_record_metrics_nonexistent_step(server: SyncServer, default_user, event_loop):
"""Test recording metrics for a nonexistent step."""
step_manager = server.step_manager
# Try to record metrics for a step that doesn't exist
with pytest.raises(NoResultFound):
await step_manager.record_step_metrics_async(
actor=default_user,
step_id="nonexistent-step-id",
llm_request_ns=1_000_000_000,
tool_execution_ns=500_000_000,
step_ns=1_600_000_000,
)
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
"""Test getting usage statistics for a nonexistent job."""
job_manager = server.job_manager