""" Unified logging middleware that enriches log context and ensures exceptions are logged. """ import traceback from typing import Callable from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from letta.log import get_logger from letta.log_context import clear_log_context, update_log_context from letta.otel.tracing import tracer from letta.schemas.enums import PrimitiveType from letta.validators import PRIMITIVE_ID_PATTERNS logger = get_logger(__name__) class LoggingMiddleware(BaseHTTPMiddleware): """ Middleware that enriches log context with request-specific attributes and logs exceptions. Automatically extracts and sets: - actor_id: From the 'user_id' header - org_id: From organization-related endpoints - Letta primitive IDs: agent_id, tool_id, block_id, etc. from URL paths Also catches all exceptions and logs them with structured context before re-raising. """ async def dispatch(self, request: Request, call_next: Callable): clear_log_context() try: with tracer.start_as_current_span("middleware.logging"): # Extract and set log context context = {} with tracer.start_as_current_span("middleware.logging.context"): # Headers actor_id = request.headers.get("user_id") if actor_id: context["actor_id"] = actor_id project_id = request.headers.get("x-project-id") if project_id: context["project_id"] = project_id org_id = request.headers.get("x-organization-id") if org_id: context["org_id"] = org_id user_agent = request.headers.get("x-agent-id") if user_agent: context["agent_id"] = user_agent run_id_header = request.headers.get("x-run-id") or request.headers.get("run-id") if run_id_header: context["run_id"] = run_id_header path = request.url.path path_parts = [p for p in path.split("/") if p] # Path matched_parts = set() for part in path_parts: if part in matched_parts: continue for primitive_type in PrimitiveType: prefix = primitive_type.value pattern = PRIMITIVE_ID_PATTERNS.get(prefix) if pattern and pattern.match(part): context_key = f"{primitive_type.name.lower()}_id" if primitive_type == PrimitiveType.ORGANIZATION: context_key = "org_id" elif primitive_type == PrimitiveType.USER: context_key = "user_id" context[context_key] = part matched_parts.add(part) break # Query Parameters for param_value in request.query_params.values(): if param_value in matched_parts: continue for primitive_type in PrimitiveType: prefix = primitive_type.value pattern = PRIMITIVE_ID_PATTERNS.get(prefix) if pattern and pattern.match(param_value): context_key = f"{primitive_type.name.lower()}_id" if primitive_type == PrimitiveType.ORGANIZATION: context_key = "org_id" elif primitive_type == PrimitiveType.USER: context_key = "user_id" # Only set if not already set from path (path takes precedence over query params) # Query params can overwrite headers, but path values take precedence if context_key not in context: context[context_key] = param_value matched_parts.add(param_value) break if context: update_log_context(**context) logger.debug( f"Incoming request: {request.method} {request.url.path}", extra={ "method": request.method, "url": str(request.url), "path": request.url.path, "query_params": dict(request.query_params), "client_host": request.client.host if request.client else None, }, ) response = await call_next(request) return response except Exception as exc: import anyio if isinstance(exc, (anyio.BrokenResourceError, anyio.ClosedResourceError)): logger.info(f"Client disconnected during request: {request.method} {request.url.path}") raise # Extract request context request_context = { "method": request.method, "url": str(request.url), "path": request.url.path, "query_params": dict(request.query_params), "client_host": request.client.host if request.client else None, "user_agent": request.headers.get("user-agent"), } # Extract user context if available user_context = {} if hasattr(request.state, "user_id"): user_context["user_id"] = request.state.user_id if hasattr(request.state, "org_id"): user_context["org_id"] = request.state.org_id # Check for custom context attached to the exception custom_context = {} if hasattr(exc, "__letta_context__"): custom_context = exc.__letta_context__ # Log with structured data logger.error( f"Unhandled exception in request: {exc.__class__.__name__}: {str(exc)}", extra={ "exception_type": exc.__class__.__name__, "exception_message": str(exc), "exception_module": exc.__class__.__module__, "request": request_context, "user": user_context, "custom_context": custom_context, "traceback": traceback.format_exc(), }, exc_info=True, ) # Re-raise to let FastAPI's exception handlers deal with it raise finally: clear_log_context()