From 76679e3eccaf47b6d75831e6900d137e39c084de Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Wed, 6 Aug 2025 15:46:50 -0700 Subject: [PATCH] feat: track metrics for runs in db --- ...bc564286_add_metrics_to_agent_loop_runs.py | 33 ++++++ letta/agents/letta_agent.py | 107 ++++++++++-------- .../anthropic_streaming_interface.py | 19 ---- .../interfaces/openai_streaming_interface.py | 23 ---- letta/orm/job.py | 6 +- letta/schemas/job.py | 4 + letta/services/job_manager.py | 24 ++++ letta/services/tool_manager.py | 10 +- tests/test_managers.py | 71 ++++++++++++ 9 files changed, 202 insertions(+), 95 deletions(-) create mode 100644 alembic/versions/05c3bc564286_add_metrics_to_agent_loop_runs.py diff --git a/alembic/versions/05c3bc564286_add_metrics_to_agent_loop_runs.py b/alembic/versions/05c3bc564286_add_metrics_to_agent_loop_runs.py new file mode 100644 index 00000000..d76b064b --- /dev/null +++ b/alembic/versions/05c3bc564286_add_metrics_to_agent_loop_runs.py @@ -0,0 +1,33 @@ +"""add metrics to agent loop runs + +Revision ID: 05c3bc564286 +Revises: d007f4ca66bf +Create Date: 2025-08-06 14:30:48.255538 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "05c3bc564286" +down_revision: Union[str, None] = "d007f4ca66bf" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("jobs", sa.Column("ttft_ns", sa.BigInteger(), nullable=True)) + op.add_column("jobs", sa.Column("total_duration_ns", sa.BigInteger(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("jobs", "total_duration_ns") + op.drop_column("jobs", "ttft_ns") + # ### end Alembic commands ### diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 4a7bef96..c3b95194 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -361,8 +361,16 @@ class LettaAgent(BaseAgent): if settings.track_stop_reason: if step_progression == StepProgression.FINISHED and should_continue: continue + + self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) self.logger.info("Running final update. Step Progression: %s", step_progression) try: + if step_progression == StepProgression.FINISHED and not should_continue: + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + break + if step_progression < StepProgression.STEP_LOGGED: await self.step_manager.log_step_async( actor=self.actor, @@ -391,12 +399,11 @@ class LettaAgent(BaseAgent): self.logger.error("Error in step after logging step") stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) - elif step_progression == StepProgression.FINISHED and not should_continue: - if stop_reason is None: - stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) - await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) else: self.logger.error("Invalid StepProgression value") + + await self._log_request(request_start_timestamp_ns, request_span) + except Exception as e: self.logger.error("Failed to update step: %s", e) @@ -413,17 +420,7 @@ class LettaAgent(BaseAgent): force=False, ) - # log request time - if request_start_timestamp_ns: - now = get_utc_timestamp_ns() - duration_ms = ns_to_ms(now - request_start_timestamp_ns) - request_span.add_event(name="letta_request_ms", attributes={"duration_ms": duration_ms}) - - # update agent's last run metrics - now_datetime = get_utc_time() - await self._update_agent_last_run_metrics(now_datetime, duration_ms) - - request_span.end() + await self._log_request(request_start_timestamp_ns, request_span) # Return back usage for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason): @@ -590,8 +587,16 @@ class LettaAgent(BaseAgent): if settings.track_stop_reason: if step_progression == StepProgression.FINISHED and should_continue: continue + + self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) self.logger.info("Running final update. Step Progression: %s", step_progression) try: + if step_progression == StepProgression.FINISHED and not should_continue: + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + break + if step_progression < StepProgression.STEP_LOGGED: await self.step_manager.log_step_async( actor=self.actor, @@ -620,30 +625,17 @@ class LettaAgent(BaseAgent): self.logger.error("Error in step after logging step") stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) - elif step_progression == StepProgression.FINISHED and not should_continue: - if stop_reason is None: - stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) - await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) else: self.logger.error("Invalid StepProgression value") + + await self._log_request(request_start_timestamp_ns, request_span) + except Exception as e: self.logger.error("Failed to update step: %s", e) if not should_continue: break - # log request time - if request_start_timestamp_ns: - now = get_utc_timestamp_ns() - duration_ms = ns_to_ms(now - request_start_timestamp_ns) - request_span.add_event(name="request_ms", attributes={"duration_ms": duration_ms}) - - # update agent's last run metrics - now_datetime = get_utc_time() - await self._update_agent_last_run_metrics(now_datetime, duration_ms) - - request_span.end() - # Extend the in context message ids if not agent_state.message_buffer_autoclear: await self._rebuild_context_window( @@ -654,6 +646,8 @@ class LettaAgent(BaseAgent): force=False, ) + await self._log_request(request_start_timestamp_ns, request_span) + return current_in_context_messages, new_in_context_messages, stop_reason, usage async def _update_agent_last_run_metrics(self, completion_time: datetime, duration_ms: float) -> None: @@ -755,7 +749,6 @@ class LettaAgent(BaseAgent): elif agent_state.llm_config.model_endpoint_type == ProviderType.openai: interface = OpenAIStreamingInterface( use_assistant_message=use_assistant_message, - put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, is_openai_proxy=agent_state.llm_config.provider_name == "lmstudio_openai", messages=current_in_context_messages + new_in_context_messages, tools=request_data.get("tools", []), @@ -766,16 +759,20 @@ class LettaAgent(BaseAgent): async for chunk in interface.process( stream, ttft_span=request_span, - provider_request_start_timestamp_ns=provider_request_start_timestamp_ns, ): - # Measure time to first token + # Measure TTFT (trace, metric, and db). This should be consolidated. if first_chunk and request_span is not None: now = get_utc_timestamp_ns() ttft_ns = now - request_start_timestamp_ns + request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)}) metric_attributes = get_ctx_attributes() metric_attributes["model.name"] = agent_state.llm_config.model MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes) + + if self.current_run_id and self.job_manager: + await self.job_manager.record_ttft(self.current_run_id, ttft_ns, self.actor) + first_chunk = False if include_return_message_types is None or chunk.message_type in include_return_message_types: @@ -913,8 +910,16 @@ class LettaAgent(BaseAgent): if settings.track_stop_reason: if step_progression == StepProgression.FINISHED and should_continue: continue + + self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) self.logger.info("Running final update. Step Progression: %s", step_progression) try: + if step_progression == StepProgression.FINISHED and not should_continue: + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + break + if step_progression < StepProgression.STEP_LOGGED: await self.step_manager.log_step_async( actor=self.actor, @@ -942,12 +947,12 @@ class LettaAgent(BaseAgent): self.logger.error("Error in step after logging step") stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) - elif step_progression == StepProgression.FINISHED and not should_continue: - if stop_reason is None: - stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) - await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) else: self.logger.error("Invalid StepProgression value") + + # Do tracking for failure cases. Can consolidate with success conditions later. + await self._log_request(request_start_timestamp_ns, request_span) + except Exception as e: self.logger.error("Failed to update step: %s", e) @@ -963,21 +968,23 @@ class LettaAgent(BaseAgent): force=False, ) - # log time of entire request - if request_start_timestamp_ns: - now = get_utc_timestamp_ns() - duration_ms = ns_to_ms(now - request_start_timestamp_ns) - request_span.add_event(name="letta_request_ms", attributes={"duration_ms": duration_ms}) - - # update agent's last run metrics - completion_time = get_utc_time() - await self._update_agent_last_run_metrics(completion_time, duration_ms) - - request_span.end() + await self._log_request(request_start_timestamp_ns, request_span) for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason): yield f"data: {finish_chunk}\n\n" + async def _log_request(self, request_start_timestamp_ns: int, request_span: "Span | None"): + if request_start_timestamp_ns: + now_ns, now = get_utc_timestamp_ns(), get_utc_time() + duration_ns = now_ns - request_start_timestamp_ns + if request_span: + request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)}) + await self._update_agent_last_run_metrics(now, ns_to_ms(duration_ns)) + if self.current_run_id: + await self.job_manager.record_response_duration(self.current_run_id, duration_ns, self.actor) + if request_span: + request_span.end() + # noinspection PyInconsistentReturns async def _build_and_request_from_llm( self, @@ -1428,6 +1435,8 @@ class LettaAgent(BaseAgent): status="error", ) + print(target_tool) + # TODO: This temp. Move this logic and code to executors if agent_step_span: diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 3022a11c..82d9287f 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -25,11 +25,8 @@ from anthropic.types.beta import ( ) from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG -from letta.helpers.datetime_helpers import get_utc_timestamp_ns, ns_to_ms from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.log import get_logger -from letta.otel.context import get_ctx_attributes -from letta.otel.metric_registry import MetricRegistry from letta.schemas.letta_message import ( AssistantMessage, HiddenReasoningMessage, @@ -133,28 +130,12 @@ class AnthropicStreamingInterface: self, stream: AsyncStream[BetaRawMessageStreamEvent], ttft_span: Optional["Span"] = None, - provider_request_start_timestamp_ns: int | None = None, ) -> AsyncGenerator[LettaMessage | LettaStopReason, None]: prev_message_type = None message_index = 0 - first_chunk = True try: async with stream: async for event in stream: - # TODO (cliandy): reconsider in stream cancellations - # await cancellation_token.check_and_raise_if_cancelled() - if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None: - now = get_utc_timestamp_ns() - ttft_ns = now - provider_request_start_timestamp_ns - ttft_span.add_event( - name="anthropic_time_to_first_token_ms", attributes={"anthropic_time_to_first_token_ms": ns_to_ms(ttft_ns)} - ) - metric_attributes = get_ctx_attributes() - if isinstance(event, BetaRawMessageStartEvent): - metric_attributes["model.name"] = event.message.model - MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes) - first_chunk = False - # TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock if isinstance(event, BetaRawContentBlockStartEvent): content = event.content_block diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 710c3ba0..656402f2 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -7,12 +7,9 @@ from openai import AsyncStream from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG -from letta.helpers.datetime_helpers import get_utc_timestamp_ns, ns_to_ms from letta.llm_api.openai_client import is_openai_reasoning_model from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.log import get_logger -from letta.otel.context import get_ctx_attributes -from letta.otel.metric_registry import MetricRegistry from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage from letta.schemas.letta_message_content import OmittedReasoningContent, TextContent from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType @@ -35,7 +32,6 @@ class OpenAIStreamingInterface: def __init__( self, use_assistant_message: bool = False, - put_inner_thoughts_in_kwarg: bool = False, is_openai_proxy: bool = False, messages: Optional[list] = None, tools: Optional[list] = None, @@ -107,7 +103,6 @@ class OpenAIStreamingInterface: self, stream: AsyncStream[ChatCompletionChunk], ttft_span: Optional["Span"] = None, - provider_request_start_timestamp_ns: int | None = None, ) -> AsyncGenerator[LettaMessage | LettaStopReason, None]: """ Iterates over the OpenAI stream, yielding SSE events. @@ -125,29 +120,11 @@ class OpenAIStreamingInterface: tool_dicts = [tool["function"] if isinstance(tool, dict) and "function" in tool else tool for tool in self.tools] self.fallback_input_tokens += num_tokens_from_functions(tool_dicts) - first_chunk = True try: async with stream: prev_message_type = None message_index = 0 async for chunk in stream: - # TODO (cliandy): reconsider in stream cancellations - # await cancellation_token.check_and_raise_if_cancelled() - if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None: - now = get_utc_timestamp_ns() - ttft_ns = now - provider_request_start_timestamp_ns - ttft_span.add_event( - name="openai_time_to_first_token_ms", attributes={"openai_time_to_first_token_ms": ns_to_ms(ttft_ns)} - ) - metric_attributes = get_ctx_attributes() - metric_attributes["model.name"] = chunk.model - MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes) - - if self.is_openai_proxy: - self.fallback_output_tokens += count_tokens(chunk.model_dump_json()) - - first_chunk = False - if not self.model or not self.message_id: self.model = chunk.model self.message_id = chunk.id diff --git a/letta/orm/job.py b/letta/orm/job.py index 5e2e14cc..fb349170 100644 --- a/letta/orm/job.py +++ b/letta/orm/job.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, Index, String +from sqlalchemy import JSON, BigInteger, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import UserMixin @@ -46,6 +46,10 @@ class Job(SqlalchemyBase, UserMixin): nullable=True, doc="Optional error message from attempting to POST the callback endpoint." ) + # timing metrics (in nanoseconds for precision) + ttft_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, doc="Time to first token in nanoseconds") + total_duration_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, doc="Total run duration in nanoseconds") + # relationships user: Mapped["User"] = relationship("User", back_populates="jobs") job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan") diff --git a/letta/schemas/job.py b/letta/schemas/job.py index fadde684..9eff28bb 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -21,6 +21,10 @@ class JobBase(OrmMetadataBase): callback_status_code: Optional[int] = Field(None, description="HTTP status code returned by the callback endpoint.") callback_error: Optional[str] = Field(None, description="Optional error message from attempting to POST the callback endpoint.") + # Timing metrics (in nanoseconds for precision) + ttft_ns: int | None = Field(None, description="Time to first token for a run in nanoseconds") + total_duration_ns: int | None = Field(None, description="Total run duration in nanoseconds") + class Job(JobBase): """ diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 4f8b76ff..aaa17f74 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -806,6 +806,30 @@ class JobManager: request_config = job.request_config or LettaRequestConfig() return request_config + @enforce_types + async def record_ttft(self, job_id: str, ttft_ns: int, actor: PydanticUser) -> None: + """Record time to first token for a run""" + try: + async with db_registry.async_session() as session: + job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"]) + job.ttft_ns = ttft_ns + await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) + await session.commit() + except Exception as e: + logger.warning(f"Failed to record TTFT for job {job_id}: {e}") + + @enforce_types + async def record_response_duration(self, job_id: str, total_duration_ns: int, actor: PydanticUser) -> None: + """Record total response duration for a run""" + try: + async with db_registry.async_session() as session: + job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"]) + job.total_duration_ns = total_duration_ns + await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) + await session.commit() + except Exception as e: + logger.warning(f"Failed to record response duration for job {job_id}: {e}") + @trace_method def _dispatch_callback_sync(self, callback_info: dict) -> dict: """ diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 9c76a825..6cae38ae 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -364,9 +364,13 @@ class ToolManager: results.append(pydantic_tool) except (ValueError, ModuleNotFoundError, AttributeError) as e: tools_to_delete.append(tool) - logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}") - logger.warning("Deleted tool: ") - logger.warning(tool.pretty_print_columns()) + logger.warning( + "Deleting malformed tool with id=%s and name=%s. Error was:\n%s\nDeleted tool:%s", + tool.id, + tool.name, + e, + tool.pretty_print_columns(), + ) for tool in tools_to_delete: await self.delete_tool_by_id_async(tool.id, actor=actor) diff --git a/tests/test_managers.py b/tests/test_managers.py index 4fdb5859..5f6b0d1c 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -8043,6 +8043,77 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user) +@pytest.mark.asyncio +async def test_record_ttft(server: SyncServer, default_user, event_loop): + """Test recording time to first token for a job.""" + # Create a job + job_data = PydanticJob( + status=JobStatus.created, + metadata={"type": "test_timing"}, + ) + created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user) + + # Record TTFT + ttft_ns = 1_500_000_000 # 1.5 seconds in nanoseconds + await server.job_manager.record_ttft(created_job.id, ttft_ns, default_user) + + # Fetch the job and verify TTFT was recorded + updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user) + assert updated_job.ttft_ns == ttft_ns + + +@pytest.mark.asyncio +async def test_record_response_duration(server: SyncServer, default_user, event_loop): + """Test recording total response duration for a job.""" + # Create a job + job_data = PydanticJob( + status=JobStatus.created, + metadata={"type": "test_timing"}, + ) + created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user) + + # Record response duration + duration_ns = 5_000_000_000 # 5 seconds in nanoseconds + await server.job_manager.record_response_duration(created_job.id, duration_ns, default_user) + + # Fetch the job and verify duration was recorded + updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user) + assert updated_job.total_duration_ns == duration_ns + + +@pytest.mark.asyncio +async def test_record_timing_metrics_together(server: SyncServer, default_user, event_loop): + """Test recording both TTFT and response duration for a job.""" + # Create a job + job_data = PydanticJob( + status=JobStatus.created, + metadata={"type": "test_timing_combined"}, + ) + created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user) + + # Record both metrics + ttft_ns = 2_000_000_000 # 2 seconds in nanoseconds + duration_ns = 8_500_000_000 # 8.5 seconds in nanoseconds + + await server.job_manager.record_ttft(created_job.id, ttft_ns, default_user) + await server.job_manager.record_response_duration(created_job.id, duration_ns, default_user) + + # Fetch the job and verify both metrics were recorded + updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user) + assert updated_job.ttft_ns == ttft_ns + assert updated_job.total_duration_ns == duration_ns + + +@pytest.mark.asyncio +async def test_record_timing_invalid_job(server: SyncServer, default_user, event_loop): + """Test recording timing metrics for non-existent job fails gracefully.""" + # Try to record TTFT for non-existent job - should not raise exception but log warning + await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user) + + # Try to record response duration for non-existent job - should not raise exception but log warning + await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user) + + def test_list_tags(server: SyncServer, default_user, default_organization): """Test listing tags functionality.""" # Create multiple agents with different tags