feat: track metrics for runs in db

This commit is contained in:
Andy Li
2025-08-06 15:46:50 -07:00
committed by GitHub
parent 8faa8711bc
commit ca6f474c4e
9 changed files with 202 additions and 95 deletions

View File

@@ -0,0 +1,33 @@
"""add metrics to agent loop runs
Revision ID: 05c3bc564286
Revises: d007f4ca66bf
Create Date: 2025-08-06 14:30:48.255538
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "05c3bc564286"
down_revision: Union[str, None] = "d007f4ca66bf"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("jobs", sa.Column("ttft_ns", sa.BigInteger(), nullable=True))
op.add_column("jobs", sa.Column("total_duration_ns", sa.BigInteger(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("jobs", "total_duration_ns")
op.drop_column("jobs", "ttft_ns")
# ### end Alembic commands ###

View File

@@ -361,8 +361,16 @@ class LettaAgent(BaseAgent):
if settings.track_stop_reason:
if step_progression == StepProgression.FINISHED and should_continue:
continue
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
break
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
@@ -391,12 +399,11 @@ class LettaAgent(BaseAgent):
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
await self._log_request(request_start_timestamp_ns, request_span)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
@@ -413,17 +420,7 @@ class LettaAgent(BaseAgent):
force=False,
)
# log request time
if request_start_timestamp_ns:
now = get_utc_timestamp_ns()
duration_ms = ns_to_ms(now - request_start_timestamp_ns)
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": duration_ms})
# update agent's last run metrics
now_datetime = get_utc_time()
await self._update_agent_last_run_metrics(now_datetime, duration_ms)
request_span.end()
await self._log_request(request_start_timestamp_ns, request_span)
# Return back usage
for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason):
@@ -590,8 +587,16 @@ class LettaAgent(BaseAgent):
if settings.track_stop_reason:
if step_progression == StepProgression.FINISHED and should_continue:
continue
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
break
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
@@ -620,30 +625,17 @@ class LettaAgent(BaseAgent):
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
await self._log_request(request_start_timestamp_ns, request_span)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue:
break
# log request time
if request_start_timestamp_ns:
now = get_utc_timestamp_ns()
duration_ms = ns_to_ms(now - request_start_timestamp_ns)
request_span.add_event(name="request_ms", attributes={"duration_ms": duration_ms})
# update agent's last run metrics
now_datetime = get_utc_time()
await self._update_agent_last_run_metrics(now_datetime, duration_ms)
request_span.end()
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
await self._rebuild_context_window(
@@ -654,6 +646,8 @@ class LettaAgent(BaseAgent):
force=False,
)
await self._log_request(request_start_timestamp_ns, request_span)
return current_in_context_messages, new_in_context_messages, stop_reason, usage
async def _update_agent_last_run_metrics(self, completion_time: datetime, duration_ms: float) -> None:
@@ -755,7 +749,6 @@ class LettaAgent(BaseAgent):
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
interface = OpenAIStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
is_openai_proxy=agent_state.llm_config.provider_name == "lmstudio_openai",
messages=current_in_context_messages + new_in_context_messages,
tools=request_data.get("tools", []),
@@ -766,16 +759,20 @@ class LettaAgent(BaseAgent):
async for chunk in interface.process(
stream,
ttft_span=request_span,
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
):
# Measure time to first token
# Measure TTFT (trace, metric, and db). This should be consolidated.
if first_chunk and request_span is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = agent_state.llm_config.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
if self.current_run_id and self.job_manager:
await self.job_manager.record_ttft(self.current_run_id, ttft_ns, self.actor)
first_chunk = False
if include_return_message_types is None or chunk.message_type in include_return_message_types:
@@ -913,8 +910,16 @@ class LettaAgent(BaseAgent):
if settings.track_stop_reason:
if step_progression == StepProgression.FINISHED and should_continue:
continue
self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id)
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
break
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
@@ -942,12 +947,12 @@ class LettaAgent(BaseAgent):
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
# Do tracking for failure cases. Can consolidate with success conditions later.
await self._log_request(request_start_timestamp_ns, request_span)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
@@ -963,21 +968,23 @@ class LettaAgent(BaseAgent):
force=False,
)
# log time of entire request
if request_start_timestamp_ns:
now = get_utc_timestamp_ns()
duration_ms = ns_to_ms(now - request_start_timestamp_ns)
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": duration_ms})
# update agent's last run metrics
completion_time = get_utc_time()
await self._update_agent_last_run_metrics(completion_time, duration_ms)
request_span.end()
await self._log_request(request_start_timestamp_ns, request_span)
for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason):
yield f"data: {finish_chunk}\n\n"
async def _log_request(self, request_start_timestamp_ns: int, request_span: "Span | None"):
if request_start_timestamp_ns:
now_ns, now = get_utc_timestamp_ns(), get_utc_time()
duration_ns = now_ns - request_start_timestamp_ns
if request_span:
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)})
await self._update_agent_last_run_metrics(now, ns_to_ms(duration_ns))
if self.current_run_id:
await self.job_manager.record_response_duration(self.current_run_id, duration_ns, self.actor)
if request_span:
request_span.end()
# noinspection PyInconsistentReturns
async def _build_and_request_from_llm(
self,
@@ -1428,6 +1435,8 @@ class LettaAgent(BaseAgent):
status="error",
)
print(target_tool)
# TODO: This temp. Move this logic and code to executors
if agent_step_span:

View File

@@ -25,11 +25,8 @@ from anthropic.types.beta import (
)
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.helpers.datetime_helpers import get_utc_timestamp_ns, ns_to_ms
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.schemas.letta_message import (
AssistantMessage,
HiddenReasoningMessage,
@@ -133,28 +130,12 @@ class AnthropicStreamingInterface:
self,
stream: AsyncStream[BetaRawMessageStreamEvent],
ttft_span: Optional["Span"] = None,
provider_request_start_timestamp_ns: int | None = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
prev_message_type = None
message_index = 0
first_chunk = True
try:
async with stream:
async for event in stream:
# TODO (cliandy): reconsider in stream cancellations
# await cancellation_token.check_and_raise_if_cancelled()
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - provider_request_start_timestamp_ns
ttft_span.add_event(
name="anthropic_time_to_first_token_ms", attributes={"anthropic_time_to_first_token_ms": ns_to_ms(ttft_ns)}
)
metric_attributes = get_ctx_attributes()
if isinstance(event, BetaRawMessageStartEvent):
metric_attributes["model.name"] = event.message.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
if isinstance(event, BetaRawContentBlockStartEvent):
content = event.content_block

View File

@@ -7,12 +7,9 @@ from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.helpers.datetime_helpers import get_utc_timestamp_ns, ns_to_ms
from letta.llm_api.openai_client import is_openai_reasoning_model
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, TextContent
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
@@ -35,7 +32,6 @@ class OpenAIStreamingInterface:
def __init__(
self,
use_assistant_message: bool = False,
put_inner_thoughts_in_kwarg: bool = False,
is_openai_proxy: bool = False,
messages: Optional[list] = None,
tools: Optional[list] = None,
@@ -107,7 +103,6 @@ class OpenAIStreamingInterface:
self,
stream: AsyncStream[ChatCompletionChunk],
ttft_span: Optional["Span"] = None,
provider_request_start_timestamp_ns: int | None = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
"""
Iterates over the OpenAI stream, yielding SSE events.
@@ -125,29 +120,11 @@ class OpenAIStreamingInterface:
tool_dicts = [tool["function"] if isinstance(tool, dict) and "function" in tool else tool for tool in self.tools]
self.fallback_input_tokens += num_tokens_from_functions(tool_dicts)
first_chunk = True
try:
async with stream:
prev_message_type = None
message_index = 0
async for chunk in stream:
# TODO (cliandy): reconsider in stream cancellations
# await cancellation_token.check_and_raise_if_cancelled()
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - provider_request_start_timestamp_ns
ttft_span.add_event(
name="openai_time_to_first_token_ms", attributes={"openai_time_to_first_token_ms": ns_to_ms(ttft_ns)}
)
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = chunk.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
if self.is_openai_proxy:
self.fallback_output_tokens += count_tokens(chunk.model_dump_json())
first_chunk = False
if not self.model or not self.message_id:
self.model = chunk.model
self.message_id = chunk.id

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, Index, String
from sqlalchemy import JSON, BigInteger, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import UserMixin
@@ -46,6 +46,10 @@ class Job(SqlalchemyBase, UserMixin):
nullable=True, doc="Optional error message from attempting to POST the callback endpoint."
)
# timing metrics (in nanoseconds for precision)
ttft_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, doc="Time to first token in nanoseconds")
total_duration_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, doc="Total run duration in nanoseconds")
# relationships
user: Mapped["User"] = relationship("User", back_populates="jobs")
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")

