feat: global exception middleware (#6017)
* global exception middleware * redo both logging middlewares as one * remove extra middleware files
This commit is contained in:
137
letta/exceptions/logging.py
Normal file
137
letta/exceptions/logging.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Helper utilities for structured exception logging.
|
||||
Use these when you need to add context to exceptions before raising them.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def log_and_raise(
|
||||
exception: Exception,
|
||||
message: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
level: str = "error",
|
||||
) -> None:
|
||||
"""
|
||||
Log an exception with structured context and then raise it.
|
||||
|
||||
This is useful when you want to ensure an exception is logged with
|
||||
full context before raising it.
|
||||
|
||||
Args:
|
||||
exception: The exception to log and raise
|
||||
message: Human-readable message to log
|
||||
context: Additional context to include in logs (dict)
|
||||
level: Log level (default: "error")
|
||||
|
||||
Example:
|
||||
try:
|
||||
result = some_operation()
|
||||
except ValueError as e:
|
||||
log_and_raise(
|
||||
e,
|
||||
"Failed to process operation",
|
||||
context={
|
||||
"user_id": user.id,
|
||||
"operation": "some_operation",
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
"""
|
||||
extra = {
|
||||
"exception_type": exception.__class__.__name__,
|
||||
"exception_message": str(exception),
|
||||
"exception_module": exception.__class__.__module__,
|
||||
}
|
||||
|
||||
if context:
|
||||
extra.update(context)
|
||||
|
||||
log_method = getattr(logger, level.lower())
|
||||
log_method(
|
||||
f"{message}: {exception.__class__.__name__}: {str(exception)}",
|
||||
extra=extra,
|
||||
exc_info=exception,
|
||||
)
|
||||
|
||||
raise exception
|
||||
|
||||
|
||||
def log_exception(
|
||||
exception: Exception,
|
||||
message: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
level: str = "error",
|
||||
) -> None:
|
||||
"""
|
||||
Log an exception with structured context without raising it.
|
||||
|
||||
Use this when you want to log an exception but handle it gracefully.
|
||||
|
||||
Args:
|
||||
exception: The exception to log
|
||||
message: Human-readable message to log
|
||||
context: Additional context to include in logs (dict)
|
||||
level: Log level (default: "error")
|
||||
|
||||
Example:
|
||||
try:
|
||||
result = some_operation()
|
||||
except ValueError as e:
|
||||
log_exception(
|
||||
e,
|
||||
"Operation failed, using fallback",
|
||||
context={"user_id": user.id}
|
||||
)
|
||||
result = fallback_operation()
|
||||
"""
|
||||
extra = {
|
||||
"exception_type": exception.__class__.__name__,
|
||||
"exception_message": str(exception),
|
||||
"exception_module": exception.__class__.__module__,
|
||||
}
|
||||
|
||||
if context:
|
||||
extra.update(context)
|
||||
|
||||
log_method = getattr(logger, level.lower())
|
||||
log_method(
|
||||
f"{message}: {exception.__class__.__name__}: {str(exception)}",
|
||||
extra=extra,
|
||||
exc_info=exception,
|
||||
)
|
||||
|
||||
|
||||
def add_exception_context(exception: Exception, **context) -> Exception:
|
||||
"""
|
||||
Add context to an exception that will be picked up by the global exception handler.
|
||||
|
||||
This attaches a __letta_context__ attribute to the exception with structured data.
|
||||
The global exception handler will automatically include this context in logs.
|
||||
|
||||
Args:
|
||||
exception: The exception to add context to
|
||||
**context: Key-value pairs to add as context
|
||||
|
||||
Returns:
|
||||
The same exception with context attached
|
||||
|
||||
Example:
|
||||
try:
|
||||
result = operation()
|
||||
except ValueError as e:
|
||||
raise add_exception_context(
|
||||
e,
|
||||
user_id=user.id,
|
||||
operation="do_thing",
|
||||
input_data=data,
|
||||
)
|
||||
"""
|
||||
if not hasattr(exception, "__letta_context__"):
|
||||
exception.__letta_context__ = {}
|
||||
exception.__letta_context__.update(context)
|
||||
return exception
|
||||
108
letta/server/global_exception_handler.py
Normal file
108
letta/server/global_exception_handler.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Global exception handlers for non-request exceptions (background tasks, startup, etc.)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def setup_global_exception_handlers():
|
||||
"""
|
||||
Set up global exception handlers to catch exceptions that occur outside of request handling.
|
||||
This includes:
|
||||
- Uncaught exceptions in the main thread
|
||||
- Exceptions in background threads
|
||||
- Asyncio task exceptions
|
||||
"""
|
||||
|
||||
# 1. Handle uncaught exceptions in the main thread
|
||||
def global_exception_hook(exc_type, exc_value, exc_traceback):
|
||||
"""
|
||||
Global exception hook for uncaught exceptions in the main thread.
|
||||
This catches exceptions that would otherwise crash the application.
|
||||
"""
|
||||
# Don't log KeyboardInterrupt (Ctrl+C)
|
||||
if issubclass(exc_type, KeyboardInterrupt):
|
||||
sys.__excepthook__(exc_type, exc_value, exc_traceback)
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
f"Uncaught exception in main thread: {exc_type.__name__}: {exc_value}",
|
||||
extra={
|
||||
"exception_type": exc_type.__name__,
|
||||
"exception_message": str(exc_value),
|
||||
"exception_module": exc_type.__module__,
|
||||
"traceback": "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)),
|
||||
},
|
||||
exc_info=(exc_type, exc_value, exc_traceback),
|
||||
)
|
||||
|
||||
sys.excepthook = global_exception_hook
|
||||
|
||||
# 2. Handle exceptions in threading
|
||||
def thread_exception_hook(args):
|
||||
"""
|
||||
Hook for exceptions in threads.
|
||||
"""
|
||||
logger.error(
|
||||
f"Uncaught exception in thread {args.thread.name}: {args.exc_type.__name__}: {args.exc_value}",
|
||||
extra={
|
||||
"exception_type": args.exc_type.__name__,
|
||||
"exception_message": str(args.exc_value),
|
||||
"exception_module": args.exc_type.__module__,
|
||||
"thread_name": args.thread.name,
|
||||
"thread_id": args.thread.ident,
|
||||
"traceback": "".join(traceback.format_exception(args.exc_type, args.exc_value, args.exc_traceback)),
|
||||
},
|
||||
exc_info=(args.exc_type, args.exc_value, args.exc_traceback),
|
||||
)
|
||||
|
||||
threading.excepthook = thread_exception_hook
|
||||
|
||||
logger.info("Global exception handlers initialized")
|
||||
|
||||
|
||||
def setup_asyncio_exception_handler(loop):
|
||||
"""
|
||||
Set up exception handler for asyncio loop.
|
||||
Call this with your event loop.
|
||||
"""
|
||||
|
||||
def asyncio_exception_handler(loop, context):
|
||||
"""
|
||||
Handler for exceptions in asyncio tasks.
|
||||
"""
|
||||
exception = context.get("exception")
|
||||
message = context.get("message", "Unhandled exception in asyncio")
|
||||
|
||||
extra = {
|
||||
"asyncio_context": str(context),
|
||||
"task": str(context.get("task")),
|
||||
}
|
||||
|
||||
if exception:
|
||||
extra.update(
|
||||
{
|
||||
"exception_type": exception.__class__.__name__,
|
||||
"exception_message": str(exception),
|
||||
"exception_module": exception.__class__.__module__,
|
||||
}
|
||||
)
|
||||
logger.error(
|
||||
f"Asyncio exception: {message}: {exception}",
|
||||
extra=extra,
|
||||
exc_info=exception,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Asyncio exception: {message}",
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
loop.set_exception_handler(asyncio_exception_handler)
|
||||
logger.info("Asyncio exception handler initialized")
|
||||
@@ -58,11 +58,12 @@ from letta.schemas.letta_message_content import (
|
||||
)
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.global_exception_handler import setup_global_exception_handlers
|
||||
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LogContextMiddleware, ProfilerContextMiddleware
|
||||
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware, ProfilerContextMiddleware
|
||||
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
|
||||
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
|
||||
@@ -260,12 +261,41 @@ def create_application() -> "FastAPI":
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# === Global Exception Handlers ===
|
||||
# Set up handlers for exceptions outside of request context (background tasks, threads, etc.)
|
||||
setup_global_exception_handlers()
|
||||
|
||||
# === Exception Handlers ===
|
||||
# TODO (cliandy): move to separate file
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unhandled error: {str(exc)}", exc_info=True)
|
||||
# Log with structured context
|
||||
request_context = {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
}
|
||||
|
||||
# Extract user context if available
|
||||
user_context = {}
|
||||
if hasattr(request.state, "user_id"):
|
||||
user_context["user_id"] = request.state.user_id
|
||||
if hasattr(request.state, "org_id"):
|
||||
user_context["org_id"] = request.state.org_id
|
||||
|
||||
logger.error(
|
||||
f"Unhandled error: {exc.__class__.__name__}: {str(exc)}",
|
||||
extra={
|
||||
"exception_type": exc.__class__.__name__,
|
||||
"exception_message": str(exc),
|
||||
"exception_module": exc.__class__.__module__,
|
||||
"request": request_context,
|
||||
"user": user_context,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.capture_exception(exc)
|
||||
|
||||
@@ -519,7 +549,8 @@ def create_application() -> "FastAPI":
|
||||
if telemetry_settings.profiler:
|
||||
app.add_middleware(ProfilerContextMiddleware)
|
||||
|
||||
app.add_middleware(LogContextMiddleware)
|
||||
# Add unified logging middleware - enriches log context and logs exceptions
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware
|
||||
from letta.server.rest_api.middleware.log_context import LogContextMiddleware
|
||||
from letta.server.rest_api.middleware.logging import LoggingMiddleware
|
||||
from letta.server.rest_api.middleware.profiler_context import ProfilerContextMiddleware
|
||||
|
||||
__all__ = ["CheckPasswordMiddleware", "LogContextMiddleware", "ProfilerContextMiddleware"]
|
||||
__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware", "ProfilerContextMiddleware"]
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import re
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from letta.log_context import clear_log_context, update_log_context
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.validators import PRIMITIVE_ID_PATTERNS
|
||||
|
||||
|
||||
class LogContextMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that enriches log context with request-specific attributes.
|
||||
|
||||
Automatically extracts and sets:
|
||||
- actor_id: From the 'user_id' header
|
||||
- org_id: From organization-related endpoints
|
||||
- Letta primitive IDs: agent_id, tool_id, block_id, etc. from URL paths
|
||||
|
||||
This enables all logs within a request to be automatically tagged with
|
||||
relevant context for better filtering and correlation in monitoring systems.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
clear_log_context()
|
||||
|
||||
try:
|
||||
context = {}
|
||||
|
||||
actor_id = request.headers.get("user_id")
|
||||
if actor_id:
|
||||
context["actor_id"] = actor_id
|
||||
|
||||
path = request.url.path
|
||||
path_parts = [p for p in path.split("/") if p]
|
||||
|
||||
matched_parts = set()
|
||||
for part in path_parts:
|
||||
if part in matched_parts:
|
||||
continue
|
||||
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
|
||||
if pattern and pattern.match(part):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
|
||||
context[context_key] = part
|
||||
matched_parts.add(part)
|
||||
break
|
||||
|
||||
if context:
|
||||
update_log_context(**context)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
finally:
|
||||
clear_log_context()
|
||||
115
letta/server/rest_api/middleware/logging.py
Normal file
115
letta/server/rest_api/middleware/logging.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Unified logging middleware that enriches log context and ensures exceptions are logged.
|
||||
"""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.log_context import clear_log_context, update_log_context
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.validators import PRIMITIVE_ID_PATTERNS
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that enriches log context with request-specific attributes and logs exceptions.
|
||||
|
||||
Automatically extracts and sets:
|
||||
- actor_id: From the 'user_id' header
|
||||
- org_id: From organization-related endpoints
|
||||
- Letta primitive IDs: agent_id, tool_id, block_id, etc. from URL paths
|
||||
|
||||
Also catches all exceptions and logs them with structured context before re-raising.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable):
|
||||
clear_log_context()
|
||||
|
||||
try:
|
||||
# Extract and set log context
|
||||
context = {}
|
||||
|
||||
actor_id = request.headers.get("user_id")
|
||||
if actor_id:
|
||||
context["actor_id"] = actor_id
|
||||
|
||||
path = request.url.path
|
||||
path_parts = [p for p in path.split("/") if p]
|
||||
|
||||
matched_parts = set()
|
||||
for part in path_parts:
|
||||
if part in matched_parts:
|
||||
continue
|
||||
|
||||
for primitive_type in PrimitiveType:
|
||||
prefix = primitive_type.value
|
||||
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
|
||||
|
||||
if pattern and pattern.match(part):
|
||||
context_key = f"{primitive_type.name.lower()}_id"
|
||||
|
||||
if primitive_type == PrimitiveType.ORGANIZATION:
|
||||
context_key = "org_id"
|
||||
elif primitive_type == PrimitiveType.USER:
|
||||
context_key = "user_id"
|
||||
|
||||
context[context_key] = part
|
||||
matched_parts.add(part)
|
||||
break
|
||||
|
||||
if context:
|
||||
update_log_context(**context)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except Exception as exc:
|
||||
# Extract request context
|
||||
request_context = {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"client_host": request.client.host if request.client else None,
|
||||
"user_agent": request.headers.get("user-agent"),
|
||||
}
|
||||
|
||||
# Extract user context if available
|
||||
user_context = {}
|
||||
if hasattr(request.state, "user_id"):
|
||||
user_context["user_id"] = request.state.user_id
|
||||
if hasattr(request.state, "org_id"):
|
||||
user_context["org_id"] = request.state.org_id
|
||||
|
||||
# Check for custom context attached to the exception
|
||||
custom_context = {}
|
||||
if hasattr(exc, "__letta_context__"):
|
||||
custom_context = exc.__letta_context__
|
||||
|
||||
# Log with structured data
|
||||
logger.error(
|
||||
f"Unhandled exception in request: {exc.__class__.__name__}: {str(exc)}",
|
||||
extra={
|
||||
"exception_type": exc.__class__.__name__,
|
||||
"exception_message": str(exc),
|
||||
"exception_module": exc.__class__.__module__,
|
||||
"request": request_context,
|
||||
"user": user_context,
|
||||
"custom_context": custom_context,
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Re-raise to let FastAPI's exception handlers deal with it
|
||||
raise
|
||||
|
||||
finally:
|
||||
clear_log_context()
|
||||
246
tests/test_exception_logging.py
Normal file
246
tests/test_exception_logging.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Tests for global exception logging system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from letta.exceptions.logging import add_exception_context, log_and_raise, log_exception
|
||||
from letta.server.rest_api.middleware.logging import LoggingMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_exception_middleware():
|
||||
"""Create a test FastAPI app with logging middleware."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
@app.get("/test-error")
|
||||
def test_error():
|
||||
raise ValueError("Test error message")
|
||||
|
||||
@app.get("/test-error-with-context")
|
||||
def test_error_with_context():
|
||||
exc = ValueError("Test error with context")
|
||||
exc = add_exception_context(
|
||||
exc,
|
||||
user_id="test-user-123",
|
||||
operation="test_operation",
|
||||
)
|
||||
raise exc
|
||||
|
||||
@app.get("/test-success")
|
||||
def test_success():
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_exception_middleware_logs_basic_exception(app_with_exception_middleware):
|
||||
"""Test that the middleware logs exceptions with basic context."""
|
||||
client = TestClient(app_with_exception_middleware, raise_server_exceptions=False)
|
||||
|
||||
with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger:
|
||||
response = client.get("/test-error")
|
||||
|
||||
# Should return 500
|
||||
assert response.status_code == 500
|
||||
|
||||
# Should log the error
|
||||
assert mock_logger.error.called
|
||||
call_args = mock_logger.error.call_args
|
||||
|
||||
# Check the message
|
||||
assert "ValueError" in call_args[0][0]
|
||||
assert "Test error message" in call_args[0][0]
|
||||
|
||||
# Check the extra context
|
||||
extra = call_args[1]["extra"]
|
||||
assert extra["exception_type"] == "ValueError"
|
||||
assert extra["exception_message"] == "Test error message"
|
||||
assert "request" in extra
|
||||
assert extra["request"]["method"] == "GET"
|
||||
assert "/test-error" in extra["request"]["path"]
|
||||
|
||||
|
||||
def test_exception_middleware_logs_custom_context(app_with_exception_middleware):
|
||||
"""Test that the middleware logs custom context attached to exceptions."""
|
||||
client = TestClient(app_with_exception_middleware, raise_server_exceptions=False)
|
||||
|
||||
with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger:
|
||||
response = client.get("/test-error-with-context")
|
||||
|
||||
# Should return 500
|
||||
assert response.status_code == 500
|
||||
|
||||
# Should log the error with custom context
|
||||
assert mock_logger.error.called
|
||||
call_args = mock_logger.error.call_args
|
||||
extra = call_args[1]["extra"]
|
||||
|
||||
# Check custom context
|
||||
assert "custom_context" in extra
|
||||
assert extra["custom_context"]["user_id"] == "test-user-123"
|
||||
assert extra["custom_context"]["operation"] == "test_operation"
|
||||
|
||||
|
||||
def test_exception_middleware_does_not_interfere_with_success(app_with_exception_middleware):
|
||||
"""Test that the middleware doesn't interfere with successful requests."""
|
||||
client = TestClient(app_with_exception_middleware)
|
||||
|
||||
response = client.get("/test-success")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_add_exception_context():
|
||||
"""Test that add_exception_context properly attaches context to exceptions."""
|
||||
exc = ValueError("Test error")
|
||||
|
||||
# Add context
|
||||
exc_with_context = add_exception_context(
|
||||
exc,
|
||||
user_id="user-123",
|
||||
agent_id="agent-456",
|
||||
operation="test_op",
|
||||
)
|
||||
|
||||
# Should be the same exception object
|
||||
assert exc_with_context is exc
|
||||
|
||||
# Should have context attached
|
||||
assert hasattr(exc, "__letta_context__")
|
||||
assert exc.__letta_context__["user_id"] == "user-123"
|
||||
assert exc.__letta_context__["agent_id"] == "agent-456"
|
||||
assert exc.__letta_context__["operation"] == "test_op"
|
||||
|
||||
|
||||
def test_add_exception_context_multiple_times():
|
||||
"""Test that add_exception_context can be called multiple times."""
|
||||
exc = ValueError("Test error")
|
||||
|
||||
# Add context in multiple calls
|
||||
add_exception_context(exc, user_id="user-123")
|
||||
add_exception_context(exc, agent_id="agent-456")
|
||||
|
||||
# Both should be present
|
||||
assert exc.__letta_context__["user_id"] == "user-123"
|
||||
assert exc.__letta_context__["agent_id"] == "agent-456"
|
||||
|
||||
|
||||
def test_log_and_raise():
|
||||
"""Test that log_and_raise logs and then raises the exception."""
|
||||
exc = ValueError("Test error")
|
||||
|
||||
with patch("letta.exceptions.logging.logger") as mock_logger:
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
log_and_raise(
|
||||
exc,
|
||||
"Operation failed",
|
||||
context={"user_id": "user-123"},
|
||||
)
|
||||
|
||||
# Should have logged
|
||||
assert mock_logger.error.called
|
||||
call_args = mock_logger.error.call_args
|
||||
|
||||
# Check message
|
||||
assert "Operation failed" in call_args[0][0]
|
||||
assert "ValueError" in call_args[0][0]
|
||||
assert "Test error" in call_args[0][0]
|
||||
|
||||
# Check extra context
|
||||
extra = call_args[1]["extra"]
|
||||
assert extra["exception_type"] == "ValueError"
|
||||
assert extra["user_id"] == "user-123"
|
||||
|
||||
|
||||
def test_log_exception():
|
||||
"""Test that log_exception logs without raising."""
|
||||
exc = ValueError("Test error")
|
||||
|
||||
with patch("letta.exceptions.logging.logger") as mock_logger:
|
||||
# Should not raise
|
||||
log_exception(
|
||||
exc,
|
||||
"Operation failed, using fallback",
|
||||
context={"user_id": "user-123"},
|
||||
)
|
||||
|
||||
# Should have logged
|
||||
assert mock_logger.error.called
|
||||
call_args = mock_logger.error.call_args
|
||||
|
||||
# Check message
|
||||
assert "Operation failed, using fallback" in call_args[0][0]
|
||||
assert "ValueError" in call_args[0][0]
|
||||
|
||||
# Check extra context
|
||||
extra = call_args[1]["extra"]
|
||||
assert extra["exception_type"] == "ValueError"
|
||||
assert extra["user_id"] == "user-123"
|
||||
|
||||
|
||||
def test_log_exception_with_different_levels():
|
||||
"""Test that log_exception respects different log levels."""
|
||||
exc = ValueError("Test error")
|
||||
|
||||
with patch("letta.exceptions.logging.logger") as mock_logger:
|
||||
# Test warning level
|
||||
log_exception(exc, "Warning message", level="warning")
|
||||
assert mock_logger.warning.called
|
||||
|
||||
# Test info level
|
||||
log_exception(exc, "Info message", level="info")
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_exception_handler_setup():
|
||||
"""Test that global exception handlers can be set up without errors."""
|
||||
from letta.server.global_exception_handler import setup_global_exception_handlers
|
||||
|
||||
# Should not raise
|
||||
setup_global_exception_handlers()
|
||||
|
||||
# Verify sys.excepthook was modified
|
||||
import sys
|
||||
|
||||
assert sys.excepthook != sys.__excepthook__
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asyncio_exception_handler():
|
||||
"""Test that asyncio exception handler can be set up."""
|
||||
from letta.server.global_exception_handler import setup_asyncio_exception_handler
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Should not raise
|
||||
setup_asyncio_exception_handler(loop)
|
||||
|
||||
|
||||
def test_exception_middleware_preserves_traceback(app_with_exception_middleware):
|
||||
"""Test that the middleware preserves traceback information."""
|
||||
client = TestClient(app_with_exception_middleware, raise_server_exceptions=False)
|
||||
|
||||
with patch("letta.server.rest_api.middleware.logging.logger") as mock_logger:
|
||||
response = client.get("/test-error")
|
||||
|
||||
assert response.status_code == 500
|
||||
call_args = mock_logger.error.call_args
|
||||
|
||||
# Check that exc_info was passed (enables traceback)
|
||||
assert call_args[1]["exc_info"] is True
|
||||
|
||||
# Check that traceback is in extra
|
||||
extra = call_args[1]["extra"]
|
||||
assert "traceback" in extra
|
||||
assert "ValueError" in extra["traceback"]
|
||||
assert "test_error" in extra["traceback"]
|
||||
@@ -3,13 +3,13 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from letta.log_context import get_log_context
|
||||
from letta.server.rest_api.middleware import LogContextMiddleware
|
||||
from letta.server.rest_api.middleware import LoggingMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = FastAPI()
|
||||
app.add_middleware(LogContextMiddleware)
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
@app.get("/v1/agents/{agent_id}")
|
||||
async def get_agent(agent_id: str):
|
||||
|
||||
Reference in New Issue
Block a user