chore: add tracing for request middleware (#8142)
* base * update * more
This commit is contained in:
@@ -4,17 +4,16 @@ from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from opentelemetry import trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import REDIS_DEFAULT_CACHE_PREFIX
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import tracer
|
||||
from letta.plugins.plugins import get_experimental_checker
|
||||
from letta.settings import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
def experimental(feature_name: str, fallback_function: Callable, **kwargs):
|
||||
|
||||
@@ -37,6 +37,9 @@ _excluded_v1_endpoints_regex: List[str] = [
|
||||
|
||||
|
||||
async def _trace_request_middleware(request: Request, call_next):
|
||||
# Capture earliest possible timestamp when request enters application
|
||||
entry_time = time.time()
|
||||
|
||||
if not _is_tracing_initialized:
|
||||
return await call_next(request)
|
||||
initial_span_name = f"{request.method} {request.url.path}"
|
||||
@@ -47,8 +50,13 @@ async def _trace_request_middleware(request: Request, call_next):
|
||||
initial_span_name,
|
||||
kind=trace.SpanKind.SERVER,
|
||||
) as span:
|
||||
# Record when we entered the application (useful for detecting worker queuing)
|
||||
span.set_attribute("entry.timestamp_ms", int(entry_time * 1000))
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
# This span captures all downstream middleware (CORS, RequestId, Logging) + handler
|
||||
with tracer.start_as_current_span("middleware.chain"):
|
||||
response = await call_next(request)
|
||||
span.set_attribute("http.status_code", response.status_code)
|
||||
span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR))
|
||||
return response
|
||||
@@ -100,9 +108,10 @@ async def _update_trace_attributes(request: Request):
|
||||
|
||||
# Add request body if available
|
||||
try:
|
||||
body = await request.json()
|
||||
for key, value in body.items():
|
||||
span.set_attribute(f"http.request.body.{key}", str(value))
|
||||
with tracer.start_as_current_span("trace.request_body"):
|
||||
body = await request.json()
|
||||
for key, value in body.items():
|
||||
span.set_attribute(f"http.request.body.{key}", str(value))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Unified logging middleware that enriches log context and ensures exceptions are logged.
|
||||
"""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable
|
||||
|
||||
@@ -11,6 +10,7 @@ from starlette.requests import Request
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.log_context import clear_log_context, update_log_context
|
||||
from letta.otel.tracing import tracer
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.validators import PRIMITIVE_ID_PATTERNS
|
||||
|
||||
@@ -33,95 +33,96 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
clear_log_context()
|
||||
|
||||
try:
|
||||
# Extract and set log context
|
||||
context = {}
|
||||
with tracer.start_as_current_span("middleware.logging"):
|
||||
# Extract and set log context
|
||||
context = {}
|
||||
|
||||
# Headers
|
||||
actor_id = request.headers.get("user_id")
|
||||
if actor_id:
|
||||
context["actor_id"] = actor_id
|
||||
# Headers
|
||||
actor_id = request.headers.get("user_id")
|
||||
if actor_id:
|
||||
context["actor_id"] = actor_id
|
||||
|
||||
project_id = request.headers.get("x-project-id")
|
||||
if project_id:
|
||||
context["project_id"] = project_id
|
||||
project_id = request.headers.get("x-project-id")
|
||||
if project_id:
|
||||
context["project_id"] = project_id
|
||||
|
||||
org_id = request.headers.get("x-organization-id")
|
||||
if org_id:
|
||||
context["org_id"] = org_id
|
||||
org_id = request.headers.get("x-organization-id")
|
||||
if org_id:
|
||||
context["org_id"] = org_id
|
||||
|
||||
user_agent = request.headers.get("x-agent-id")
|
||||
if user_agent:
|
||||
context["agent_id"] = user_agent
|
||||
user_agent = request.headers.get("x-agent-id")
|
||||
if user_agent:
|
||||
context["agent_id"] = user_agent
|
||||
|
||||
run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id")
|
||||
if run_id_header:
|
||||
context["run_id"] = run_id_header
|
||||
run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id")
|
||||
if run_id_header:
|
||||
context["run_id"] = run_id_header
|
||||
|
||||
path = request.url.path
|
||||
path_parts = [p for p in path.split("/") if p]
|
||||
path = request.url.path
|
||||
path_parts = [p for p in path.split("/") if p]
|
||||
|
||||
# Path
|
||||
matched_parts = set()
|
||||
for part in path_parts:
|
||||
if part in matched_parts:
|
||||
continue
|
||||
# Path
|
||||
matched_parts = set()
|
||||
for part in path_parts:
|
||||
if part in matched_parts:
|
||||
continue
|
||||
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
|
||||
if pattern and pattern.match(part):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
if pattern and pattern.match(part):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
|
||||
context[context_key] = part
|
||||
matched_parts.add(part)
|
||||
break
|
||||
context[context_key] = part
|
||||
matched_parts.add(part)
|
||||
break
|
||||
|
||||
# Query Parameters
|
||||
for param_value in request.query_params.values():
|
||||
if param_value in matched_parts:
|
||||
continue
|
||||
# Query Parameters
|
||||
for param_value in request.query_params.values():
|
||||
if param_value in matched_parts:
|
||||
continue
|
||||
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
|
||||
if pattern and pattern.match(param_value):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
if pattern and pattern.match(param_value):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
|
||||
# Only set if not already set from path (path takes precedence over query params)
|
||||
# Query params can overwrite headers, but path values take precedence
|
||||
if context_key not in context:
|
||||
context[context_key] = param_value
|
||||
matched_parts.add(param_value)
|
||||
break
|
||||
# Only set if not already set from path (path takes precedence over query params)
|
||||
# Query params can overwrite headers, but path values take precedence
|
||||
if context_key not in context:
|
||||
context[context_key] = param_value
|
||||
matched_parts.add(param_value)
|
||||
break
|
||||
|
||||
if context:
|
||||
update_log_context(**context)
|
||||
if context:
|
||||
update_log_context(**context)
|
||||
|
||||
logger.debug(
|
||||
f"Incoming request: {request.method} {request.url.path}",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"client_host": request.client.host if request.client else None,
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
f"Incoming request: {request.method} {request.url.path}",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"client_host": request.client.host if request.client else None,
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except Exception as exc:
|
||||
# Extract request context
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import Optional
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from letta.otel.tracing import tracer
|
||||
|
||||
# Contextvar for storing the request ID across async boundaries
|
||||
request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
|
||||
|
||||
@@ -47,17 +49,18 @@ class RequestIdMiddleware:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Create a Request object for easier header access
|
||||
request = Request(scope)
|
||||
with tracer.start_as_current_span("middleware.request_id"):
|
||||
# Create a Request object for easier header access
|
||||
request = Request(scope)
|
||||
|
||||
# Extract request_id from header
|
||||
request_id = request.headers.get("x-api-request-log-id")
|
||||
# Extract request_id from header
|
||||
request_id = request.headers.get("x-api-request-log-id")
|
||||
|
||||
# Set in contextvar (for non-streaming code paths)
|
||||
request_id_var.set(request_id)
|
||||
# Set in contextvar (for non-streaming code paths)
|
||||
request_id_var.set(request_id)
|
||||
|
||||
# Also store in request.state for streaming responses where contextvars don't propagate
|
||||
# This is accessible via request.state.request_id throughout the request lifecycle
|
||||
request.state.request_id = request_id
|
||||
# Also store in request.state for streaming responses where contextvars don't propagate
|
||||
# This is accessible via request.state.request_id throughout the request lifecycle
|
||||
request.state.request_id = request_id
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
Reference in New Issue
Block a user