View File

@@ -21,6 +21,10 @@ class JobBase(OrmMetadataBase):
callback_status_code: Optional[int] = Field(None, description="HTTP status code returned by the callback endpoint.")
callback_error: Optional[str] = Field(None, description="Optional error message from attempting to POST the callback endpoint.")
# Timing metrics (in nanoseconds for precision)
ttft_ns: int | None = Field(None, description="Time to first token for a run in nanoseconds")
total_duration_ns: int | None = Field(None, description="Total run duration in nanoseconds")
class Job(JobBase):
"""

View File

@@ -806,6 +806,30 @@ class JobManager:
request_config = job.request_config or LettaRequestConfig()
return request_config
@enforce_types
async def record_ttft(self, job_id: str, ttft_ns: int, actor: PydanticUser) -> None:
"""Record time to first token for a run"""
try:
async with db_registry.async_session() as session:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.ttft_ns = ttft_ns
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
except Exception as e:
logger.warning(f"Failed to record TTFT for job {job_id}: {e}")
@enforce_types
async def record_response_duration(self, job_id: str, total_duration_ns: int, actor: PydanticUser) -> None:
"""Record total response duration for a run"""
try:
async with db_registry.async_session() as session:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.total_duration_ns = total_duration_ns
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
except Exception as e:
logger.warning(f"Failed to record response duration for job {job_id}: {e}")
@trace_method
def _dispatch_callback_sync(self, callback_info: dict) -> dict:
"""

