feat: global exception middleware (#6017)

* global exception middleware

* redo both logging middlewares as one

* remove extra middleware files
This commit is contained in:
Kian Jones
2025-11-06 20:34:52 -08:00
committed by Caren Thomas
parent a217c8f1b6
commit 4acda9c80f
8 changed files with 644 additions and 70 deletions

137
letta/exceptions/logging.py Normal file
View File

@@ -0,0 +1,137 @@
"""
Helper utilities for structured exception logging.
Use these when you need to add context to exceptions before raising them.
"""
from typing import Any, Dict, Optional
from letta.log import get_logger
logger = get_logger(__name__)
def log_and_raise(
exception: Exception,
message: str,
context: Optional[Dict[str, Any]] = None,
level: str = "error",
) -> None:
"""
Log an exception with structured context and then raise it.
This is useful when you want to ensure an exception is logged with
full context before raising it.
Args:
exception: The exception to log and raise
message: Human-readable message to log
context: Additional context to include in logs (dict)
level: Log level (default: "error")
Example:
try:
result = some_operation()
except ValueError as e:
log_and_raise(
e,
"Failed to process operation",
context={
"user_id": user.id,
"operation": "some_operation",
"input": input_data,
}
)
"""
extra = {
"exception_type": exception.__class__.__name__,
"exception_message": str(exception),
"exception_module": exception.__class__.__module__,
}
if context:
extra.update(context)
log_method = getattr(logger, level.lower())
log_method(
f"{message}: {exception.__class__.__name__}: {str(exception)}",
extra=extra,
exc_info=exception,
)
raise exception
def log_exception(
exception: Exception,
message: str,
context: Optional[Dict[str, Any]] = None,
level: str = "error",
) -> None:
"""
Log an exception with structured context without raising it.
Use this when you want to log an exception but handle it gracefully.
Args:
exception: The exception to log
message: Human-readable message to log
context: Additional context to include in logs (dict)
level: Log level (default: "error")
Example:
try:
result = some_operation()
except ValueError as e:
log_exception(
e,
"Operation failed, using fallback",
context={"user_id": user.id}
)
result = fallback_operation()
"""
extra = {
"exception_type": exception.__class__.__name__,
"exception_message": str(exception),
"exception_module": exception.__class__.__module__,
}
if context:
extra.update(context)
log_method = getattr(logger, level.lower())
log_method(
f"{message}: {exception.__class__.__name__}: {str(exception)}",
extra=extra,
exc_info=exception,
)
def add_exception_context(exception: Exception, **context) -> Exception:
"""
Add context to an exception that will be picked up by the global exception handler.
This attaches a __letta_context__ attribute to the exception with structured data.
The global exception handler will automatically include this context in logs.
Args:
exception: The exception to add context to
**context: Key-value pairs to add as context
Returns:
The same exception with context attached
Example:
try:
result = operation()
except ValueError as e:
raise add_exception_context(
e,
user_id=user.id,
operation="do_thing",
input_data=data,
)
"""
if not hasattr(exception, "__letta_context__"):
exception.__letta_context__ = {}
exception.__letta_context__.update(context)
return exception

View File

@@ -0,0 +1,108 @@
"""
Global exception handlers for non-request exceptions (background tasks, startup, etc.)
"""
import sys
import threading
import traceback
from letta.log import get_logger
logger = get_logger(__name__)
def setup_global_exception_handlers():
"""
Set up global exception handlers to catch exceptions that occur outside of request handling.
This includes:
- Uncaught exceptions in the main thread
- Exceptions in background threads
- Asyncio task exceptions
"""
# 1. Handle uncaught exceptions in the main thread
def global_exception_hook(exc_type, exc_value, exc_traceback):
"""
Global exception hook for uncaught exceptions in the main thread.
This catches exceptions that would otherwise crash the application.
"""
# Don't log KeyboardInterrupt (Ctrl+C)
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logger.critical(
f"Uncaught exception in main thread: {exc_type.__name__}: {exc_value}",
extra={
"exception_type": exc_type.__name__,
"exception_message": str(exc_value),
"exception_module": exc_type.__module__,
"traceback": "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)),
},
exc_info=(exc_type, exc_value, exc_traceback),
)
sys.excepthook = global_exception_hook
# 2. Handle exceptions in threading
def thread_exception_hook(args):
"""
Hook for exceptions in threads.
"""
logger.error(
f"Uncaught exception in thread {args.thread.name}: {args.exc_type.__name__}: {args.exc_value}",
extra={
"exception_type": args.exc_type.__name__,
"exception_message": str(args.exc_value),
"exception_module": args.exc_type.__module__,
"thread_name": args.thread.name,
"thread_id": args.thread.ident,
"traceback": "".join(traceback.format_exception(args.exc_type, args.exc_value, args.exc_traceback)),
},
exc_info=(args.exc_type, args.exc_value, args.exc_traceback),
)
threading.excepthook = thread_exception_hook
logger.info("Global exception handlers initialized")
def setup_asyncio_exception_handler(loop):
"""
Set up exception handler for asyncio loop.
Call this with your event loop.
"""
def asyncio_exception_handler(loop, context):
"""
Handler for exceptions in asyncio tasks.
"""
exception = context.get("exception")
message = context.get("message", "Unhandled exception in asyncio")
extra = {
"asyncio_context": str(context),
"task": str(context.get("task")),
}
if exception:
extra.update(
{
"exception_type": exception.__class__.__name__,
"exception_message": str(exception),
"exception_module": exception.__class__.__module__,
}
)
logger.error(
f"Asyncio exception: {message}: {exception}",
extra=extra,
exc_info=exception,
)
else:
logger.error(
f"Asyncio exception: {message}",
extra=extra,
)
loop.set_exception_handler(asyncio_exception_handler)
logger.info("Asyncio exception handler initialized")

View File

@@ -58,11 +58,12 @@ from letta.schemas.letta_message_content import (
)
from letta.server.constants import REST_DEFAULT_PORT
from letta.server.db import db_registry
from letta.server.global_exception_handler import setup_global_exception_handlers
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LogContextMiddleware, ProfilerContextMiddleware
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware, ProfilerContextMiddleware
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
@@ -260,12 +261,41 @@ def create_application() -> "FastAPI":
lifespan=lifespan,
)
# === Global Exception Handlers ===
# Set up handlers for exceptions outside of request context (background tasks, threads, etc.)
setup_global_exception_handlers()
# === Exception Handlers ===
# TODO (cliandy): move to separate file
@app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception):
logger.error(f"Unhandled error: {str(exc)}", exc_info=True)
# Log with structured context
request_context = {
"method": request.method,
"url": str(request.url),
"path": request.url.path,
}
# Extract user context if available
user_context = {}
if hasattr(request.state, "user_id"):
user_context["user_id"] = request.state.user_id
if hasattr(request.state, "org_id"):
user_context["org_id"] = request.state.org_id
logger.error(
f"Unhandled error: {exc.__class__.__name__}: {str(exc)}",
extra={
"exception_type": exc.__class__.__name__,
"exception_message": str(exc),
"exception_module": exc.__class__.__module__,
"request": request_context,
"user": user_context,
},
exc_info=True,
)
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
@@ -519,7 +549,8 @@ def create_application() -> "FastAPI":
if telemetry_settings.profiler:
app.add_middleware(ProfilerContextMiddleware)
app.add_middleware(LogContextMiddleware)
# Add unified logging middleware - enriches log context and logs exceptions
app.add_middleware(LoggingMiddleware)
app.add_middleware(
CORSMiddleware,

View File

@@ -1,5 +1,5 @@
from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware
from letta.server.rest_api.middleware.log_context import LogContextMiddleware
from letta.server.rest_api.middleware.logging import LoggingMiddleware
from letta.server.rest_api.middleware.profiler_context import ProfilerContextMiddleware
__all__ = ["CheckPasswordMiddleware", "LogContextMiddleware", "ProfilerContextMiddleware"]
__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware", "ProfilerContextMiddleware"]

View File

@@ -1,63 +0,0 @@
import re
from starlette.middleware.base import BaseHTTPMiddleware
from letta.log_context import clear_log_context, update_log_context
from letta.schemas.enums import PrimitiveType
from letta.validators import PRIMITIVE_ID_PATTERNS
class LogContextMiddleware(BaseHTTPMiddleware):
"""
Middleware that enriches log context with request-specific attributes.
Automatically extracts and sets:
- actor_id: From the 'user_id' header
- org_id: From organization-related endpoints
- Letta primitive IDs: agent_id, tool_id, block_id, etc. from URL paths
This enables all logs within a request to be automatically tagged with
relevant context for better filtering and correlation in monitoring systems.
"""
async def dispatch(self, request, call_next):
clear_log_context()
try:
context = {}
actor_id = request.headers.get("user_id")
if actor_id:
context["actor_id"] = actor_id
path = request.url.path
path_parts = [p for p in path.split("/") if p]
matched_parts = set()
for part in path_parts:
if part in matched_parts:
continue
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
if pattern and pattern.match(part):
context_key = f"{primitive_type.name.lower()}_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
context[context_key] = part
matched_parts.add(part)
break
if context:
update_log_context(**context)
response = await call_next(request)
return response
finally:
clear_log_context()

View File

@@ -0,0 +1,115 @@
"""
Unified logging middleware that enriches log context and ensures exceptions are logged.
"""
import re
import traceback
from typing import Callable
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from letta.log import get_logger
from letta.log_context import clear_log_context, update_log_context
from letta.schemas.enums import PrimitiveType
from letta.validators import PRIMITIVE_ID_PATTERNS
logger = get_logger(__name__)
class LoggingMiddleware(BaseHTTPMiddleware):
"""
Middleware that enriches log context with request-specific attributes and logs exceptions.
Automatically extracts and sets:
- actor_id: From the 'user_id' header
- org_id: From organization-related endpoints
- Letta primitive IDs: agent_id, tool_id, block_id, etc. from URL paths
Also catches all exceptions and logs them with structured context before re-raising.
"""
async def dispatch(self, request: Request, call_next: Callable):
clear_log_context()
try:
# Extract and set log context
context = {}
actor_id = request.headers.get("user_id")
if actor_id:
context["actor_id"] = actor_id
path = request.url.path
path_parts = [p for p in path.split("/") if p]
matched_parts = set()
for part in path_parts:
if part in matched_parts:
continue
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
if pattern and pattern.match(part):
context_key = f"{primitive_type.name.lower()}_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
context[context_key] = part
matched_parts.add(part)
break
if context:
update_log_context(**context)
response = await call_next(request)
return response
except Exception as exc:
# Extract request context
request_context = {
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"client_host": request.client.host if request.client else None,
"user_agent": request.headers.get("user-agent"),
}
# Extract user context if available
user_context = {}
if hasattr(request.state, "user_id"):
user_context["user_id"] = request.state.user_id
if hasattr(request.state, "org_id"):
user_context["org_id"] = request.state.org_id
# Check for custom context attached to the exception
custom_context = {}
if hasattr(exc, "__letta_context__"):
custom_context = exc.__letta_context__
# Log with structured data
logger.error(
f"Unhandled exception in request: {exc.__class__.__name__}: {str(exc)}",
extra={
"exception_type": exc.__class__.__name__,
"exception_message": str(exc),
"exception_module": exc.__class__.__module__,
"request": request_context,
"user": user_context,
"custom_context": custom_context,
"traceback": traceback.format_exc(),
},
exc_info=True,
)
# Re-raise to let FastAPI's exception handlers deal with it
raise
finally:
clear_log_context()

View File

@@ -0,0 +1,246 @@
"""
Tests for global exception logging system.
"""
import asyncio
import logging
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from starlette.middleware.base import BaseHTTPMiddleware
from letta.exceptions.logging import add_exception_context, log_and_raise, log_exception
from letta.server.rest_api.middleware.logging import LoggingMiddleware
@pytest.fixture
def app_with_exception_middleware():
"""Create a test FastAPI app with logging middleware."""
app = FastAPI()
app.add_middleware(LoggingMiddleware)
@app.get("/test-error")
def test_error():
raise ValueError("Test error message")
@app.get("/test-error-with-context")
def test_error_with_context():
exc = ValueError("Test error with context")
exc = add_exception_context(
exc,
user_id="test-user-123",
operation="test_operation",
)
raise exc
@app.get("/test-success")
def test_success():
return {"status": "ok"}
return app
def test_exception_middleware_logs_basic_exception(app_with_exception_middleware):
"""Test that the middleware logs exceptions with basic context."""
client = TestClient(app_with_exception_middleware, raise_server_exceptions=False)
with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger:
response = client.get("/test-error")
# Should return 500
assert response.status_code == 500
# Should log the error
assert mock_logger.error.called
call_args = mock_logger.error.call_args
# Check the message
assert "ValueError" in call_args[0][0]
assert "Test error message" in call_args[0][0]
# Check the extra context
extra = call_args[1]["extra"]
assert extra["exception_type"] == "ValueError"
assert extra["exception_message"] == "Test error message"
assert "request" in extra
assert extra["request"]["method"] == "GET"
assert "/test-error" in extra["request"]["path"]
def test_exception_middleware_logs_custom_context(app_with_exception_middleware):
"""Test that the middleware logs custom context attached to exceptions."""
client = TestClient(app_with_exception_middleware, raise_server_exceptions=False)
with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger:
response = client.get("/test-error-with-context")
# Should return 500
assert response.status_code == 500
# Should log the error with custom context
assert mock_logger.error.called
call_args = mock_logger.error.call_args
extra = call_args[1]["extra"]
# Check custom context
assert "custom_context" in extra
assert extra["custom_context"]["user_id"] == "test-user-123"
assert extra["custom_context"]["operation"] == "test_operation"
def test_exception_middleware_does_not_interfere_with_success(app_with_exception_middleware):
"""Test that the middleware doesn't interfere with successful requests."""
client = TestClient(app_with_exception_middleware)
response = client.get("/test-success")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
def test_add_exception_context():
"""Test that add_exception_context properly attaches context to exceptions."""
exc = ValueError("Test error")
# Add context
exc_with_context = add_exception_context(
exc,
user_id="user-123",
agent_id="agent-456",
operation="test_op",
)
# Should be the same exception object
assert exc_with_context is exc
# Should have context attached
assert hasattr(exc, "__letta_context__")
assert exc.__letta_context__["user_id"] == "user-123"
assert exc.__letta_context__["agent_id"] == "agent-456"
assert exc.__letta_context__["operation"] == "test_op"
def test_add_exception_context_multiple_times():
"""Test that add_exception_context can be called multiple times."""
exc = ValueError("Test error")
# Add context in multiple calls
add_exception_context(exc, user_id="user-123")
add_exception_context(exc, agent_id="agent-456")
# Both should be present
assert exc.__letta_context__["user_id"] == "user-123"
assert exc.__letta_context__["agent_id"] == "agent-456"
def test_log_and_raise():
"""Test that log_and_raise logs and then raises the exception."""
exc = ValueError("Test error")
with patch("letta.exceptions.logging.logger") as mock_logger:
with pytest.raises(ValueError, match="Test error"):
log_and_raise(
exc,
"Operation failed",
context={"user_id": "user-123"},
)
# Should have logged
assert mock_logger.error.called
call_args = mock_logger.error.call_args
# Check message
assert "Operation failed" in call_args[0][0]
assert "ValueError" in call_args[0][0]
assert "Test error" in call_args[0][0]
# Check extra context
extra = call_args[1]["extra"]
assert extra["exception_type"] == "ValueError"
assert extra["user_id"] == "user-123"
def test_log_exception():
"""Test that log_exception logs without raising."""
exc = ValueError("Test error")
with patch("letta.exceptions.logging.logger") as mock_logger:
# Should not raise
log_exception(
exc,
"Operation failed, using fallback",
context={"user_id": "user-123"},
)
# Should have logged
assert mock_logger.error.called
call_args = mock_logger.error.call_args
# Check message
assert "Operation failed, using fallback" in call_args[0][0]
assert "ValueError" in call_args[0][0]
# Check extra context
extra = call_args[1]["extra"]
assert extra["exception_type"] == "ValueError"
assert extra["user_id"] == "user-123"
def test_log_exception_with_different_levels():
"""Test that log_exception respects different log levels."""
exc = ValueError("Test error")
with patch("letta.exceptions.logging.logger") as mock_logger:
# Test warning level
log_exception(exc, "Warning message", level="warning")
assert mock_logger.warning.called
# Test info level
log_exception(exc, "Info message", level="info")
assert mock_logger.info.called
@pytest.mark.asyncio
async def test_global_exception_handler_setup():
"""Test that global exception handlers can be set up without errors."""
from letta.server.global_exception_handler import setup_global_exception_handlers
# Should not raise
setup_global_exception_handlers()
# Verify sys.excepthook was modified
import sys
assert sys.excepthook != sys.__excepthook__
@pytest.mark.asyncio
async def test_asyncio_exception_handler():
"""Test that asyncio exception handler can be set up."""
from letta.server.global_exception_handler import setup_asyncio_exception_handler
loop = asyncio.get_event_loop()
# Should not raise
setup_asyncio_exception_handler(loop)
def test_exception_middleware_preserves_traceback(app_with_exception_middleware):
"""Test that the middleware preserves traceback information."""
client = TestClient(app_with_exception_middleware, raise_server_exceptions=False)
with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger:
response = client.get("/test-error")
assert response.status_code == 500
call_args = mock_logger.error.call_args
# Check that exc_info was passed (enables traceback)
assert call_args[1]["exc_info"] is True
# Check that traceback is in extra
extra = call_args[1]["extra"]
assert "traceback" in extra
assert "ValueError" in extra["traceback"]
assert "test_error" in extra["traceback"]

View File

@@ -3,13 +3,13 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from letta.log_context import get_log_context
from letta.server.rest_api.middleware import LogContextMiddleware
from letta.server.rest_api.middleware import LoggingMiddleware
@pytest.fixture
def app():
app = FastAPI()
app.add_middleware(LogContextMiddleware)
app.add_middleware(LoggingMiddleware)
@app.get("/v1/agents/{agent_id}")
async def get_agent(agent_id: str):