Files
letta-server/letta/monitoring/request_monitor.py
Kian Jones 848aa962b6 feat: add memory tracking to core (#6179)
* add memory tracking to core

* move to asyncio from threading.Thread

* remove threading.thread all the way

* delay decorator monitoring initialization until after event loop is registered

* context manager to decorator

* add psutil
2025-11-24 19:09:32 -08:00

275 lines
11 KiB
Python

"""
Request size monitoring middleware for Letta application.
Tracks incoming request sizes to identify large uploads causing SSL memory spikes.
"""
import time
from typing import Optional
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from letta.log import get_logger
from letta.monitoring import get_memory_tracker
logger = get_logger(__name__)
# Size thresholds (in bytes)
SIZE_WARNING_THRESHOLD = 10 * 1024 * 1024 # 10MB
SIZE_ERROR_THRESHOLD = 50 * 1024 * 1024 # 50MB
SIZE_CRITICAL_THRESHOLD = 100 * 1024 * 1024 # 100MB
class RequestSizeMonitoringMiddleware(BaseHTTPMiddleware):
"""
Middleware to monitor incoming request sizes and detect large uploads.
This helps identify if SSL memory spikes are from large incoming data.
"""
def __init__(self, app):
super().__init__(app)
self.tracker = get_memory_tracker()
async def dispatch(self, request: Request, call_next):
start_time = time.time()
request_size = 0
content_type = request.headers.get("content-type", "")
# Track the endpoint
endpoint = f"{request.method} {request.url.path}"
# Special monitoring for known large data endpoints
critical_endpoints = [
("/upload", "file upload"),
("/messages", "message creation"),
("/agents", "agent creation/update"),
("/memory", "memory update"),
("/sources", "source upload"),
("/folders", "folder upload"),
]
is_critical = any(pattern in request.url.path.lower() for pattern, _ in critical_endpoints)
endpoint_type = next((desc for pattern, desc in critical_endpoints if pattern in request.url.path.lower()), "general")
# Try to get content length from headers
content_length = request.headers.get("content-length")
if content_length:
request_size = int(content_length)
# Get memory before processing
memory_before = self.tracker.get_memory_info()
# Enhanced monitoring for critical endpoints
if is_critical and request_size > 1024 * 1024: # Log all critical endpoints > 1MB
logger.info(
f"Critical endpoint access: {endpoint_type} - {endpoint} - "
f"Size: {request_size / 1024 / 1024:.2f} MB - "
f"Memory before: {memory_before.get('rss_mb', 0):.2f} MB"
)
# Log large incoming requests BEFORE processing
if request_size > SIZE_CRITICAL_THRESHOLD:
logger.critical(
f"CRITICAL: Large request incoming - {endpoint} ({endpoint_type}) - "
f"Size: {request_size / 1024 / 1024:.2f} MB - "
f"Content-Type: {content_type} - "
f"Memory before: {memory_before.get('rss_mb', 0):.2f} MB"
)
elif request_size > SIZE_ERROR_THRESHOLD:
logger.error(
f"Large request detected - {endpoint} ({endpoint_type}) - "
f"Size: {request_size / 1024 / 1024:.2f} MB - "
f"Content-Type: {content_type}"
)
elif request_size > SIZE_WARNING_THRESHOLD:
logger.warning(
f"Sizeable request - {endpoint} ({endpoint_type}) - "
f"Size: {request_size / 1024 / 1024:.2f} MB - "
f"Content-Type: {content_type}"
)
# For multipart/form-data (file uploads), try to get more details
if "multipart/form-data" in content_type:
logger.info(f"File upload detected at {endpoint} - Expected size: {request_size / 1024 / 1024:.2f} MB")
# Note: The actual file reading happens when the endpoint accesses request.form()
# That's when SSL read would spike
# Track the operation with memory monitoring
operation_name = f"request_{request.method}_{request.url.path.replace('/', '_')}"
# For large JSON payloads, log structure details
if request_size > SIZE_WARNING_THRESHOLD and "application/json" in content_type:
# Create a copy of the request for body logging
body_logger = RequestBodyLogger()
try:
# Clone the body for inspection (this won't consume the original)
body = await request.body()
# Put it back for the actual handler
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
# Log the structure
if body:
await self._log_json_structure(body, endpoint)
except Exception as e:
logger.warning(f"Could not inspect request body: {e}")
try:
# Process the request with memory tracking
# Use the decorator by creating a wrapped function
@self.tracker.track_operation(operation_name)
async def process_request():
return await call_next(request)
response = await process_request()
# Get memory after processing
memory_after = self.tracker.get_memory_info()
memory_delta = memory_after.get("rss_mb", 0) - memory_before.get("rss_mb", 0)
# Log if memory increased significantly during request processing
if memory_delta > 100: # More than 100MB increase
process_time = time.time() - start_time
logger.error(
f"MEMORY SPIKE during request: {endpoint} - "
f"Memory increased by {memory_delta:.2f} MB - "
f"Request size: {request_size / 1024 / 1024:.2f} MB - "
f"Process time: {process_time:.2f}s - "
f"Content-Type: {content_type}"
)
return response
except Exception as e:
# Log any errors with context
logger.error(f"Error processing request {endpoint} - Size: {request_size / 1024 / 1024:.2f} MB - Error: {str(e)}")
raise
async def _log_json_structure(self, body: bytes, endpoint: str):
"""Helper method to log JSON body structure for large payloads."""
import json
try:
data = json.loads(body)
body_size = len(body)
# Calculate field sizes
field_sizes = {}
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (str, bytes)):
field_sizes[key] = len(value)
elif isinstance(value, (list, dict)):
field_sizes[key] = len(json.dumps(value))
else:
field_sizes[key] = 0
elif isinstance(data, list):
field_sizes["list_items"] = len(data)
if data:
# Sample first item size
field_sizes["first_item_size"] = len(json.dumps(data[0]))
# Find largest fields
if field_sizes:
largest_fields = sorted(field_sizes.items(), key=lambda x: x[1], reverse=True)[:5]
logger.warning(
f"Large JSON payload structure at {endpoint}:\n"
f" Total size: {body_size / 1024 / 1024:.2f} MB\n"
f" Top fields by size:\n"
+ "\n".join([f" - {k}: {v / 1024 / 1024:.2f} MB" for k, v in largest_fields if v > 1024]) # Only show fields > 1KB
)
# Special monitoring for known problematic fields
problematic_fields = ["messages", "memory", "system", "tools", "files", "context"]
for field in problematic_fields:
if field in field_sizes and field_sizes[field] > 5 * 1024 * 1024: # > 5MB
logger.error(f"LARGE FIELD DETECTED: '{field}' is {field_sizes[field] / 1024 / 1024:.2f} MB at {endpoint}")
except json.JSONDecodeError:
logger.error(f"Could not parse JSON body at {endpoint} (size: {len(body) / 1024 / 1024:.2f} MB)")
except Exception as e:
logger.error(f"Error analyzing JSON structure: {e}")
class RequestBodyLogger:
"""
Utility to log request body details for specific endpoints.
Useful for debugging which fields contain large data.
"""
@staticmethod
async def log_json_body_structure(request: Request, endpoint: str):
"""Log the structure and size of JSON request bodies."""
try:
# Only for JSON requests
if "application/json" not in request.headers.get("content-type", ""):
return
# Get the body
body = await request.body()
body_size = len(body)
if body_size > SIZE_WARNING_THRESHOLD:
# Try to parse JSON to understand structure
import json
try:
data = json.loads(body)
# Log field sizes
field_sizes = {}
for key, value in data.items() if isinstance(data, dict) else enumerate(data):
if isinstance(value, (str, bytes)):
field_sizes[key] = len(value)
elif isinstance(value, (list, dict)):
field_sizes[key] = len(json.dumps(value))
else:
field_sizes[key] = 0
# Find largest fields
largest_fields = sorted(field_sizes.items(), key=lambda x: x[1], reverse=True)[:5]
logger.warning(
f"Large JSON body structure at {endpoint}:\n"
f" Total size: {body_size / 1024 / 1024:.2f} MB\n"
f" Top fields by size:\n" + "\n".join([f" - {k}: {v / 1024 / 1024:.2f} MB" for k, v in largest_fields])
)
except json.JSONDecodeError:
logger.error(f"Could not parse JSON body at {endpoint} (size: {body_size / 1024 / 1024:.2f} MB)")
except Exception as e:
logger.error(f"Error logging request body structure: {e}")
def identify_upload_endpoints(app):
"""
Scan the app routes to identify potential upload endpoints.
"""
upload_endpoints = []
for route in app.routes:
if hasattr(route, "path") and hasattr(route, "methods"):
path = route.path
methods = route.methods
# Look for common upload patterns
upload_keywords = ["upload", "file", "attachment", "media", "document", "image", "import"]
if any(keyword in path.lower() for keyword in upload_keywords):
upload_endpoints.append((path, methods))
# Also check for POST/PUT to any endpoint (potential large JSON)
if "POST" in methods or "PUT" in methods or "PATCH" in methods:
upload_endpoints.append((path, methods))
logger.info("Potential upload/large data endpoints identified:")
for path, methods in upload_endpoints[:20]: # Log first 20
logger.info(f" {', '.join(methods)}: {path}")
return upload_endpoints