chore: add tracing for request middleware (#8142)

* base

* update

* more
This commit is contained in:
jnjpng
2025-12-29 14:33:29 -08:00
committed by Caren Thomas
parent 190dbfa93b
commit 7e8088adc5
4 changed files with 99 additions and 87 deletions

View File

@@ -4,17 +4,16 @@ from dataclasses import dataclass
from functools import wraps
from typing import Callable
from opentelemetry import trace
from pydantic import BaseModel
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.log import get_logger
from letta.otel.tracing import tracer
from letta.plugins.plugins import get_experimental_checker
from letta.settings import settings
logger = get_logger(__name__)
tracer = trace.get_tracer(__name__)
def experimental(feature_name: str, fallback_function: Callable, **kwargs):

View File

@@ -37,6 +37,9 @@ _excluded_v1_endpoints_regex: List[str] = [
async def _trace_request_middleware(request: Request, call_next):
# Capture earliest possible timestamp when request enters application
entry_time = time.time()
if not _is_tracing_initialized:
return await call_next(request)
initial_span_name = f"{request.method} {request.url.path}"
@@ -47,8 +50,13 @@ async def _trace_request_middleware(request: Request, call_next):
initial_span_name,
kind=trace.SpanKind.SERVER,
) as span:
# Record when we entered the application (useful for detecting worker queuing)
span.set_attribute("entry.timestamp_ms", int(entry_time * 1000))
try:
response = await call_next(request)
# This span captures all downstream middleware (CORS, RequestId, Logging) + handler
with tracer.start_as_current_span("middleware.chain"):
response = await call_next(request)
span.set_attribute("http.status_code", response.status_code)
span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR))
return response
@@ -100,9 +108,10 @@ async def _update_trace_attributes(request: Request):
# Add request body if available
try:
body = await request.json()
for key, value in body.items():
span.set_attribute(f"http.request.body.{key}", str(value))
with tracer.start_as_current_span("trace.request_body"):
body = await request.json()
for key, value in body.items():
span.set_attribute(f"http.request.body.{key}", str(value))
except Exception:
pass

View File

@@ -2,7 +2,6 @@
Unified logging middleware that enriches log context and ensures exceptions are logged.
"""
import re
import traceback
from typing import Callable
@@ -11,6 +10,7 @@ from starlette.requests import Request
from letta.log import get_logger
from letta.log_context import clear_log_context, update_log_context
from letta.otel.tracing import tracer
from letta.schemas.enums import PrimitiveType
from letta.validators import PRIMITIVE_ID_PATTERNS
@@ -33,95 +33,96 @@ class LoggingMiddleware(BaseHTTPMiddleware):
clear_log_context()
try:
# Extract and set log context
context = {}
with tracer.start_as_current_span("middleware.logging"):
# Extract and set log context
context = {}
# Headers
actor_id = request.headers.get("user_id")
if actor_id:
context["actor_id"] = actor_id
# Headers
actor_id = request.headers.get("user_id")
if actor_id:
context["actor_id"] = actor_id
project_id = request.headers.get("x-project-id")
if project_id:
context["project_id"] = project_id
project_id = request.headers.get("x-project-id")
if project_id:
context["project_id"] = project_id
org_id = request.headers.get("x-organization-id")
if org_id:
context["org_id"] = org_id
org_id = request.headers.get("x-organization-id")
if org_id:
context["org_id"] = org_id
user_agent = request.headers.get("x-agent-id")
if user_agent:
context["agent_id"] = user_agent
user_agent = request.headers.get("x-agent-id")
if user_agent:
context["agent_id"] = user_agent
run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id")
if run_id_header:
context["run_id"] = run_id_header
run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id")
if run_id_header:
context["run_id"] = run_id_header
path = request.url.path
path_parts = [p for p in path.split("/") if p]
path = request.url.path
path_parts = [p for p in path.split("/") if p]
# Path
matched_parts = set()
for part in path_parts:
if part in matched_parts:
continue
# Path
matched_parts = set()
for part in path_parts:
if part in matched_parts:
continue
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
if pattern and pattern.match(part):
context_key = f"{primitive_type.name.lower()}_id"
if pattern and pattern.match(part):
context_key = f"{primitive_type.name.lower()}_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
context[context_key] = part
matched_parts.add(part)
break
context[context_key] = part
matched_parts.add(part)
break
# Query Parameters
for param_value in request.query_params.values():
if param_value in matched_parts:
continue
# Query Parameters
for param_value in request.query_params.values():
if param_value in matched_parts:
continue
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
if pattern and pattern.match(param_value):
context_key = f"{primitive_type.name.lower()}_id"
if pattern and pattern.match(param_value):
context_key = f"{primitive_type.name.lower()}_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
# Only set if not already set from path (path takes precedence over query params)
# Query params can overwrite headers, but path values take precedence
if context_key not in context:
context[context_key] = param_value
matched_parts.add(param_value)
break
# Only set if not already set from path (path takes precedence over query params)
# Query params can overwrite headers, but path values take precedence
if context_key not in context:
context[context_key] = param_value
matched_parts.add(param_value)
break
if context:
update_log_context(**context)
if context:
update_log_context(**context)
logger.debug(
f"Incoming request: {request.method} {request.url.path}",
extra={
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"client_host": request.client.host if request.client else None,
},
)
logger.debug(
f"Incoming request: {request.method} {request.url.path}",
extra={
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"client_host": request.client.host if request.client else None,
},
)
response = await call_next(request)
return response
response = await call_next(request)
return response
except Exception as exc:
# Extract request context

View File

@@ -18,6 +18,8 @@ from typing import Optional
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send
from letta.otel.tracing import tracer
# Contextvar for storing the request ID across async boundaries
request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
@@ -47,17 +49,18 @@ class RequestIdMiddleware:
await self.app(scope, receive, send)
return
# Create a Request object for easier header access
request = Request(scope)
with tracer.start_as_current_span("middleware.request_id"):
# Create a Request object for easier header access
request = Request(scope)
# Extract request_id from header
request_id = request.headers.get("x-api-request-log-id")
# Extract request_id from header
request_id = request.headers.get("x-api-request-log-id")
# Set in contextvar (for non-streaming code paths)
request_id_var.set(request_id)
# Set in contextvar (for non-streaming code paths)
request_id_var.set(request_id)
# Also store in request.state for streaming responses where contextvars don't propagate
# This is accessible via request.state.request_id throughout the request lifecycle
request.state.request_id = request_id
# Also store in request.state for streaming responses where contextvars don't propagate
# This is accessible via request.state.request_id throughout the request lifecycle
request.state.request_id = request_id
await self.app(scope, receive, send)
await self.app(scope, receive, send)