fix: metric tracking (#2785)

This commit is contained in:
Andy Li
2025-06-13 13:53:10 -07:00
committed by GitHub
parent 2e77ea6e76
commit 33bfd14017
7 changed files with 101 additions and 44 deletions

View File

@@ -5,6 +5,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk
from opentelemetry.trace import Span
from letta.agents.base_agent import BaseAgent
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
@@ -178,17 +179,13 @@ class LettaAgent(BaseAgent):
agent_state,
llm_client,
tool_rules_solver,
agent_step_span,
)
)
in_context_messages = current_in_context_messages + new_in_context_messages
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
# log llm request time
now = get_utc_timestamp_ns()
llm_request_ns = now - step_start
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# update usage
@@ -197,6 +194,9 @@ class LettaAgent(BaseAgent):
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
@@ -216,11 +216,6 @@ class LettaAgent(BaseAgent):
logger.info("No reasoning content found.")
reasoning = None
# log LLM request time
now = get_utc_timestamp_ns()
llm_request_ns = now - step_start
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
valid_tool_names,
@@ -265,6 +260,8 @@ class LettaAgent(BaseAgent):
if include_return_message_types is None or message.message_type in include_return_message_types:
yield f"data: {message.model_dump_json()}\n\n"
MetricRegistry().step_execution_time_ms_histogram.record(step_start - get_utc_timestamp_ns(), get_ctx_attributes())
if not should_continue:
break
@@ -327,7 +324,7 @@ class LettaAgent(BaseAgent):
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm(
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
)
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -336,16 +333,14 @@ class LettaAgent(BaseAgent):
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# log LLM request time
now = get_utc_timestamp_ns()
llm_request_ns = now - step_start
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
@@ -399,6 +394,8 @@ class LettaAgent(BaseAgent):
),
)
MetricRegistry().step_execution_time_ms_histogram.record(step_start - get_utc_timestamp_ns(), get_ctx_attributes())
if not should_continue:
break
@@ -458,24 +455,28 @@ class LettaAgent(BaseAgent):
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
provider_request_start_timestamp_ns = None
for i in range(max_steps):
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
request_data, stream, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm_streaming(
first_chunk,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
)
(
request_data,
stream,
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
provider_request_start_timestamp_ns,
) = await self._build_and_request_from_llm_streaming(
first_chunk,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
)
log_event("agent.stream.llm_response.received") # [3^]
@@ -502,12 +503,17 @@ class LettaAgent(BaseAgent):
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = agent_state.llm_config.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
if include_return_message_types is None or chunk.message_type in include_return_message_types:
# filter down returned data
yield f"data: {chunk.model_dump_json()}\n\n"
stream_end_time_ns = get_utc_timestamp_ns()
# update usage
usage.step_count += 1
usage.completion_tokens += interface.output_tokens
@@ -518,9 +524,12 @@ class LettaAgent(BaseAgent):
)
# log LLM request time
now = get_utc_timestamp_ns()
llm_request_ns = now - step_start
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
llm_request_ms = ns_to_ms(stream_end_time_ns - request_start_timestamp_ns)
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
MetricRegistry().llm_execution_time_ms_histogram.record(
llm_request_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Process resulting stream content
tool_call = interface.get_tool_call_object()
@@ -585,6 +594,9 @@ class LettaAgent(BaseAgent):
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
# TODO (cliandy): consolidate and expand with trace
MetricRegistry().step_execution_time_ms_histogram.record(step_start - get_utc_timestamp_ns(), get_ctx_attributes())
if not should_continue:
break
@@ -608,6 +620,7 @@ class LettaAgent(BaseAgent):
for finish_chunk in self.get_finish_chunks_for_stream(usage):
yield f"data: {finish_chunk}\n\n"
# noinspection PyInconsistentReturns
async def _build_and_request_from_llm(
self,
current_in_context_messages: List[Message],
@@ -615,7 +628,8 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]]:
agent_step_span: "Span",
) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
@@ -629,13 +643,15 @@ class LettaAgent(BaseAgent):
log_event("agent.stream_no_tokens.llm_request.created")
async with AsyncTimer() as timer:
# Attempt LLM request
response = await llm_client.request_async(request_data, agent_state.llm_config)
MetricRegistry().llm_execution_time_ms_histogram.record(
timer.elapsed_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Attempt LLM request
return (request_data, response, current_in_context_messages, new_in_context_messages, valid_tool_names)
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": timer.elapsed_ms})
return request_data, response, current_in_context_messages, new_in_context_messages, valid_tool_names
except Exception as e:
if attempt == self.max_summarization_retries:
@@ -653,6 +669,7 @@ class LettaAgent(BaseAgent):
new_in_context_messages = []
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
# noinspection PyInconsistentReturns
async def _build_and_request_from_llm_streaming(
self,
first_chunk: bool,
@@ -663,7 +680,7 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str]]:
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str], int] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
@@ -676,10 +693,13 @@ class LettaAgent(BaseAgent):
)
log_event("agent.stream.llm_request.created") # [2^]
provider_request_start_timestamp_ns = get_utc_timestamp_ns()
if first_chunk and ttft_span is not None:
provider_request_start_timestamp_ns = get_utc_timestamp_ns()
provider_req_start_ns = provider_request_start_timestamp_ns - request_start_timestamp_ns
ttft_span.add_event(name="provider_req_start_ns", attributes={"provider_req_start_ms": ns_to_ms(provider_req_start_ns)})
request_start_to_provider_request_start_ns = provider_request_start_timestamp_ns - request_start_timestamp_ns
ttft_span.add_event(
name="request_start_to_provider_request_start_ns",
attributes={"request_start_to_provider_request_start_ns": ns_to_ms(request_start_to_provider_request_start_ns)},
)
# Attempt LLM request
return (
@@ -688,6 +708,7 @@ class LettaAgent(BaseAgent):
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
provider_request_start_timestamp_ns,
)
except Exception as e:
@@ -703,7 +724,7 @@ class LettaAgent(BaseAgent):
llm_config=agent_state.llm_config,
force=True,
)
new_in_context_messages = []
new_in_context_messages: list[Message] = []
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
@trace_method

