feat: add memory tracking to core (#6179)

* add memory tracking to core

* move to asyncio from threading.Thread

* remove threading.thread all the way

* delay decorator monitoring initialization until after event loop is registered

* context manager to decorator

* add psutil
This commit is contained in:
Kian Jones
2025-11-14 13:24:50 -08:00
committed by Caren Thomas
parent 3b030d1bb0
commit 848aa962b6
19 changed files with 971 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
"""Memory and request monitoring utilities for Letta application."""
from .memory_tracker import MemoryTracker, get_memory_tracker, track_operation
from .request_monitor import RequestBodyLogger, RequestSizeMonitoringMiddleware, identify_upload_endpoints
__all__ = [
"MemoryTracker",
"get_memory_tracker",
"track_operation",
"RequestSizeMonitoringMiddleware",
"RequestBodyLogger",
"identify_upload_endpoints",
]

View File

@@ -0,0 +1,552 @@
"""
Memory tracking utility for Letta application.
Provides real-time memory monitoring with proactive alerting using asyncio.
"""
import asyncio
import functools
import gc
import json
import os
import sys
import time
import traceback
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List, Optional, Tuple
import psutil
from letta.log import get_logger
logger = get_logger(__name__)
class MemoryTracker:
"""
Track memory usage across different operations with proactive alerting.
Uses asyncio for all async operations.
Features:
- Real-time memory monitoring using asyncio
- Proactive alerts before OOM
- Automatic memory dumps on critical thresholds
- Per-operation tracking and reporting
"""
# Memory thresholds (in MB)
WARNING_THRESHOLD_MB = 1000 # Warning at 1GB
CRITICAL_THRESHOLD_MB = 2000 # Critical at 2GB
SPIKE_THRESHOLD_MB = 100 # Alert on 100MB+ spikes
# Memory percentage thresholds
MEMORY_PERCENT_WARNING = 70 # Warn at 70% system memory
MEMORY_PERCENT_CRITICAL = 85 # Critical at 85% system memory
MEMORY_PERCENT_FATAL = 95 # Fatal - likely to OOM
def __init__(self, enable_background_monitor: bool = True, monitor_interval: int = 5):
"""
Initialize the memory tracker.
Args:
enable_background_monitor: Whether to start background monitoring
monitor_interval: Interval in seconds between background checks
"""
self.process = psutil.Process(os.getpid())
self.measurements = defaultdict(list)
self.active_operations = {}
self.lock = asyncio.Lock() # Use asyncio.Lock instead of threading.Lock
self.monitor_interval = monitor_interval
self._monitoring = False
self._monitor_task = None
# Track memory history for trend analysis
self.memory_history = []
self.max_history_size = 100
# Track if we've already warned about memory levels
self._warned_levels = set()
# Start time for uptime tracking
self.start_time = datetime.now()
# Flag to track if we should start the monitor
self._should_start_monitor = enable_background_monitor
self._monitor_started = False
logger.info(
f"Memory tracker initialized - PID: {os.getpid()}, "
f"Warning: {self.WARNING_THRESHOLD_MB}MB, "
f"Critical: {self.CRITICAL_THRESHOLD_MB}MB"
)
def _ensure_monitor_started(self):
"""Start the background monitor if needed and if there's an event loop."""
if self._should_start_monitor and not self._monitor_started:
try:
# Check if there's a running event loop
loop = asyncio.get_running_loop()
# Create the monitor task
asyncio.create_task(self.start_background_monitor())
self._monitor_started = True
except RuntimeError:
# No event loop running yet, will try again later
pass
def get_memory_info(self) -> Dict[str, Any]:
"""Get current memory information."""
try:
mem_info = self.process.memory_info()
mem_percent = self.process.memory_percent()
# Get system-wide memory
system_mem = psutil.virtual_memory()
return {
"rss_mb": mem_info.rss / 1024 / 1024,
"vms_mb": mem_info.vms / 1024 / 1024,
"percent": mem_percent,
"system_available_mb": system_mem.available / 1024 / 1024,
"system_percent": system_mem.percent,
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Failed to get memory info: {e}")
return {}
def track_operation(self, operation_name: str):
"""
Decorator to track memory for specific operations.
Logs immediately on completion or error.
"""
def decorator(func):
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
# Ensure background monitor is started (now we have an event loop)
self._ensure_monitor_started()
start_mem_info = self.get_memory_info()
start_time = time.time()
# Log operation start if memory is already high
if start_mem_info.get("rss_mb", 0) > self.WARNING_THRESHOLD_MB:
logger.warning(f"Starting operation '{operation_name}' with high memory: {start_mem_info['rss_mb']:.2f} MB")
# Record call stack for debugging
stack = traceback.extract_stack()
operation_id = f"{operation_name}_{id(func)}_{time.time()}"
async with self.lock:
self.active_operations[operation_id] = {
"operation_name": operation_name,
"start_mem": start_mem_info,
"start_time": start_time,
"stack": stack,
}
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
# Log memory state on error
await self._log_operation_completion(operation_id, error=str(e))
raise
finally:
# Always log operation completion
await self._log_operation_completion(operation_id)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
start_mem_info = self.get_memory_info()
start_time = time.time()
# Log operation start if memory is already high
if start_mem_info.get("rss_mb", 0) > self.WARNING_THRESHOLD_MB:
logger.warning(f"Starting operation '{operation_name}' with high memory: {start_mem_info['rss_mb']:.2f} MB")
operation_id = f"{operation_name}_{id(func)}_{time.time()}"
# For sync functions, we can't use async lock, so we'll use the operation directly
self.active_operations[operation_id] = {
"operation_name": operation_name,
"start_mem": start_mem_info,
"start_time": start_time,
"stack": traceback.extract_stack(),
}
try:
result = func(*args, **kwargs)
return result
except Exception as e:
# Log memory state on error (sync version)
self._log_operation_completion_sync(operation_id, error=str(e))
raise
finally:
# Always log operation completion (sync version)
self._log_operation_completion_sync(operation_id)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
async def _log_operation_completion(self, operation_id: str, error: Optional[str] = None):
"""Log memory usage immediately when an operation completes (async version)."""
async with self.lock:
if operation_id not in self.active_operations:
return
operation_data = self.active_operations.pop(operation_id)
end_mem_info = self.get_memory_info()
end_time = time.time()
start_mem = operation_data["start_mem"].get("rss_mb", 0)
end_mem = end_mem_info.get("rss_mb", 0)
mem_delta = end_mem - start_mem
time_delta = end_time - operation_data["start_time"]
operation_name = operation_data["operation_name"]
# Record measurement
async with self.lock:
self.measurements[operation_name].append(
{
"memory_delta_mb": mem_delta,
"peak_memory_mb": end_mem,
"time_seconds": time_delta,
"timestamp": datetime.now().isoformat(),
"error": error,
"system_percent": end_mem_info.get("system_percent", 0),
}
)
self._log_memory_status(operation_name, mem_delta, end_mem, time_delta, end_mem_info, error, operation_data)
def _log_operation_completion_sync(self, operation_id: str, error: Optional[str] = None):
"""Log memory usage immediately when an operation completes (sync version for non-async functions)."""
if operation_id not in self.active_operations:
return
operation_data = self.active_operations.pop(operation_id)
end_mem_info = self.get_memory_info()
end_time = time.time()
start_mem = operation_data["start_mem"].get("rss_mb", 0)
end_mem = end_mem_info.get("rss_mb", 0)
mem_delta = end_mem - start_mem
time_delta = end_time - operation_data["start_time"]
operation_name = operation_data["operation_name"]
# Record measurement (sync version doesn't use async lock)
self.measurements[operation_name].append(
{
"memory_delta_mb": mem_delta,
"peak_memory_mb": end_mem,
"time_seconds": time_delta,
"timestamp": datetime.now().isoformat(),
"error": error,
"system_percent": end_mem_info.get("system_percent", 0),
}
)
self._log_memory_status(operation_name, mem_delta, end_mem, time_delta, end_mem_info, error, operation_data)
def _log_memory_status(
self,
operation_name: str,
mem_delta: float,
end_mem: float,
time_delta: float,
end_mem_info: Dict,
error: Optional[str],
operation_data: Dict,
):
"""Common logging logic for memory status."""
# Determine log level based on memory situation
if error:
logger.error(
f"Operation '{operation_name}' failed after {time_delta:.2f}s - "
f"Memory: {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB), "
f"System: {end_mem_info.get('system_percent', 0):.1f}%, "
f"Error: {error}"
)
elif mem_delta > self.SPIKE_THRESHOLD_MB:
logger.warning(
f"MEMORY SPIKE: Operation '{operation_name}' - "
f"Increased by {mem_delta:.2f} MB in {time_delta:.2f}s - "
f"Current: {end_mem:.2f} MB, System: {end_mem_info.get('system_percent', 0):.1f}%"
)
# Log stack trace for large spikes
if mem_delta > self.SPIKE_THRESHOLD_MB * 2:
stack = operation_data.get("stack", [])
if stack and len(stack) > 3:
logger.warning("Call stack for memory spike:")
for frame in stack[-5:]:
logger.warning(f" {frame.filename}:{frame.lineno} in {frame.name}")
elif end_mem > self.CRITICAL_THRESHOLD_MB:
logger.error(
f"CRITICAL MEMORY: Operation '{operation_name}' completed - "
f"Memory at {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB), "
f"System: {end_mem_info.get('system_percent', 0):.1f}%"
)
elif end_mem > self.WARNING_THRESHOLD_MB:
logger.warning(f"High memory after '{operation_name}': {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB) in {time_delta:.2f}s")
else:
# Only log normal operations in debug mode
logger.debug(f"Operation '{operation_name}' completed: Memory {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB) in {time_delta:.2f}s")
async def start_background_monitor(self):
"""Start the background memory monitoring task."""
if self._monitoring:
return
self._monitoring = True
self._monitor_started = True
self._monitor_task = asyncio.create_task(self._monitor_loop())
logger.info(f"Background memory monitor started (interval: {self.monitor_interval}s)")
async def stop_background_monitor(self):
"""Stop the background memory monitoring task."""
self._monitoring = False
self._monitor_started = False
if self._monitor_task:
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
logger.info("Background memory monitor stopped")
async def _monitor_loop(self):
"""Background monitoring loop that runs continuously using asyncio."""
consecutive_high_memory = 0
last_gc_time = time.time()
while self._monitoring:
try:
mem_info = self.get_memory_info()
current_mb = mem_info.get("rss_mb", 0)
system_percent = mem_info.get("system_percent", 0)
# Add to history
async with self.lock:
self.memory_history.append(mem_info)
if len(self.memory_history) > self.max_history_size:
self.memory_history.pop(0)
# Check memory levels
self._check_memory_thresholds(mem_info)
# Track consecutive high memory readings
if current_mb > self.WARNING_THRESHOLD_MB:
consecutive_high_memory += 1
else:
consecutive_high_memory = 0
# Force GC if memory is consistently high
if consecutive_high_memory >= 3 and time.time() - last_gc_time > 30:
await asyncio.to_thread(self._force_gc_with_logging)
last_gc_time = time.time()
# Check for memory leak patterns
if len(self.memory_history) >= 10:
await self._check_memory_trend()
# Log active operations if memory is critical
if system_percent > self.MEMORY_PERCENT_CRITICAL:
await self._log_active_operations()
await asyncio.sleep(self.monitor_interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in memory monitor loop: {e}")
await asyncio.sleep(self.monitor_interval)
def _check_memory_thresholds(self, mem_info: Dict[str, Any]):
"""Check memory against thresholds and log appropriately."""
current_mb = mem_info.get("rss_mb", 0)
system_percent = mem_info.get("system_percent", 0)
# Check system memory percentage thresholds
if system_percent > self.MEMORY_PERCENT_FATAL and "fatal" not in self._warned_levels:
logger.critical(
f"FATAL MEMORY LEVEL: {system_percent:.1f}% of system memory used! Process: {current_mb:.2f} MB - OOM imminent!"
)
self._warned_levels.add("fatal")
self._dump_memory_state()
elif system_percent > self.MEMORY_PERCENT_CRITICAL and "critical" not in self._warned_levels:
logger.error(f"CRITICAL: System memory at {system_percent:.1f}% - Process using {current_mb:.2f} MB")
self._warned_levels.add("critical")
self._dump_active_operations_summary()
elif system_percent > self.MEMORY_PERCENT_WARNING and "warning" not in self._warned_levels:
logger.warning(f"Memory warning: System at {system_percent:.1f}% - Process: {current_mb:.2f} MB")
self._warned_levels.add("warning")
# Reset warning levels if memory drops
if system_percent < self.MEMORY_PERCENT_WARNING:
self._warned_levels.clear()
async def _check_memory_trend(self):
"""Check if memory is trending upward (potential leak)."""
async with self.lock:
if len(self.memory_history) < 10:
return
# Get the last 10 readings
recent = self.memory_history[-10:]
# Calculate trend
first_mb = recent[0].get("rss_mb", 0)
last_mb = recent[-1].get("rss_mb", 0)
increase_mb = last_mb - first_mb
# Calculate average rate of increase
time_span = 10 * self.monitor_interval # seconds
rate_mb_per_min = (increase_mb / time_span) * 60
# Warn if memory is increasing rapidly
if rate_mb_per_min > 50: # More than 50MB/minute
logger.warning(
f"MEMORY LEAK SUSPECTED: Memory increasing at {rate_mb_per_min:.1f} MB/min - From {first_mb:.2f} MB to {last_mb:.2f} MB"
)
self._dump_active_operations_summary()
def _force_gc_with_logging(self):
"""Force garbage collection and log the results."""
before_mb = self.get_memory_info().get("rss_mb", 0)
collected = gc.collect()
after_mb = self.get_memory_info().get("rss_mb", 0)
freed_mb = before_mb - after_mb
if freed_mb > 10: # Only log if significant memory was freed
logger.info(f"Garbage collection freed {freed_mb:.2f} MB (collected {collected} objects)")
async def _log_active_operations(self):
"""Log currently active operations."""
async with self.lock:
if not self.active_operations:
return
logger.warning(f"Active operations during high memory ({len(self.active_operations)} running):")
for op_id, op_data in self.active_operations.items():
duration = time.time() - op_data["start_time"]
logger.warning(
f" - {op_data['operation_name']}: running for {duration:.1f}s, started at {op_data['start_mem'].get('rss_mb', 0):.2f} MB"
)
def _dump_active_operations_summary(self):
"""Dump summary of recent operations."""
logger.warning("=== Recent Operation Memory Usage ===")
for operation_name, measurements in self.measurements.items():
if measurements:
recent = measurements[-5:] # Last 5 measurements
total_delta = sum(m["memory_delta_mb"] for m in recent)
avg_delta = total_delta / len(recent)
max_peak = max(m["peak_memory_mb"] for m in recent)
logger.warning(f"{operation_name}: Avg Δ{avg_delta:+.1f} MB, Peak {max_peak:.1f} MB ({len(recent)} recent ops)")
def _dump_memory_state(self):
"""Dump detailed memory state for debugging."""
logger.critical("=== MEMORY STATE DUMP ===")
# Current memory
mem_info = self.get_memory_info()
logger.critical(f"Current memory: {json.dumps(mem_info, indent=2)}")
# Top memory consumers by operation
logger.critical("Top memory consuming operations:")
for op_name, measurements in sorted(self.measurements.items(), key=lambda x: sum(m["memory_delta_mb"] for m in x[1]), reverse=True)[
:5
]:
if measurements:
total = sum(m["memory_delta_mb"] for m in measurements)
logger.critical(f" {op_name}: {total:.1f} MB total across {len(measurements)} calls")
# Active operations
if self.active_operations:
logger.critical(f"Active operations: {len(self.active_operations)}")
for op_data in self.active_operations.values():
logger.critical(f" - {op_data['operation_name']}")
# System info
logger.critical(f"Uptime: {datetime.now() - self.start_time}")
logger.critical(f"PID: {os.getpid()}")
# Attempt to get top memory objects (if available)
try:
import sys
logger.critical("Top objects by count:")
obj_counts = defaultdict(int)
for obj in gc.get_objects()[:1000]: # Sample first 1000 objects
obj_counts[type(obj).__name__] += 1
for obj_type, count in sorted(obj_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
logger.critical(f" {obj_type}: {count}")
except Exception as e:
logger.error(f"Could not get object counts: {e}")
def get_report(self) -> str:
"""Generate a summary report of memory usage."""
lines = []
lines.append("=== MEMORY USAGE REPORT ===")
lines.append(f"Uptime: {datetime.now() - self.start_time}")
current = self.get_memory_info()
lines.append(f"Current memory: {current.get('rss_mb', 0):.2f} MB")
lines.append(f"System memory: {current.get('system_percent', 0):.1f}%")
lines.append("\nOperations summary:")
for operation_name, measurements in self.measurements.items():
if measurements:
total_mem = sum(m["memory_delta_mb"] for m in measurements)
avg_mem = total_mem / len(measurements)
max_mem = max(m["memory_delta_mb"] for m in measurements)
errors = sum(1 for m in measurements if m.get("error"))
lines.append(
f" {operation_name}: "
f"{len(measurements)} calls, "
f"Avg Δ{avg_mem:+.1f} MB, "
f"Max Δ{max_mem:+.1f} MB, "
f"Total Δ{total_mem:+.1f} MB"
)
if errors:
lines.append(f" ({errors} errors)")
return "\n".join(lines)
# Global tracker instance
_global_tracker = None
def get_memory_tracker(enable_background_monitor: bool = True, monitor_interval: int = 5) -> MemoryTracker:
"""Get or create the global memory tracker instance."""
global _global_tracker
if _global_tracker is None:
_global_tracker = MemoryTracker(enable_background_monitor=enable_background_monitor, monitor_interval=monitor_interval)
return _global_tracker
def track_operation(operation_name: str):
"""
Convenience decorator that uses the global tracker.
Usage:
@track_operation("my_operation")
async def my_function():
...
"""
tracker = get_memory_tracker()
return tracker.track_operation(operation_name)

View File

@@ -0,0 +1,274 @@
"""
Request size monitoring middleware for Letta application.
Tracks incoming request sizes to identify large uploads causing SSL memory spikes.
"""
import time
from typing import Optional
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from letta.log import get_logger
from letta.monitoring import get_memory_tracker
logger = get_logger(__name__)
# Size thresholds (in bytes)
SIZE_WARNING_THRESHOLD = 10 * 1024 * 1024 # 10MB
SIZE_ERROR_THRESHOLD = 50 * 1024 * 1024 # 50MB
SIZE_CRITICAL_THRESHOLD = 100 * 1024 * 1024 # 100MB
class RequestSizeMonitoringMiddleware(BaseHTTPMiddleware):
"""
Middleware to monitor incoming request sizes and detect large uploads.
This helps identify if SSL memory spikes are from large incoming data.
"""
def __init__(self, app):
super().__init__(app)
self.tracker = get_memory_tracker()
async def dispatch(self, request: Request, call_next):
start_time = time.time()
request_size = 0
content_type = request.headers.get("content-type", "")
# Track the endpoint
endpoint = f"{request.method} {request.url.path}"
# Special monitoring for known large data endpoints
critical_endpoints = [
("/upload", "file upload"),
("/messages", "message creation"),
("/agents", "agent creation/update"),
("/memory", "memory update"),
("/sources", "source upload"),
("/folders", "folder upload"),
]
is_critical = any(pattern in request.url.path.lower() for pattern, _ in critical_endpoints)
endpoint_type = next((desc for pattern, desc in critical_endpoints if pattern in request.url.path.lower()), "general")
# Try to get content length from headers
content_length = request.headers.get("content-length")
if content_length:
request_size = int(content_length)
# Get memory before processing
memory_before = self.tracker.get_memory_info()
# Enhanced monitoring for critical endpoints
if is_critical and request_size > 1024 * 1024: # Log all critical endpoints > 1MB
logger.info(
f"Critical endpoint access: {endpoint_type} - {endpoint} - "
f"Size: {request_size / 1024 / 1024:.2f} MB - "
f"Memory before: {memory_before.get('rss_mb', 0):.2f} MB"
)
# Log large incoming requests BEFORE processing
if request_size > SIZE_CRITICAL_THRESHOLD:
logger.critical(
f"CRITICAL: Large request incoming - {endpoint} ({endpoint_type}) - "
f"Size: {request_size / 1024 / 1024:.2f} MB - "
f"Content-Type: {content_type} - "
f"Memory before: {memory_before.get('rss_mb', 0):.2f} MB"
)
elif request_size > SIZE_ERROR_THRESHOLD:
logger.error(
f"Large request detected - {endpoint} ({endpoint_type}) - "
f"Size: {request_size / 1024 / 1024:.2f} MB - "
f"Content-Type: {content_type}"
)
elif request_size > SIZE_WARNING_THRESHOLD:
logger.warning(
f"Sizeable request - {endpoint} ({endpoint_type}) - "
f"Size: {request_size / 1024 / 1024:.2f} MB - "
f"Content-Type: {content_type}"
)
# For multipart/form-data (file uploads), try to get more details
if "multipart/form-data" in content_type:
logger.info(f"File upload detected at {endpoint} - Expected size: {request_size / 1024 / 1024:.2f} MB")
# Note: The actual file reading happens when the endpoint accesses request.form()
# That's when SSL read would spike
# Track the operation with memory monitoring
operation_name = f"request_{request.method}_{request.url.path.replace('/', '_')}"
# For large JSON payloads, log structure details
if request_size > SIZE_WARNING_THRESHOLD and "application/json" in content_type:
# Create a copy of the request for body logging
body_logger = RequestBodyLogger()
try:
# Clone the body for inspection (this won't consume the original)
body = await request.body()
# Put it back for the actual handler
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
# Log the structure
if body:
await self._log_json_structure(body, endpoint)
except Exception as e:
logger.warning(f"Could not inspect request body: {e}")
try:
# Process the request with memory tracking
# Use the decorator by creating a wrapped function
@self.tracker.track_operation(operation_name)
async def process_request():
return await call_next(request)
response = await process_request()
# Get memory after processing
memory_after = self.tracker.get_memory_info()
memory_delta = memory_after.get("rss_mb", 0) - memory_before.get("rss_mb", 0)
# Log if memory increased significantly during request processing
if memory_delta > 100: # More than 100MB increase
process_time = time.time() - start_time
logger.error(
f"MEMORY SPIKE during request: {endpoint} - "
f"Memory increased by {memory_delta:.2f} MB - "
f"Request size: {request_size / 1024 / 1024:.2f} MB - "
f"Process time: {process_time:.2f}s - "
f"Content-Type: {content_type}"
)
return response
except Exception as e:
# Log any errors with context
logger.error(f"Error processing request {endpoint} - Size: {request_size / 1024 / 1024:.2f} MB - Error: {str(e)}")
raise
async def _log_json_structure(self, body: bytes, endpoint: str):
"""Helper method to log JSON body structure for large payloads."""
import json
try:
data = json.loads(body)
body_size = len(body)
# Calculate field sizes
field_sizes = {}
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (str, bytes)):
field_sizes[key] = len(value)
elif isinstance(value, (list, dict)):
field_sizes[key] = len(json.dumps(value))
else:
field_sizes[key] = 0
elif isinstance(data, list):
field_sizes["list_items"] = len(data)
if data:
# Sample first item size
field_sizes["first_item_size"] = len(json.dumps(data[0]))
# Find largest fields
if field_sizes:
largest_fields = sorted(field_sizes.items(), key=lambda x: x[1], reverse=True)[:5]
logger.warning(
f"Large JSON payload structure at {endpoint}:\n"
f" Total size: {body_size / 1024 / 1024:.2f} MB\n"
f" Top fields by size:\n"
+ "\n".join([f" - {k}: {v / 1024 / 1024:.2f} MB" for k, v in largest_fields if v > 1024]) # Only show fields > 1KB
)
# Special monitoring for known problematic fields
problematic_fields = ["messages", "memory", "system", "tools", "files", "context"]
for field in problematic_fields:
if field in field_sizes and field_sizes[field] > 5 * 1024 * 1024: # > 5MB
logger.error(f"LARGE FIELD DETECTED: '{field}' is {field_sizes[field] / 1024 / 1024:.2f} MB at {endpoint}")
except json.JSONDecodeError:
logger.error(f"Could not parse JSON body at {endpoint} (size: {len(body) / 1024 / 1024:.2f} MB)")
except Exception as e:
logger.error(f"Error analyzing JSON structure: {e}")
class RequestBodyLogger:
"""
Utility to log request body details for specific endpoints.
Useful for debugging which fields contain large data.
"""
@staticmethod
async def log_json_body_structure(request: Request, endpoint: str):
"""Log the structure and size of JSON request bodies."""
try:
# Only for JSON requests
if "application/json" not in request.headers.get("content-type", ""):
return
# Get the body
body = await request.body()
body_size = len(body)
if body_size > SIZE_WARNING_THRESHOLD:
# Try to parse JSON to understand structure
import json
try:
data = json.loads(body)
# Log field sizes
field_sizes = {}
for key, value in data.items() if isinstance(data, dict) else enumerate(data):
if isinstance(value, (str, bytes)):
field_sizes[key] = len(value)
elif isinstance(value, (list, dict)):
field_sizes[key] = len(json.dumps(value))
else:
field_sizes[key] = 0
# Find largest fields
largest_fields = sorted(field_sizes.items(), key=lambda x: x[1], reverse=True)[:5]
logger.warning(
f"Large JSON body structure at {endpoint}:\n"
f" Total size: {body_size / 1024 / 1024:.2f} MB\n"
f" Top fields by size:\n" + "\n".join([f" - {k}: {v / 1024 / 1024:.2f} MB" for k, v in largest_fields])
)
except json.JSONDecodeError:
logger.error(f"Could not parse JSON body at {endpoint} (size: {body_size / 1024 / 1024:.2f} MB)")
except Exception as e:
logger.error(f"Error logging request body structure: {e}")
def identify_upload_endpoints(app):
"""
Scan the app routes to identify potential upload endpoints.
"""
upload_endpoints = []
for route in app.routes:
if hasattr(route, "path") and hasattr(route, "methods"):
path = route.path
methods = route.methods
# Look for common upload patterns
upload_keywords = ["upload", "file", "attachment", "media", "document", "image", "import"]
if any(keyword in path.lower() for keyword in upload_keywords):
upload_endpoints.append((path, methods))
# Also check for POST/PUT to any endpoint (potential large JSON)
if "POST" in methods or "PUT" in methods or "PATCH" in methods:
upload_endpoints.append((path, methods))
logger.info("Potential upload/large data endpoints identified:")
for path, methods in upload_endpoints[:20]: # Log first 20
logger.info(f" {', '.join(methods)}: {path}")
return upload_endpoints

View File

@@ -14,6 +14,14 @@ import uvicorn
# Enable Python fault handler to get stack traces on segfaults
faulthandler.enable()
# Import memory tracking (if available)
try:
from letta.monitoring import RequestSizeMonitoringMiddleware, get_memory_tracker, identify_upload_endpoints
MEMORY_TRACKING_ENABLED = True
except ImportError:
MEMORY_TRACKING_ENABLED = False
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@@ -133,6 +141,13 @@ async def lifespan(app_: FastAPI):
"""
worker_id = os.getpid()
# Initialize memory tracking
if MEMORY_TRACKING_ENABLED:
logger.info(f"[Worker {worker_id}] Initializing memory tracking")
# Get the global tracker instance (will start background monitor automatically)
tracker = get_memory_tracker(enable_background_monitor=True, monitor_interval=5)
logger.info(f"[Worker {worker_id}] Memory tracking enabled - monitoring every 5s with proactive alerts")
if telemetry_settings.profiler:
try:
import googlecloudprofiler
@@ -174,6 +189,14 @@ async def lifespan(app_: FastAPI):
# Cleanup on shutdown
logger.info(f"[Worker {worker_id}] Starting lifespan shutdown")
# Report memory usage before shutdown
if MEMORY_TRACKING_ENABLED:
logger.info(f"[Worker {worker_id}] Generating final memory report")
tracker = get_memory_tracker()
report = tracker.get_report()
logger.info(f"[Worker {worker_id}] Memory report:\n{report}")
try:
from letta.jobs.scheduler import shutdown_scheduler_and_release_lock
@@ -556,6 +579,13 @@ def create_application() -> "FastAPI":
# Add unified logging middleware - enriches log context and logs exceptions
app.add_middleware(LoggingMiddleware)
# Add request size monitoring middleware to detect large uploads
if MEMORY_TRACKING_ENABLED:
app.add_middleware(RequestSizeMonitoringMiddleware)
logger.info("Request size monitoring middleware enabled")
# Identify potential upload endpoints
identify_upload_endpoints(app)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,

View File

@@ -28,6 +28,7 @@ from letta.errors import (
from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
from letta.log import get_logger
from letta.monitoring import track_operation
from letta.orm.errors import NoResultFound
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
@@ -329,6 +330,7 @@ async def _import_agent(
@router.post("/import", response_model=ImportedAgentsResponse, operation_id="import_agent")
@track_operation("import_agent")
async def import_agent(
file: UploadFile = File(...),
server: "SyncServer" = Depends(get_letta_server),
@@ -437,6 +439,7 @@ class CreateAgentRequest(CreateAgent):
@router.post("/", response_model=AgentState, operation_id="create_agent")
@track_operation("create_agent")
async def create_agent(
agent: CreateAgentRequest = Body(...),
server: "SyncServer" = Depends(get_letta_server),

View File

@@ -17,6 +17,7 @@ from letta.helpers.pinecone_utils import (
)
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.monitoring import track_operation
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
@@ -227,6 +228,7 @@ async def delete_folder(
@router.post("/{folder_id}/upload", response_model=FileMetadata, operation_id="upload_file_to_folder")
@track_operation("file_upload_to_folder")
async def upload_file_to_folder(
file: UploadFile,
folder_id: FolderId,

View File

@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Depends, Query
from letta.monitoring.memory_tracker import track_operation
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.model import EmbeddingModel, Model
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
@@ -13,6 +14,7 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[Model], operation_id="list_models")
@track_operation("list_llm_models_endpoint")
async def list_llm_models(
provider_category: Optional[List[ProviderCategory]] = Query(None),
provider_name: Optional[str] = Query(None),
@@ -40,6 +42,7 @@ async def list_llm_models(
@router.get("/embedding", response_model=List[EmbeddingModel], operation_id="list_embedding_models")
@track_operation("list_embedding_models_endpoint")
async def list_embedding_models(
server: "SyncServer" = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),

View File

@@ -17,6 +17,7 @@ from letta.helpers.pinecone_utils import (
)
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.monitoring import track_operation
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
@@ -207,6 +208,7 @@ async def delete_source(
@router.post("/{source_id}/upload", response_model=FileMetadata, operation_id="upload_file_to_source", deprecated=True)
@track_operation("file_upload_to_source")
async def upload_file_to_source(
file: UploadFile,
source_id: SourceId,

View File

@@ -31,6 +31,7 @@ from letta.interface import (
CLIInterface, # for printing to terminal
)
from letta.log import get_logger
from letta.monitoring.memory_tracker import track_operation
from letta.orm.errors import NoResultFound
from letta.otel.tracing import log_event, trace_method
from letta.prompts.gpt_system import get_system_text
@@ -972,6 +973,7 @@ class SyncServer(object):
return passage_count, document_count
@trace_method
@track_operation("list_llm_models_server")
async def list_llm_models_async(
self,
actor: User,
@@ -1023,6 +1025,7 @@ class SyncServer(object):
return unique_models
@track_operation("list_embedding_models_server")
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
"""Asynchronously list available embedding models with maximum concurrency"""
import asyncio
@@ -1049,6 +1052,7 @@ class SyncServer(object):
return embedding_models
@track_operation("get_enabled_providers")
async def get_enabled_providers_async(
self,
actor: User,

View File

@@ -7,6 +7,19 @@ import sqlalchemy as sa
from sqlalchemy import delete, func, insert, literal, or_, select, tuple_
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Import memory tracking if available
try:
from letta.monitoring import track_operation
MEMORY_TRACKING_ENABLED = True
except ImportError:
MEMORY_TRACKING_ENABLED = False
# Define a no-op decorator if tracking is not available
def track_operation(name):
return lambda f: f
from letta.constants import (
BASE_MEMORY_TOOLS,
BASE_MEMORY_TOOLS_V2,
@@ -327,6 +340,7 @@ class AgentManager:
# ======================================================================================================================
@trace_method
@track_operation("agent_creation")
async def create_agent_async(
self,
agent_create: CreateAgent,

View File

@@ -4,6 +4,20 @@ from typing import Any, List, Optional, Tuple
from openai.types.beta.function_tool import FunctionTool as OpenAITool
from letta.log import get_logger
# Import memory tracking if available
try:
from letta.monitoring import track_operation
MEMORY_TRACKING_ENABLED = True
except ImportError:
MEMORY_TRACKING_ENABLED = False
# Define a no-op decorator if tracking is not available
def track_operation(name):
return lambda f: f
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
@@ -96,6 +110,7 @@ class ContextWindowCalculator:
return None, 1
@track_operation("calculate_context_window")
async def calculate_context_window(
self,
agent_state: AgentState,

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import selectinload
from letta.constants import MAX_FILENAME_LENGTH
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
from letta.log import get_logger
from letta.monitoring import track_operation
from letta.orm.errors import NoResultFound
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
from letta.orm.sqlalchemy_base import AccessType
@@ -53,6 +54,7 @@ class FileManager:
@enforce_types
@trace_method
@track_operation("create_file_with_content")
async def create_file(
self,
file_metadata: PydanticFileMetadata,
@@ -356,6 +358,7 @@ class FileManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
@track_operation("upsert_file_content")
async def upsert_file_content(
self,
*,
@@ -402,6 +405,7 @@ class FileManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
@track_operation("list_files_with_content")
async def list_files(
self,
source_id: str,
@@ -629,6 +633,7 @@ class FileManager:
@enforce_types
@trace_method
@track_operation("batch_get_files_by_ids")
async def get_files_by_ids_async(
self, file_ids: List[str], actor: PydanticUser, *, include_content: bool = False
) -> List[PydanticFileMetadata]:
@@ -664,6 +669,7 @@ class FileManager:
@enforce_types
@trace_method
@track_operation("batch_get_files_for_agents")
async def get_files_for_agents_async(
self, agent_ids: List[str], actor: PydanticUser, *, include_content: bool = False
) -> List[PydanticFileMetadata]:

View File

@@ -4,6 +4,7 @@ from typing import List, Optional, Union
from sqlalchemy import and_, asc, delete, desc, or_, select
from sqlalchemy.orm import Session
from letta.monitoring import track_operation
from letta.orm.agent import Agent as AgentModel
from letta.orm.block import Block
from letta.orm.errors import NoResultFound
@@ -73,6 +74,7 @@ class GroupManager:
return group.to_pydantic()
@enforce_types
@track_operation("create_multi_agent_group")
async def create_group_async(self, group: Union[GroupCreate, InternalTemplateGroupCreate], actor: PydanticUser) -> PydanticGroup:
async with db_registry.async_session() as session:
new_group = GroupModel()
@@ -197,6 +199,7 @@ class GroupManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
@track_operation("list_multi_agent_messages")
async def list_group_messages_async(
self,
actor: PydanticUser,
@@ -320,6 +323,7 @@ class GroupManager:
else:
raise ValueError("Extend relationship is not supported for groups.")
@track_operation("process_multi_agent_relationships")
async def _process_agent_relationship_async(self, session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
if not agent_ids:
if replace:

View File

@@ -7,6 +7,20 @@ from sqlalchemy import delete, exists, func, select, text
from letta.constants import CONVERSATION_SEARCH_TOOL_NAME, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
# Import memory tracking if available
try:
from letta.monitoring import track_operation
MEMORY_TRACKING_ENABLED = True
except ImportError:
MEMORY_TRACKING_ENABLED = False
# Define a no-op decorator if tracking is not available
def track_operation(name):
return lambda f: f
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.otel.tracing import trace_method
@@ -441,6 +455,7 @@ class MessageManager:
@enforce_types
@trace_method
@track_operation("bulk_create_messages")
async def create_many_messages_async(
self,
pydantic_msgs: List[PydanticMessage],
@@ -825,6 +840,7 @@ class MessageManager:
@enforce_types
@trace_method
@track_operation("list_messages")
async def list_messages(
self,
actor: PydanticUser,
@@ -949,6 +965,7 @@ class MessageManager:
@enforce_types
@trace_method
@track_operation("bulk_delete_all_agent_messages")
async def delete_all_messages_for_agent_async(
self, agent_id: str, actor: PydanticUser, exclude_ids: Optional[List[str]] = None, strict_mode: bool = False
) -> int:
@@ -997,6 +1014,7 @@ class MessageManager:
@enforce_types
@trace_method
@track_operation("bulk_delete_messages_by_ids")
async def delete_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser, strict_mode: bool = False) -> int:
"""
Efficiently deletes messages by their specific IDs,
@@ -1044,6 +1062,7 @@ class MessageManager:
@enforce_types
@trace_method
@track_operation("search_messages")
async def search_messages_async(
self,
agent_id: str,
@@ -1164,6 +1183,7 @@ class MessageManager:
message_tuples.append((message, metadata))
return message_tuples
@track_operation("search_messages_org")
async def search_messages_org_async(
self,
actor: PydanticUser,

View File

@@ -11,6 +11,7 @@ from letta.constants import MAX_EMBEDDING_DIM
from letta.helpers.decorators import async_redis_cache
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
from letta.monitoring import track_operation
from letta.orm import ArchivesAgents
from letta.orm.errors import NoResultFound
from letta.orm.passage import ArchivalPassage, SourcePassage
@@ -302,6 +303,7 @@ class PassageManager:
@enforce_types
@trace_method
@track_operation("batch_create_archival_passages")
async def create_many_archival_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
"""Create multiple archival passages."""
archival_passages = []
@@ -354,6 +356,7 @@ class PassageManager:
@enforce_types
@trace_method
@track_operation("batch_create_source_passages")
async def create_many_source_passages_async(
self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser
) -> List[PydanticPassage]:
@@ -451,6 +454,7 @@ class PassageManager:
@enforce_types
@trace_method
@track_operation("insert_passages_with_embeddings")
async def insert_passage(
self,
agent_state: AgentState,
@@ -545,6 +549,7 @@ class PassageManager:
except Exception as e:
raise e
@track_operation("generate_embeddings_batch")
async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config, actor: PydanticUser) -> List[List[float]]:
"""Generate embeddings for all text chunks concurrently using LLMClient"""
@@ -764,6 +769,7 @@ class PassageManager:
@enforce_types
@trace_method
@track_operation("bulk_delete_agent_passages")
async def delete_agent_passages_async(
self,
passages: List[PydanticPassage],
@@ -818,6 +824,7 @@ class PassageManager:
@enforce_types
@trace_method
@track_operation("bulk_delete_source_passages")
async def delete_source_passages_async(
self,
actor: PydanticUser,

View File

@@ -3,6 +3,20 @@ from typing import List, Optional, Tuple, Union
from letta.orm.provider import Provider as ProviderModel
from letta.orm.provider_model import ProviderModel as ProviderModelORM
from letta.otel.tracing import trace_method
# Import memory tracking if available
try:
from letta.monitoring import track_operation
MEMORY_TRACKING_ENABLED = True
except ImportError:
MEMORY_TRACKING_ENABLED = False
# Define a no-op decorator if tracking is not available
def track_operation(name):
return lambda f: f
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
@@ -746,6 +760,7 @@ class ProviderManager:
@enforce_types
@trace_method
@track_operation("list_models")
async def list_models_async(
self,
actor: PydanticUser,

View File

@@ -5,6 +5,7 @@ from sqlalchemy import and_, exists, select
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.helpers.tpuf_client import should_use_tpuf
from letta.monitoring import track_operation
from letta.orm import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.source import Source as SourceModel
@@ -272,6 +273,7 @@ class SourceManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
@track_operation("list_all_attached_agents")
async def list_attached_agents(
self, source_id: str, actor: PydanticUser, ids_only: bool = False
) -> Union[List[PydanticAgentState], List[str]]:
@@ -465,6 +467,7 @@ class SourceManager:
@enforce_types
@trace_method
@track_operation("batch_get_sources_by_ids")
async def get_sources_by_ids_async(self, source_ids: List[str], actor: PydanticUser) -> List[PydanticSource]:
"""
Get multiple sources by their IDs in a single query.
@@ -491,6 +494,7 @@ class SourceManager:
@enforce_types
@trace_method
@track_operation("batch_get_sources_for_agents")
async def get_sources_for_agents_async(self, agent_ids: List[str], actor: PydanticUser) -> List[PydanticSource]:
"""
Get all sources associated with the given agents via sources-agents relationships.

View File

@@ -71,6 +71,7 @@ dependencies = [
"readability-lxml",
"google-genai>=1.15.0",
"datadog>=0.49.1",
"psutil>=5.9.0",
]
[project.scripts]

2
uv.lock generated
View File

@@ -2446,6 +2446,7 @@ dependencies = [
{ name = "orjson" },
{ name = "pathvalidate" },
{ name = "prettytable" },
{ name = "psutil" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pyhumps" },
@@ -2620,6 +2621,7 @@ requires-dist = [
{ name = "pinecone", extras = ["asyncio"], marker = "extra == 'pinecone'", specifier = ">=7.3.0" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.5.0" },
{ name = "prettytable", specifier = ">=3.9.0" },
{ name = "psutil", specifier = ">=5.9.0" },
{ name = "psycopg2", marker = "extra == 'postgres'", specifier = ">=2.9.10" },
{ name = "psycopg2-binary", marker = "extra == 'postgres'", specifier = ">=2.9.10" },
{ name = "pydantic", specifier = ">=2.10.6" },