diff --git a/letta/helpers/decorators.py b/letta/helpers/decorators.py index 8c0bbe15..e75b8db5 100644 --- a/letta/helpers/decorators.py +++ b/letta/helpers/decorators.py @@ -4,17 +4,16 @@ from dataclasses import dataclass from functools import wraps from typing import Callable -from opentelemetry import trace from pydantic import BaseModel from letta.constants import REDIS_DEFAULT_CACHE_PREFIX from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client from letta.log import get_logger +from letta.otel.tracing import tracer from letta.plugins.plugins import get_experimental_checker from letta.settings import settings logger = get_logger(__name__) -tracer = trace.get_tracer(__name__) def experimental(feature_name: str, fallback_function: Callable, **kwargs): diff --git a/letta/otel/tracing.py b/letta/otel/tracing.py index 3a2238ed..fea75efb 100644 --- a/letta/otel/tracing.py +++ b/letta/otel/tracing.py @@ -37,6 +37,9 @@ _excluded_v1_endpoints_regex: List[str] = [ async def _trace_request_middleware(request: Request, call_next): + # Capture earliest possible timestamp when request enters application + entry_time = time.time() + if not _is_tracing_initialized: return await call_next(request) initial_span_name = f"{request.method} {request.url.path}" @@ -47,8 +50,13 @@ async def _trace_request_middleware(request: Request, call_next): initial_span_name, kind=trace.SpanKind.SERVER, ) as span: + # Record when we entered the application (useful for detecting worker queuing) + span.set_attribute("entry.timestamp_ms", int(entry_time * 1000)) + try: - response = await call_next(request) + # This span captures all downstream middleware (CORS, RequestId, Logging) + handler + with tracer.start_as_current_span("middleware.chain"): + response = await call_next(request) span.set_attribute("http.status_code", response.status_code) span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR)) return response @@ -100,9 +108,10 @@ async def _update_trace_attributes(request: Request): # Add request body if available try: - body = await request.json() - for key, value in body.items(): - span.set_attribute(f"http.request.body.{key}", str(value)) + with tracer.start_as_current_span("trace.request_body"): + body = await request.json() + for key, value in body.items(): + span.set_attribute(f"http.request.body.{key}", str(value)) except Exception: pass diff --git a/letta/server/rest_api/middleware/logging.py b/letta/server/rest_api/middleware/logging.py index 15534a7e..5caa043d 100644 --- a/letta/server/rest_api/middleware/logging.py +++ b/letta/server/rest_api/middleware/logging.py @@ -2,7 +2,6 @@ Unified logging middleware that enriches log context and ensures exceptions are logged. """ -import re import traceback from typing import Callable @@ -11,6 +10,7 @@ from starlette.requests import Request from letta.log import get_logger from letta.log_context import clear_log_context, update_log_context +from letta.otel.tracing import tracer from letta.schemas.enums import PrimitiveType from letta.validators import PRIMITIVE_ID_PATTERNS @@ -33,95 +33,96 @@ class LoggingMiddleware(BaseHTTPMiddleware): clear_log_context() try: - # Extract and set log context - context = {} + with tracer.start_as_current_span("middleware.logging"): + # Extract and set log context + context = {} - # Headers - actor_id = request.headers.get("user_id") - if actor_id: - context["actor_id"] = actor_id + # Headers + actor_id = request.headers.get("user_id") + if actor_id: + context["actor_id"] = actor_id - project_id = request.headers.get("x-project-id") - if project_id: - context["project_id"] = project_id + project_id = request.headers.get("x-project-id") + if project_id: + context["project_id"] = project_id - org_id = request.headers.get("x-organization-id") - if org_id: - context["org_id"] = org_id + org_id = request.headers.get("x-organization-id") + if org_id: + context["org_id"] = org_id - user_agent = request.headers.get("x-agent-id") - if user_agent: - context["agent_id"] = user_agent + user_agent = request.headers.get("x-agent-id") + if user_agent: + context["agent_id"] = user_agent - run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id") - if run_id_header: - context["run_id"] = run_id_header + run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id") + if run_id_header: + context["run_id"] = run_id_header - path = request.url.path - path_parts = [p for p in path.split("/") if p] + path = request.url.path + path_parts = [p for p in path.split("/") if p] - # Path - matched_parts = set() - for part in path_parts: - if part in matched_parts: - continue + # Path + matched_parts = set() + for part in path_parts: + if part in matched_parts: + continue - for primitive_type in PrimitiveType: - prefix = primitive_type.value - pattern = PRIMITIVE_ID_PATTERNS.get(prefix) + for primitive_type in PrimitiveType: + prefix = primitive_type.value + pattern = PRIMITIVE_ID_PATTERNS.get(prefix) - if pattern and pattern.match(part): - context_key = f"{primitive_type.name.lower()}_id" + if pattern and pattern.match(part): + context_key = f"{primitive_type.name.lower()}_id" - if primitive_type == PrimitiveType.ORGANIZATION: - context_key = "org_id" - elif primitive_type == PrimitiveType.USER: - context_key = "user_id" + if primitive_type == PrimitiveType.ORGANIZATION: + context_key = "org_id" + elif primitive_type == PrimitiveType.USER: + context_key = "user_id" - context[context_key] = part - matched_parts.add(part) - break + context[context_key] = part + matched_parts.add(part) + break - # Query Parameters - for param_value in request.query_params.values(): - if param_value in matched_parts: - continue + # Query Parameters + for param_value in request.query_params.values(): + if param_value in matched_parts: + continue - for primitive_type in PrimitiveType: - prefix = primitive_type.value - pattern = PRIMITIVE_ID_PATTERNS.get(prefix) + for primitive_type in PrimitiveType: + prefix = primitive_type.value + pattern = PRIMITIVE_ID_PATTERNS.get(prefix) - if pattern and pattern.match(param_value): - context_key = f"{primitive_type.name.lower()}_id" + if pattern and pattern.match(param_value): + context_key = f"{primitive_type.name.lower()}_id" - if primitive_type == PrimitiveType.ORGANIZATION: - context_key = "org_id" - elif primitive_type == PrimitiveType.USER: - context_key = "user_id" + if primitive_type == PrimitiveType.ORGANIZATION: + context_key = "org_id" + elif primitive_type == PrimitiveType.USER: + context_key = "user_id" - # Only set if not already set from path (path takes precedence over query params) - # Query params can overwrite headers, but path values take precedence - if context_key not in context: - context[context_key] = param_value - matched_parts.add(param_value) - break + # Only set if not already set from path (path takes precedence over query params) + # Query params can overwrite headers, but path values take precedence + if context_key not in context: + context[context_key] = param_value + matched_parts.add(param_value) + break - if context: - update_log_context(**context) + if context: + update_log_context(**context) - logger.debug( - f"Incoming request: {request.method} {request.url.path}", - extra={ - "method": request.method, - "url": str(request.url), - "path": request.url.path, - "query_params": dict(request.query_params), - "client_host": request.client.host if request.client else None, - }, - ) + logger.debug( + f"Incoming request: {request.method} {request.url.path}", + extra={ + "method": request.method, + "url": str(request.url), + "path": request.url.path, + "query_params": dict(request.query_params), + "client_host": request.client.host if request.client else None, + }, + ) - response = await call_next(request) - return response + response = await call_next(request) + return response except Exception as exc: # Extract request context diff --git a/letta/server/rest_api/middleware/request_id.py b/letta/server/rest_api/middleware/request_id.py index b147ee61..70fd2cbb 100644 --- a/letta/server/rest_api/middleware/request_id.py +++ b/letta/server/rest_api/middleware/request_id.py @@ -18,6 +18,8 @@ from typing import Optional from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send +from letta.otel.tracing import tracer + # Contextvar for storing the request ID across async boundaries request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None) @@ -47,17 +49,18 @@ class RequestIdMiddleware: await self.app(scope, receive, send) return - # Create a Request object for easier header access - request = Request(scope) + with tracer.start_as_current_span("middleware.request_id"): + # Create a Request object for easier header access + request = Request(scope) - # Extract request_id from header - request_id = request.headers.get("x-api-request-log-id") + # Extract request_id from header + request_id = request.headers.get("x-api-request-log-id") - # Set in contextvar (for non-streaming code paths) - request_id_var.set(request_id) + # Set in contextvar (for non-streaming code paths) + request_id_var.set(request_id) - # Also store in request.state for streaming responses where contextvars don't propagate - # This is accessible via request.state.request_id throughout the request lifecycle - request.state.request_id = request_id + # Also store in request.state for streaming responses where contextvars don't propagate + # This is accessible via request.state.request_id throughout the request lifecycle + request.state.request_id = request_id - await self.app(scope, receive, send) + await self.app(scope, receive, send)