View File

@@ -26,6 +26,8 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.helpers.datetime_helpers import get_utc_timestamp_ns, ns_to_ms
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.schemas.letta_message import (
AssistantMessage,
HiddenReasoningMessage,
@@ -142,6 +144,10 @@ class AnthropicStreamingInterface:
ttft_span.add_event(
name="anthropic_time_to_first_token_ms", attributes={"anthropic_time_to_first_token_ms": ns_to_ms(ttft_ns)}
)
metric_attributes = get_ctx_attributes()
if isinstance(event, BetaRawMessageStartEvent):
metric_attributes["model.name"] = event.message.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock

View File

@@ -7,6 +7,8 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.helpers.datetime_helpers import get_utc_timestamp_ns, ns_to_ms
from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message
@@ -95,6 +97,10 @@ class OpenAIStreamingInterface:
ttft_span.add_event(
name="openai_time_to_first_token_ms", attributes={"openai_time_to_first_token_ms": ns_to_ms(ttft_ns)}
)
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = chunk.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
if not self.model or not self.message_id:

View File

@@ -95,6 +95,18 @@ class MetricRegistry:
),
)
@property
def step_execution_time_ms_histogram(self) -> Histogram:
return self._get_or_create_metric(
"hist_step_execution_time_ms",
partial(
self._meter.create_histogram,
name="hist_step_execution_time_ms",
description="Histogram for step execution time (ms)",
unit="ms",
),
)
# TODO (cliandy): instrument this
@property
def message_cost(self) -> Histogram:

View File

@@ -5,14 +5,15 @@ from typing import List
from fastapi import FastAPI, Request
from opentelemetry import metrics
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.metrics import NoOpMeter
from opentelemetry.metrics import Counter, Histogram, NoOpMeter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
from letta.helpers.datetime_helpers import ns_to_ms
from letta.log import get_logger
from letta.otel.context import add_ctx_attribute, get_ctx_attributes
from letta.otel.resource import get_resource, is_pytest_environment
from letta.settings import settings
logger = get_logger(__name__)
@@ -110,9 +111,17 @@ def setup_metrics(
assert endpoint
global _is_metrics_initialized, _meter
otlp_metric_exporter = OTLPMetricExporter(endpoint=endpoint)
preferred_temporality = AggregationTemporality(settings.otel_preferred_temporality)
otlp_metric_exporter = OTLPMetricExporter(
endpoint=endpoint,
preferred_temporality={
# Add more as needed here.
Counter: preferred_temporality,
Histogram: preferred_temporality,
},
)
metric_reader = PeriodicExportingMetricReader(exporter=otlp_metric_exporter)
meter_provider = MeterProvider(resource=get_resource(service_name), metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
_meter = metrics.get_meter(__name__)

View File

@@ -88,7 +88,7 @@ async def sse_async_generator(
metric_attributes = get_ctx_attributes()
if llm_config:
metric_attributes["model.name"] = llm_config.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
# yield f"data: {json.dumps(chunk)}\n\n"

View File

@@ -205,6 +205,9 @@ class Settings(BaseSettings):
# telemetry logging
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
otel_preferred_temporality: Optional[int] = Field(
default=1, ge=0, le=2, description="Exported metric temporality. {0: UNSPECIFIED, 1: DELTA, 2: CUMULATIVE}"
)
disable_tracing: bool = False
llm_api_logging: bool = True