diff --git a/letta/exceptions/logging.py b/letta/exceptions/logging.py new file mode 100644 index 00000000..3b40f0b8 --- /dev/null +++ b/letta/exceptions/logging.py @@ -0,0 +1,137 @@ +""" +Helper utilities for structured exception logging. +Use these when you need to add context to exceptions before raising them. +""" + +from typing import Any, Dict, Optional + +from letta.log import get_logger + +logger = get_logger(__name__) + + +def log_and_raise( + exception: Exception, + message: str, + context: Optional[Dict[str, Any]] = None, + level: str = "error", +) -> None: + """ + Log an exception with structured context and then raise it. + + This is useful when you want to ensure an exception is logged with + full context before raising it. + + Args: + exception: The exception to log and raise + message: Human-readable message to log + context: Additional context to include in logs (dict) + level: Log level (default: "error") + + Example: + try: + result = some_operation() + except ValueError as e: + log_and_raise( + e, + "Failed to process operation", + context={ + "user_id": user.id, + "operation": "some_operation", + "input": input_data, + } + ) + """ + extra = { + "exception_type": exception.__class__.__name__, + "exception_message": str(exception), + "exception_module": exception.__class__.__module__, + } + + if context: + extra.update(context) + + log_method = getattr(logger, level.lower()) + log_method( + f"{message}: {exception.__class__.__name__}: {str(exception)}", + extra=extra, + exc_info=exception, + ) + + raise exception + + +def log_exception( + exception: Exception, + message: str, + context: Optional[Dict[str, Any]] = None, + level: str = "error", +) -> None: + """ + Log an exception with structured context without raising it. + + Use this when you want to log an exception but handle it gracefully. + + Args: + exception: The exception to log + message: Human-readable message to log + context: Additional context to include in logs (dict) + level: Log level (default: "error") + + Example: + try: + result = some_operation() + except ValueError as e: + log_exception( + e, + "Operation failed, using fallback", + context={"user_id": user.id} + ) + result = fallback_operation() + """ + extra = { + "exception_type": exception.__class__.__name__, + "exception_message": str(exception), + "exception_module": exception.__class__.__module__, + } + + if context: + extra.update(context) + + log_method = getattr(logger, level.lower()) + log_method( + f"{message}: {exception.__class__.__name__}: {str(exception)}", + extra=extra, + exc_info=exception, + ) + + +def add_exception_context(exception: Exception, **context) -> Exception: + """ + Add context to an exception that will be picked up by the global exception handler. + + This attaches a __letta_context__ attribute to the exception with structured data. + The global exception handler will automatically include this context in logs. + + Args: + exception: The exception to add context to + **context: Key-value pairs to add as context + + Returns: + The same exception with context attached + + Example: + try: + result = operation() + except ValueError as e: + raise add_exception_context( + e, + user_id=user.id, + operation="do_thing", + input_data=data, + ) + """ + if not hasattr(exception, "__letta_context__"): + exception.__letta_context__ = {} + exception.__letta_context__.update(context) + return exception diff --git a/letta/server/global_exception_handler.py b/letta/server/global_exception_handler.py new file mode 100644 index 00000000..002b12ad --- /dev/null +++ b/letta/server/global_exception_handler.py @@ -0,0 +1,108 @@ +""" +Global exception handlers for non-request exceptions (background tasks, startup, etc.) +""" + +import sys +import threading +import traceback + +from letta.log import get_logger + +logger = get_logger(__name__) + + +def setup_global_exception_handlers(): + """ + Set up global exception handlers to catch exceptions that occur outside of request handling. + This includes: + - Uncaught exceptions in the main thread + - Exceptions in background threads + - Asyncio task exceptions + """ + + # 1. Handle uncaught exceptions in the main thread + def global_exception_hook(exc_type, exc_value, exc_traceback): + """ + Global exception hook for uncaught exceptions in the main thread. + This catches exceptions that would otherwise crash the application. + """ + # Don't log KeyboardInterrupt (Ctrl+C) + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + + logger.critical( + f"Uncaught exception in main thread: {exc_type.__name__}: {exc_value}", + extra={ + "exception_type": exc_type.__name__, + "exception_message": str(exc_value), + "exception_module": exc_type.__module__, + "traceback": "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)), + }, + exc_info=(exc_type, exc_value, exc_traceback), + ) + + sys.excepthook = global_exception_hook + + # 2. Handle exceptions in threading + def thread_exception_hook(args): + """ + Hook for exceptions in threads. + """ + logger.error( + f"Uncaught exception in thread {args.thread.name}: {args.exc_type.__name__}: {args.exc_value}", + extra={ + "exception_type": args.exc_type.__name__, + "exception_message": str(args.exc_value), + "exception_module": args.exc_type.__module__, + "thread_name": args.thread.name, + "thread_id": args.thread.ident, + "traceback": "".join(traceback.format_exception(args.exc_type, args.exc_value, args.exc_traceback)), + }, + exc_info=(args.exc_type, args.exc_value, args.exc_traceback), + ) + + threading.excepthook = thread_exception_hook + + logger.info("Global exception handlers initialized") + + +def setup_asyncio_exception_handler(loop): + """ + Set up exception handler for asyncio loop. + Call this with your event loop. + """ + + def asyncio_exception_handler(loop, context): + """ + Handler for exceptions in asyncio tasks. + """ + exception = context.get("exception") + message = context.get("message", "Unhandled exception in asyncio") + + extra = { + "asyncio_context": str(context), + "task": str(context.get("task")), + } + + if exception: + extra.update( + { + "exception_type": exception.__class__.__name__, + "exception_message": str(exception), + "exception_module": exception.__class__.__module__, + } + ) + logger.error( + f"Asyncio exception: {message}: {exception}", + extra=extra, + exc_info=exception, + ) + else: + logger.error( + f"Asyncio exception: {message}", + extra=extra, + ) + + loop.set_exception_handler(asyncio_exception_handler) + logger.info("Asyncio exception handler initialized") diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index e78e5970..dbeeec8b 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -58,11 +58,12 @@ from letta.schemas.letta_message_content import ( ) from letta.server.constants import REST_DEFAULT_PORT from letta.server.db import db_registry +from letta.server.global_exception_handler import setup_global_exception_handlers # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right? from letta.server.rest_api.interface import StreamingServerInterface -from letta.server.rest_api.middleware import CheckPasswordMiddleware, LogContextMiddleware, ProfilerContextMiddleware +from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware, ProfilerContextMiddleware from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes from letta.server.rest_api.routers.v1.organizations import router as organizations_router from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin @@ -260,12 +261,41 @@ def create_application() -> "FastAPI": lifespan=lifespan, ) + # === Global Exception Handlers === + # Set up handlers for exceptions outside of request context (background tasks, threads, etc.) + setup_global_exception_handlers() + # === Exception Handlers === # TODO (cliandy): move to separate file @app.exception_handler(Exception) async def generic_error_handler(request: Request, exc: Exception): - logger.error(f"Unhandled error: {str(exc)}", exc_info=True) + # Log with structured context + request_context = { + "method": request.method, + "url": str(request.url), + "path": request.url.path, + } + + # Extract user context if available + user_context = {} + if hasattr(request.state, "user_id"): + user_context["user_id"] = request.state.user_id + if hasattr(request.state, "org_id"): + user_context["org_id"] = request.state.org_id + + logger.error( + f"Unhandled error: {exc.__class__.__name__}: {str(exc)}", + extra={ + "exception_type": exc.__class__.__name__, + "exception_message": str(exc), + "exception_module": exc.__class__.__module__, + "request": request_context, + "user": user_context, + }, + exc_info=True, + ) + if SENTRY_ENABLED: sentry_sdk.capture_exception(exc) @@ -519,7 +549,8 @@ def create_application() -> "FastAPI": if telemetry_settings.profiler: app.add_middleware(ProfilerContextMiddleware) - app.add_middleware(LogContextMiddleware) + # Add unified logging middleware - enriches log context and logs exceptions + app.add_middleware(LoggingMiddleware) app.add_middleware( CORSMiddleware, diff --git a/letta/server/rest_api/middleware/__init__.py b/letta/server/rest_api/middleware/__init__.py index ab4da3b7..ed194d9f 100644 --- a/letta/server/rest_api/middleware/__init__.py +++ b/letta/server/rest_api/middleware/__init__.py @@ -1,5 +1,5 @@ from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware -from letta.server.rest_api.middleware.log_context import LogContextMiddleware +from letta.server.rest_api.middleware.logging import LoggingMiddleware from letta.server.rest_api.middleware.profiler_context import ProfilerContextMiddleware -__all__ = ["CheckPasswordMiddleware", "LogContextMiddleware", "ProfilerContextMiddleware"] +__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware", "ProfilerContextMiddleware"] diff --git a/letta/server/rest_api/middleware/log_context.py b/letta/server/rest_api/middleware/log_context.py deleted file mode 100644 index a7e27328..00000000 --- a/letta/server/rest_api/middleware/log_context.py +++ /dev/null @@ -1,63 +0,0 @@ -import re - -from starlette.middleware.base import BaseHTTPMiddleware - -from letta.log_context import clear_log_context, update_log_context -from letta.schemas.enums import PrimitiveType -from letta.validators import PRIMITIVE_ID_PATTERNS - - -class LogContextMiddleware(BaseHTTPMiddleware): - """ - Middleware that enriches log context with request-specific attributes. - - Automatically extracts and sets: - - actor_id: From the 'user_id' header - - org_id: From organization-related endpoints - - Letta primitive IDs: agent_id, tool_id, block_id, etc. from URL paths - - This enables all logs within a request to be automatically tagged with - relevant context for better filtering and correlation in monitoring systems. - """ - - async def dispatch(self, request, call_next): - clear_log_context() - - try: - context = {} - - actor_id = request.headers.get("user_id") - if actor_id: - context["actor_id"] = actor_id - - path = request.url.path - path_parts = [p for p in path.split("/") if p] - - matched_parts = set() - for part in path_parts: - if part in matched_parts: - continue - - for primitive_type in PrimitiveType: - prefix = primitive_type.value - pattern = PRIMITIVE_ID_PATTERNS.get(prefix) - - if pattern and pattern.match(part): - context_key = f"{primitive_type.name.lower()}_id" - - if primitive_type == PrimitiveType.ORGANIZATION: - context_key = "org_id" - elif primitive_type == PrimitiveType.USER: - context_key = "user_id" - - context[context_key] = part - matched_parts.add(part) - break - - if context: - update_log_context(**context) - - response = await call_next(request) - return response - finally: - clear_log_context() diff --git a/letta/server/rest_api/middleware/logging.py b/letta/server/rest_api/middleware/logging.py new file mode 100644 index 00000000..ea781b9e --- /dev/null +++ b/letta/server/rest_api/middleware/logging.py @@ -0,0 +1,115 @@ +""" +Unified logging middleware that enriches log context and ensures exceptions are logged. +""" + +import re +import traceback +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +from letta.log import get_logger +from letta.log_context import clear_log_context, update_log_context +from letta.schemas.enums import PrimitiveType +from letta.validators import PRIMITIVE_ID_PATTERNS + +logger = get_logger(__name__) + + +class LoggingMiddleware(BaseHTTPMiddleware): + """ + Middleware that enriches log context with request-specific attributes and logs exceptions. + + Automatically extracts and sets: + - actor_id: From the 'user_id' header + - org_id: From organization-related endpoints + - Letta primitive IDs: agent_id, tool_id, block_id, etc. from URL paths + + Also catches all exceptions and logs them with structured context before re-raising. + """ + + async def dispatch(self, request: Request, call_next: Callable): + clear_log_context() + + try: + # Extract and set log context + context = {} + + actor_id = request.headers.get("user_id") + if actor_id: + context["actor_id"] = actor_id + + path = request.url.path + path_parts = [p for p in path.split("/") if p] + + matched_parts = set() + for part in path_parts: + if part in matched_parts: + continue + + for primitive_type in PrimitiveType: + prefix = primitive_type.value + pattern = PRIMITIVE_ID_PATTERNS.get(prefix) + + if pattern and pattern.match(part): + context_key = f"{primitive_type.name.lower()}_id" + + if primitive_type == PrimitiveType.ORGANIZATION: + context_key = "org_id" + elif primitive_type == PrimitiveType.USER: + context_key = "user_id" + + context[context_key] = part + matched_parts.add(part) + break + + if context: + update_log_context(**context) + + response = await call_next(request) + return response + + except Exception as exc: + # Extract request context + request_context = { + "method": request.method, + "url": str(request.url), + "path": request.url.path, + "query_params": dict(request.query_params), + "client_host": request.client.host if request.client else None, + "user_agent": request.headers.get("user-agent"), + } + + # Extract user context if available + user_context = {} + if hasattr(request.state, "user_id"): + user_context["user_id"] = request.state.user_id + if hasattr(request.state, "org_id"): + user_context["org_id"] = request.state.org_id + + # Check for custom context attached to the exception + custom_context = {} + if hasattr(exc, "__letta_context__"): + custom_context = exc.__letta_context__ + + # Log with structured data + logger.error( + f"Unhandled exception in request: {exc.__class__.__name__}: {str(exc)}", + extra={ + "exception_type": exc.__class__.__name__, + "exception_message": str(exc), + "exception_module": exc.__class__.__module__, + "request": request_context, + "user": user_context, + "custom_context": custom_context, + "traceback": traceback.format_exc(), + }, + exc_info=True, + ) + + # Re-raise to let FastAPI's exception handlers deal with it + raise + + finally: + clear_log_context() diff --git a/tests/test_exception_logging.py b/tests/test_exception_logging.py new file mode 100644 index 00000000..2fa952cc --- /dev/null +++ b/tests/test_exception_logging.py @@ -0,0 +1,246 @@ +""" +Tests for global exception logging system. +""" + +import asyncio +import logging +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient +from starlette.middleware.base import BaseHTTPMiddleware + +from letta.exceptions.logging import add_exception_context, log_and_raise, log_exception +from letta.server.rest_api.middleware.logging import LoggingMiddleware + + +@pytest.fixture +def app_with_exception_middleware(): + """Create a test FastAPI app with logging middleware.""" + app = FastAPI() + app.add_middleware(LoggingMiddleware) + + @app.get("/test-error") + def test_error(): + raise ValueError("Test error message") + + @app.get("/test-error-with-context") + def test_error_with_context(): + exc = ValueError("Test error with context") + exc = add_exception_context( + exc, + user_id="test-user-123", + operation="test_operation", + ) + raise exc + + @app.get("/test-success") + def test_success(): + return {"status": "ok"} + + return app + + +def test_exception_middleware_logs_basic_exception(app_with_exception_middleware): + """Test that the middleware logs exceptions with basic context.""" + client = TestClient(app_with_exception_middleware, raise_server_exceptions=False) + + with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger: + response = client.get("/test-error") + + # Should return 500 + assert response.status_code == 500 + + # Should log the error + assert mock_logger.error.called + call_args = mock_logger.error.call_args + + # Check the message + assert "ValueError" in call_args[0][0] + assert "Test error message" in call_args[0][0] + + # Check the extra context + extra = call_args[1]["extra"] + assert extra["exception_type"] == "ValueError" + assert extra["exception_message"] == "Test error message" + assert "request" in extra + assert extra["request"]["method"] == "GET" + assert "/test-error" in extra["request"]["path"] + + +def test_exception_middleware_logs_custom_context(app_with_exception_middleware): + """Test that the middleware logs custom context attached to exceptions.""" + client = TestClient(app_with_exception_middleware, raise_server_exceptions=False) + + with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger: + response = client.get("/test-error-with-context") + + # Should return 500 + assert response.status_code == 500 + + # Should log the error with custom context + assert mock_logger.error.called + call_args = mock_logger.error.call_args + extra = call_args[1]["extra"] + + # Check custom context + assert "custom_context" in extra + assert extra["custom_context"]["user_id"] == "test-user-123" + assert extra["custom_context"]["operation"] == "test_operation" + + +def test_exception_middleware_does_not_interfere_with_success(app_with_exception_middleware): + """Test that the middleware doesn't interfere with successful requests.""" + client = TestClient(app_with_exception_middleware) + + response = client.get("/test-success") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_add_exception_context(): + """Test that add_exception_context properly attaches context to exceptions.""" + exc = ValueError("Test error") + + # Add context + exc_with_context = add_exception_context( + exc, + user_id="user-123", + agent_id="agent-456", + operation="test_op", + ) + + # Should be the same exception object + assert exc_with_context is exc + + # Should have context attached + assert hasattr(exc, "__letta_context__") + assert exc.__letta_context__["user_id"] == "user-123" + assert exc.__letta_context__["agent_id"] == "agent-456" + assert exc.__letta_context__["operation"] == "test_op" + + +def test_add_exception_context_multiple_times(): + """Test that add_exception_context can be called multiple times.""" + exc = ValueError("Test error") + + # Add context in multiple calls + add_exception_context(exc, user_id="user-123") + add_exception_context(exc, agent_id="agent-456") + + # Both should be present + assert exc.__letta_context__["user_id"] == "user-123" + assert exc.__letta_context__["agent_id"] == "agent-456" + + +def test_log_and_raise(): + """Test that log_and_raise logs and then raises the exception.""" + exc = ValueError("Test error") + + with patch("letta.exceptions.logging.logger") as mock_logger: + with pytest.raises(ValueError, match="Test error"): + log_and_raise( + exc, + "Operation failed", + context={"user_id": "user-123"}, + ) + + # Should have logged + assert mock_logger.error.called + call_args = mock_logger.error.call_args + + # Check message + assert "Operation failed" in call_args[0][0] + assert "ValueError" in call_args[0][0] + assert "Test error" in call_args[0][0] + + # Check extra context + extra = call_args[1]["extra"] + assert extra["exception_type"] == "ValueError" + assert extra["user_id"] == "user-123" + + +def test_log_exception(): + """Test that log_exception logs without raising.""" + exc = ValueError("Test error") + + with patch("letta.exceptions.logging.logger") as mock_logger: + # Should not raise + log_exception( + exc, + "Operation failed, using fallback", + context={"user_id": "user-123"}, + ) + + # Should have logged + assert mock_logger.error.called + call_args = mock_logger.error.call_args + + # Check message + assert "Operation failed, using fallback" in call_args[0][0] + assert "ValueError" in call_args[0][0] + + # Check extra context + extra = call_args[1]["extra"] + assert extra["exception_type"] == "ValueError" + assert extra["user_id"] == "user-123" + + +def test_log_exception_with_different_levels(): + """Test that log_exception respects different log levels.""" + exc = ValueError("Test error") + + with patch("letta.exceptions.logging.logger") as mock_logger: + # Test warning level + log_exception(exc, "Warning message", level="warning") + assert mock_logger.warning.called + + # Test info level + log_exception(exc, "Info message", level="info") + assert mock_logger.info.called + + +@pytest.mark.asyncio +async def test_global_exception_handler_setup(): + """Test that global exception handlers can be set up without errors.""" + from letta.server.global_exception_handler import setup_global_exception_handlers + + # Should not raise + setup_global_exception_handlers() + + # Verify sys.excepthook was modified + import sys + + assert sys.excepthook != sys.__excepthook__ + + +@pytest.mark.asyncio +async def test_asyncio_exception_handler(): + """Test that asyncio exception handler can be set up.""" + from letta.server.global_exception_handler import setup_asyncio_exception_handler + + loop = asyncio.get_event_loop() + + # Should not raise + setup_asyncio_exception_handler(loop) + + +def test_exception_middleware_preserves_traceback(app_with_exception_middleware): + """Test that the middleware preserves traceback information.""" + client = TestClient(app_with_exception_middleware, raise_server_exceptions=False) + + with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger: + response = client.get("/test-error") + + assert response.status_code == 500 + call_args = mock_logger.error.call_args + + # Check that exc_info was passed (enables traceback) + assert call_args[1]["exc_info"] is True + + # Check that traceback is in extra + extra = call_args[1]["extra"] + assert "traceback" in extra + assert "ValueError" in extra["traceback"] + assert "test_error" in extra["traceback"] diff --git a/tests/test_log_context_middleware.py b/tests/test_log_context_middleware.py index 14e9b3ab..546f169f 100644 --- a/tests/test_log_context_middleware.py +++ b/tests/test_log_context_middleware.py @@ -3,13 +3,13 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from letta.log_context import get_log_context -from letta.server.rest_api.middleware import LogContextMiddleware +from letta.server.rest_api.middleware import LoggingMiddleware @pytest.fixture def app(): app = FastAPI() - app.add_middleware(LogContextMiddleware) + app.add_middleware(LoggingMiddleware) @app.get("/v1/agents/{agent_id}") async def get_agent(agent_id: str):