From 0179901e79177faf13f6755b15fb4e9dbefc069c Mon Sep 17 00:00:00 2001 From: jnjpng Date: Wed, 13 Aug 2025 11:29:18 -0700 Subject: [PATCH] feat: record step metrics to table Co-authored-by: Jin Peng --- letta/agents/letta_agent.py | 141 +++++++++++++++++++++++++++++- letta/helpers/datetime_helpers.py | 12 ++- letta/orm/step_metrics.py | 41 ++++++++- letta/services/step_manager.py | 61 +++++++++++++ tests/test_managers.py | 68 ++++++++++++++ 5 files changed, 316 insertions(+), 7 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 46af4fad..8c7255aa 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -44,6 +44,7 @@ from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.schemas.provider_trace import ProviderTraceCreate from letta.schemas.step import StepProgression +from letta.schemas.step_metrics import StepMetrics from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -242,6 +243,7 @@ class LettaAgent(BaseAgent): step_progression = StepProgression.START should_continue = False + step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking # Create step early with PENDING status logged_step = await self.step_manager.log_step_async( @@ -271,6 +273,7 @@ class LettaAgent(BaseAgent): llm_client, tool_rules_solver, agent_step_span, + step_metrics, ) ) in_context_messages = current_in_context_messages + new_in_context_messages @@ -320,6 +323,7 @@ class LettaAgent(BaseAgent): initial_messages=initial_messages, agent_step_span=agent_step_span, is_final_step=(i == max_steps - 1), + step_metrics=step_metrics, ) step_progression = StepProgression.STEP_LOGGED @@ -365,6 +369,17 @@ class LettaAgent(BaseAgent): MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) step_progression = StepProgression.FINISHED + + # Record step metrics for successful completion + if logged_step and step_metrics: + # Set the step_ns that was already calculated + step_metrics.step_ns = step_ns + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + ) + except Exception as e: # Handle any unexpected errors during step processing self.logger.error(f"Error during step processing: {e}") @@ -433,6 +448,17 @@ class LettaAgent(BaseAgent): if settings.track_stop_reason: await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + # Record partial step metrics on failure (capture whatever timing data we have) + if logged_step and step_metrics and step_progression < StepProgression.FINISHED: + # Calculate total step time up to the failure point + step_metrics.step_ns = get_utc_timestamp_ns() - step_start + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + job_id=locals().get("run_id", self.current_run_id), + ) + except Exception as e: self.logger.error("Failed to update step: %s", e) @@ -513,6 +539,7 @@ class LettaAgent(BaseAgent): step_progression = StepProgression.START should_continue = False + step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking # Create step early with PENDING status logged_step = await self.step_manager.log_step_async( @@ -536,7 +563,13 @@ class LettaAgent(BaseAgent): try: request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( await self._build_and_request_from_llm( - current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span + current_in_context_messages, + new_in_context_messages, + agent_state, + llm_client, + tool_rules_solver, + agent_step_span, + step_metrics, ) ) in_context_messages = current_in_context_messages + new_in_context_messages @@ -587,6 +620,7 @@ class LettaAgent(BaseAgent): agent_step_span=agent_step_span, is_final_step=(i == max_steps - 1), run_id=run_id, + step_metrics=step_metrics, ) step_progression = StepProgression.STEP_LOGGED @@ -622,6 +656,17 @@ class LettaAgent(BaseAgent): MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) step_progression = StepProgression.FINISHED + # Record step metrics for successful completion + if logged_step and step_metrics: + # Set the step_ns that was already calculated + step_metrics.step_ns = step_ns + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + job_id=run_id if run_id else self.current_run_id, + ) + except Exception as e: # Handle any unexpected errors during step processing self.logger.error(f"Error during step processing: {e}") @@ -686,6 +731,17 @@ class LettaAgent(BaseAgent): if settings.track_stop_reason: await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + # Record partial step metrics on failure (capture whatever timing data we have) + if logged_step and step_metrics and step_progression < StepProgression.FINISHED: + # Calculate total step time up to the failure point + step_metrics.step_ns = get_utc_timestamp_ns() - step_start + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + job_id=locals().get("run_id", self.current_run_id), + ) + except Exception as e: self.logger.error("Failed to update step: %s", e) @@ -774,6 +830,7 @@ class LettaAgent(BaseAgent): step_progression = StepProgression.START should_continue = False + step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking # Create step early with PENDING status logged_step = await self.step_manager.log_step_async( @@ -875,7 +932,10 @@ class LettaAgent(BaseAgent): ) # log LLM request time - llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns) + llm_request_ns = stream_end_time_ns - provider_request_start_timestamp_ns + step_metrics.llm_request_ns = llm_request_ns + + llm_request_ms = ns_to_ms(llm_request_ns) agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms}) MetricRegistry().llm_execution_time_ms_histogram.record( llm_request_ms, @@ -908,6 +968,7 @@ class LettaAgent(BaseAgent): initial_messages=initial_messages, agent_step_span=agent_step_span, is_final_step=(i == max_steps - 1), + step_metrics=step_metrics, ) step_progression = StepProgression.STEP_LOGGED @@ -979,6 +1040,25 @@ class LettaAgent(BaseAgent): MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) step_progression = StepProgression.FINISHED + # Record step metrics for successful completion + if logged_step and step_metrics: + try: + # Set the step_ns that was already calculated + step_metrics.step_ns = step_ns + + # Get context attributes for project and template IDs + ctx_attrs = get_ctx_attributes() + + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + ctx_attrs=ctx_attrs, + job_id=self.current_run_id, + ) + except Exception as metrics_error: + self.logger.warning(f"Failed to record step metrics: {metrics_error}") + except Exception as e: # Handle any unexpected errors during step processing self.logger.error(f"Error during step processing: {e}") @@ -1047,6 +1127,25 @@ class LettaAgent(BaseAgent): if settings.track_stop_reason: await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + # Record partial step metrics on failure (capture whatever timing data we have) + if logged_step and step_metrics and step_progression < StepProgression.FINISHED: + try: + # Calculate total step time up to the failure point + step_metrics.step_ns = get_utc_timestamp_ns() - step_start + + # Get context attributes for project and template IDs + ctx_attrs = get_ctx_attributes() + + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + ctx_attrs=ctx_attrs, + job_id=locals().get("run_id", self.current_run_id), + ) + except Exception as metrics_error: + self.logger.warning(f"Failed to record step metrics: {metrics_error}") + except Exception as e: self.logger.error("Failed to update step: %s", e) @@ -1087,6 +1186,32 @@ class LettaAgent(BaseAgent): if request_span: request_span.end() + async def _record_step_metrics( + self, + *, + step_id: str, + agent_state: AgentState, + step_metrics: StepMetrics, + ctx_attrs: dict | None = None, + job_id: str | None = None, + ) -> None: + try: + attrs = ctx_attrs or get_ctx_attributes() + await self.step_manager.record_step_metrics_async( + actor=self.actor, + step_id=step_id, + llm_request_ns=step_metrics.llm_request_ns, + tool_execution_ns=step_metrics.tool_execution_ns, + step_ns=step_metrics.step_ns, + agent_id=agent_state.id, + job_id=job_id or self.current_run_id, + project_id=attrs.get("project.id") or agent_state.project_id, + template_id=attrs.get("template.id"), + base_template_id=attrs.get("base_template.id"), + ) + except Exception as metrics_error: + self.logger.warning(f"Failed to record step metrics: {metrics_error}") + # noinspection PyInconsistentReturns async def _build_and_request_from_llm( self, @@ -1096,6 +1221,7 @@ class LettaAgent(BaseAgent): llm_client: LLMClientBase, tool_rules_solver: ToolRulesSolver, agent_step_span: "Span", + step_metrics: StepMetrics, ) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None: for attempt in range(self.max_summarization_retries + 1): try: @@ -1112,6 +1238,10 @@ class LettaAgent(BaseAgent): async with AsyncTimer() as timer: # Attempt LLM request response = await llm_client.request_async(request_data, agent_state.llm_config) + + # Track LLM request time + step_metrics.llm_request_ns = int(timer.elapsed_ns) + MetricRegistry().llm_execution_time_ms_histogram.record( timer.elapsed_ms, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}), @@ -1352,6 +1482,7 @@ class LettaAgent(BaseAgent): agent_step_span: Optional["Span"] = None, is_final_step: bool | None = None, run_id: str | None = None, + step_metrics: StepMetrics = None, ) -> tuple[list[Message], bool, LettaStopReason | None]: """ Handle the final AI response once streaming completes, execute / validate the @@ -1378,6 +1509,8 @@ class LettaAgent(BaseAgent): if tool_rule_violated: tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver) else: + # Track tool execution time + tool_start_time = get_utc_timestamp_ns() tool_execution_result = await self._execute_tool( tool_name=tool_call_name, tool_args=tool_args, @@ -1385,6 +1518,10 @@ class LettaAgent(BaseAgent): agent_step_span=agent_step_span, step_id=step_id, ) + tool_end_time = get_utc_timestamp_ns() + + # Store tool execution time in metrics + step_metrics.tool_execution_ns = tool_end_time - tool_start_time log_telemetry( self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id diff --git a/letta/helpers/datetime_helpers.py b/letta/helpers/datetime_helpers.py index 07856495..789b34e2 100644 --- a/letta/helpers/datetime_helpers.py +++ b/letta/helpers/datetime_helpers.py @@ -118,7 +118,7 @@ class AsyncTimer: def __init__(self, callback_func: Callable | None = None): self._start_time_ns = None self._end_time_ns = None - self.elapsed_ns = None + self._elapsed_ns = None self.callback_func = callback_func async def __aenter__(self): @@ -127,7 +127,7 @@ class AsyncTimer: async def __aexit__(self, exc_type, exc, tb): self._end_time_ns = time.perf_counter_ns() - self.elapsed_ns = self._end_time_ns - self._start_time_ns + self._elapsed_ns = self._end_time_ns - self._start_time_ns if self.callback_func: from asyncio import iscoroutinefunction @@ -139,6 +139,10 @@ class AsyncTimer: @property def elapsed_ms(self): - if self.elapsed_ns is not None: - return ns_to_ms(self.elapsed_ns) + if self._elapsed_ns is not None: + return ns_to_ms(self._elapsed_ns) return None + + @property + def elapsed_ns(self): + return self._elapsed_ns diff --git a/letta/orm/step_metrics.py b/letta/orm/step_metrics.py index c85607cd..760db52e 100644 --- a/letta/orm/step_metrics.py +++ b/letta/orm/step_metrics.py @@ -1,11 +1,15 @@ +from datetime import datetime, timezone from typing import TYPE_CHECKING, Optional from sqlalchemy import BigInteger, ForeignKey, String -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped, Session, mapped_column, relationship from letta.orm.mixins import AgentMixin, ProjectMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.step_metrics import StepMetrics as PydanticStepMetrics +from letta.schemas.user import User +from letta.settings import DatabaseChoice, settings if TYPE_CHECKING: from letta.orm.agent import Agent @@ -69,3 +73,38 @@ class StepMetrics(SqlalchemyBase, ProjectMixin, AgentMixin): step: Mapped["Step"] = relationship("Step", back_populates="metrics", uselist=False) job: Mapped[Optional["Job"]] = relationship("Job") agent: Mapped[Optional["Agent"]] = relationship("Agent") + + def create( + self, + db_session: Session, + actor: Optional[User] = None, + no_commit: bool = False, + ) -> "StepMetrics": + """Override create to handle SQLite timestamp issues""" + # For SQLite, explicitly set timestamps as server_default may not work + if settings.database_engine == DatabaseChoice.SQLITE: + now = datetime.now(timezone.utc) + if not self.created_at: + self.created_at = now + if not self.updated_at: + self.updated_at = now + + return super().create(db_session, actor=actor, no_commit=no_commit) + + async def create_async( + self, + db_session: AsyncSession, + actor: Optional[User] = None, + no_commit: bool = False, + no_refresh: bool = False, + ) -> "StepMetrics": + """Override create_async to handle SQLite timestamp issues""" + # For SQLite, explicitly set timestamps as server_default may not work + if settings.database_engine == DatabaseChoice.SQLITE: + now = datetime.now(timezone.utc) + if not self.created_at: + self.created_at = now + if not self.updated_at: + self.updated_at = now + + return await super().create_async(db_session, actor=actor, no_commit=no_commit, no_refresh=no_refresh) diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 26257179..7aa1ee3a 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -11,11 +11,13 @@ from letta.orm.errors import NoResultFound from letta.orm.job import Job as JobModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step as StepModel +from letta.orm.step_metrics import StepMetrics as StepMetricsModel from letta.otel.tracing import get_trace_id, trace_method from letta.schemas.enums import StepStatus from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.step import Step as PydanticStep +from letta.schemas.step_metrics import StepMetrics as PydanticStepMetrics from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types @@ -372,6 +374,65 @@ class StepManager: await session.commit() return step.to_pydantic() + @enforce_types + @trace_method + async def record_step_metrics_async( + self, + actor: PydanticUser, + step_id: str, + llm_request_ns: Optional[int] = None, + tool_execution_ns: Optional[int] = None, + step_ns: Optional[int] = None, + agent_id: Optional[str] = None, + job_id: Optional[str] = None, + project_id: Optional[str] = None, + template_id: Optional[str] = None, + base_template_id: Optional[str] = None, + ) -> PydanticStepMetrics: + """Record performance metrics for a step. + + Args: + actor: The user making the request + step_id: The ID of the step to record metrics for + llm_request_ns: Time spent on LLM request in nanoseconds + tool_execution_ns: Time spent on tool execution in nanoseconds + step_ns: Total time for the step in nanoseconds + agent_id: The ID of the agent + job_id: The ID of the job + project_id: The ID of the project + template_id: The ID of the template + base_template_id: The ID of the base template + + Returns: + The created step metrics + + Raises: + NoResultFound: If the step does not exist + """ + async with db_registry.async_session() as session: + step = await session.get(StepModel, step_id) + if not step: + raise NoResultFound(f"Step with id {step_id} does not exist") + if step.organization_id != actor.organization_id: + raise Exception("Unauthorized") + + metrics_data = { + "id": step_id, + "organization_id": actor.organization_id, + "agent_id": agent_id or step.agent_id, + "job_id": job_id or step.job_id, + "project_id": project_id or step.project_id, + "llm_request_ns": llm_request_ns, + "tool_execution_ns": tool_execution_ns, + "step_ns": step_ns, + "template_id": template_id, + "base_template_id": base_template_id, + } + + metrics = StepMetricsModel(**metrics_data) + await metrics.create_async(session) + return metrics.to_pydantic() + def _verify_job_access( self, session: Session, diff --git a/tests/test_managers.py b/tests/test_managers.py index ccd2802c..319d9ce8 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -8340,6 +8340,74 @@ async def test_step_manager_list_steps_with_status_filter(server: SyncServer, sa assert status_counts[status] >= 1, f"No steps found with status {status}" +async def test_step_manager_record_metrics(server: SyncServer, sarah_agent, default_job, default_user, event_loop): + """Test recording step metrics functionality.""" + step_manager = server.step_manager + + # Create a step first + step = await step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + job_id=default_job.id, + usage=UsageStatistics( + completion_tokens=10, + prompt_tokens=20, + total_tokens=30, + ), + actor=default_user, + project_id=sarah_agent.project_id, + status=StepStatus.PENDING, + ) + + # Record metrics for the step + llm_request_ns = 1_500_000_000 # 1.5 seconds + tool_execution_ns = 500_000_000 # 0.5 seconds + step_ns = 2_100_000_000 # 2.1 seconds + + metrics = await step_manager.record_step_metrics_async( + actor=default_user, + step_id=step.id, + llm_request_ns=llm_request_ns, + tool_execution_ns=tool_execution_ns, + step_ns=step_ns, + agent_id=sarah_agent.id, + job_id=default_job.id, + project_id=sarah_agent.project_id, + template_id="template-id", + base_template_id="base-template-id", + ) + + # Verify the metrics were recorded correctly + assert metrics.id == step.id + assert metrics.llm_request_ns == llm_request_ns + assert metrics.tool_execution_ns == tool_execution_ns + assert metrics.step_ns == step_ns + assert metrics.agent_id == sarah_agent.id + assert metrics.job_id == default_job.id + assert metrics.project_id == sarah_agent.project_id + assert metrics.template_id == "template-id" + assert metrics.base_template_id == "base-template-id" + + +async def test_step_manager_record_metrics_nonexistent_step(server: SyncServer, default_user, event_loop): + """Test recording metrics for a nonexistent step.""" + step_manager = server.step_manager + + # Try to record metrics for a step that doesn't exist + with pytest.raises(NoResultFound): + await step_manager.record_step_metrics_async( + actor=default_user, + step_id="nonexistent-step-id", + llm_request_ns=1_000_000_000, + tool_execution_ns=500_000_000, + step_ns=1_600_000_000, + ) + + def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): """Test getting usage statistics for a nonexistent job.""" job_manager = server.job_manager