View File

@@ -364,9 +364,13 @@ class ToolManager:
results.append(pydantic_tool)
except (ValueError, ModuleNotFoundError, AttributeError) as e:
tools_to_delete.append(tool)
logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
logger.warning("Deleted tool: ")
logger.warning(tool.pretty_print_columns())
logger.warning(
"Deleting malformed tool with id=%s and name=%s. Error was:\n%s\nDeleted tool:%s",
tool.id,
tool.name,
e,
tool.pretty_print_columns(),
)
for tool in tools_to_delete:
await self.delete_tool_by_id_async(tool.id, actor=actor)

View File

@@ -8043,6 +8043,77 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user)
@pytest.mark.asyncio
async def test_record_ttft(server: SyncServer, default_user, event_loop):
"""Test recording time to first token for a job."""
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata={"type": "test_timing"},
)
created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# Record TTFT
ttft_ns = 1_500_000_000 # 1.5 seconds in nanoseconds
await server.job_manager.record_ttft(created_job.id, ttft_ns, default_user)
# Fetch the job and verify TTFT was recorded
updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user)
assert updated_job.ttft_ns == ttft_ns
@pytest.mark.asyncio
async def test_record_response_duration(server: SyncServer, default_user, event_loop):
"""Test recording total response duration for a job."""
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata={"type": "test_timing"},
)
created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# Record response duration
duration_ns = 5_000_000_000 # 5 seconds in nanoseconds
await server.job_manager.record_response_duration(created_job.id, duration_ns, default_user)
# Fetch the job and verify duration was recorded
updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user)
assert updated_job.total_duration_ns == duration_ns
@pytest.mark.asyncio
async def test_record_timing_metrics_together(server: SyncServer, default_user, event_loop):
"""Test recording both TTFT and response duration for a job."""
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata={"type": "test_timing_combined"},
)
created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# Record both metrics
ttft_ns = 2_000_000_000 # 2 seconds in nanoseconds
duration_ns = 8_500_000_000 # 8.5 seconds in nanoseconds
await server.job_manager.record_ttft(created_job.id, ttft_ns, default_user)
await server.job_manager.record_response_duration(created_job.id, duration_ns, default_user)
# Fetch the job and verify both metrics were recorded
updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user)
assert updated_job.ttft_ns == ttft_ns
assert updated_job.total_duration_ns == duration_ns
@pytest.mark.asyncio
async def test_record_timing_invalid_job(server: SyncServer, default_user, event_loop):
"""Test recording timing metrics for non-existent job fails gracefully."""
# Try to record TTFT for non-existent job - should not raise exception but log warning
await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user)
# Try to record response duration for non-existent job - should not raise exception but log warning
await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user)
def test_list_tags(server: SyncServer, default_user, default_organization):
"""Test listing tags functionality."""
# Create multiple agents with different tags