Files
letta-server/letta/monitoring/memory_tracker.py
Kian Jones a5e435c56f fix: register memory monitor at startup (#6195)
register monitor at startup
2025-11-24 19:09:32 -08:00

577 lines
23 KiB
Python

"""
Memory tracking utility for Letta application.
Provides real-time memory monitoring with proactive alerting using asyncio.
"""
import asyncio
import functools
import gc
import json
import os
import sys
import time
import traceback
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List, Optional, Tuple
import psutil
from letta.log import get_logger
logger = get_logger(__name__)
class MemoryTracker:
"""
Track memory usage across different operations with proactive alerting.
Uses asyncio for all async operations.
Features:
- Real-time memory monitoring using asyncio
- Proactive alerts before OOM
- Automatic memory dumps on critical thresholds
- Per-operation tracking and reporting
"""
# Memory thresholds (in MB)
WARNING_THRESHOLD_MB = 1000 # Warning at 1GB
CRITICAL_THRESHOLD_MB = 2000 # Critical at 2GB
SPIKE_THRESHOLD_MB = 100 # Alert on 100MB+ spikes
# Memory percentage thresholds
MEMORY_PERCENT_WARNING = 70 # Warn at 70% system memory
MEMORY_PERCENT_CRITICAL = 85 # Critical at 85% system memory
MEMORY_PERCENT_FATAL = 95 # Fatal - likely to OOM
def __init__(self, enable_background_monitor: bool = True, monitor_interval: int = 5):
"""
Initialize the memory tracker.
Args:
enable_background_monitor: Whether to start background monitoring
monitor_interval: Interval in seconds between background checks
"""
self.process = psutil.Process(os.getpid())
self.measurements = defaultdict(list)
self.active_operations = {}
self.lock = asyncio.Lock() # Use asyncio.Lock instead of threading.Lock
self.monitor_interval = monitor_interval
self._monitoring = False
self._monitor_task = None
# Track memory history for trend analysis
self.memory_history = []
self.max_history_size = 100
# Track if we've already warned about memory levels
self._warned_levels = set()
# Start time for uptime tracking
self.start_time = datetime.now()
# Flag to track if we should start the monitor
self._should_start_monitor = enable_background_monitor
self._monitor_started = False
logger.info(
f"Memory tracker initialized - PID: {os.getpid()}, "
f"Warning: {self.WARNING_THRESHOLD_MB}MB, "
f"Critical: {self.CRITICAL_THRESHOLD_MB}MB"
)
def _ensure_monitor_started(self):
"""Start the background monitor if needed and if there's an event loop."""
if self._should_start_monitor and not self._monitor_started:
try:
# Check if there's a running event loop
loop = asyncio.get_running_loop()
# Create the monitor task
task = asyncio.create_task(self.start_background_monitor())
self._monitor_started = True
logger.debug("Monitor start task created from _ensure_monitor_started")
except RuntimeError as e:
# No event loop running yet, will try again later
logger.debug(f"No event loop available yet for monitor start: {e}")
def get_memory_info(self) -> Dict[str, Any]:
"""Get current memory information."""
try:
mem_info = self.process.memory_info()
mem_percent = self.process.memory_percent()
# Get system-wide memory
system_mem = psutil.virtual_memory()
return {
"rss_mb": mem_info.rss / 1024 / 1024,
"vms_mb": mem_info.vms / 1024 / 1024,
"percent": mem_percent,
"system_available_mb": system_mem.available / 1024 / 1024,
"system_percent": system_mem.percent,
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"Failed to get memory info: {e}")
return {}
def track_operation(self, operation_name: str):
"""
Decorator to track memory for specific operations.
Logs immediately on completion or error.
"""
def decorator(func):
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
# Ensure background monitor is started (now we have an event loop)
self._ensure_monitor_started()
start_mem_info = self.get_memory_info()
start_time = time.time()
# Log operation start if memory is already high
if start_mem_info.get("rss_mb", 0) > self.WARNING_THRESHOLD_MB:
logger.warning(f"Starting operation '{operation_name}' with high memory: {start_mem_info['rss_mb']:.2f} MB")
# Record call stack for debugging
stack = traceback.extract_stack()
operation_id = f"{operation_name}_{id(func)}_{time.time()}"
async with self.lock:
self.active_operations[operation_id] = {
"operation_name": operation_name,
"start_mem": start_mem_info,
"start_time": start_time,
"stack": stack,
}
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
# Log memory state on error
await self._log_operation_completion(operation_id, error=str(e))
raise
finally:
# Always log operation completion
await self._log_operation_completion(operation_id)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
start_mem_info = self.get_memory_info()
start_time = time.time()
# Log operation start if memory is already high
if start_mem_info.get("rss_mb", 0) > self.WARNING_THRESHOLD_MB:
logger.warning(f"Starting operation '{operation_name}' with high memory: {start_mem_info['rss_mb']:.2f} MB")
operation_id = f"{operation_name}_{id(func)}_{time.time()}"
# For sync functions, we can't use async lock, so we'll use the operation directly
self.active_operations[operation_id] = {
"operation_name": operation_name,
"start_mem": start_mem_info,
"start_time": start_time,
"stack": traceback.extract_stack(),
}
try:
result = func(*args, **kwargs)
return result
except Exception as e:
# Log memory state on error (sync version)
self._log_operation_completion_sync(operation_id, error=str(e))
raise
finally:
# Always log operation completion (sync version)
self._log_operation_completion_sync(operation_id)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
async def _log_operation_completion(self, operation_id: str, error: Optional[str] = None):
"""Log memory usage immediately when an operation completes (async version)."""
async with self.lock:
if operation_id not in self.active_operations:
return
operation_data = self.active_operations.pop(operation_id)
end_mem_info = self.get_memory_info()
end_time = time.time()
start_mem = operation_data["start_mem"].get("rss_mb", 0)
end_mem = end_mem_info.get("rss_mb", 0)
mem_delta = end_mem - start_mem
time_delta = end_time - operation_data["start_time"]
operation_name = operation_data["operation_name"]
# Record measurement
async with self.lock:
self.measurements[operation_name].append(
{
"memory_delta_mb": mem_delta,
"peak_memory_mb": end_mem,
"time_seconds": time_delta,
"timestamp": datetime.now().isoformat(),
"error": error,
"system_percent": end_mem_info.get("system_percent", 0),
}
)
self._log_memory_status(operation_name, mem_delta, end_mem, time_delta, end_mem_info, error, operation_data)
def _log_operation_completion_sync(self, operation_id: str, error: Optional[str] = None):
"""Log memory usage immediately when an operation completes (sync version for non-async functions)."""
if operation_id not in self.active_operations:
return
operation_data = self.active_operations.pop(operation_id)
end_mem_info = self.get_memory_info()
end_time = time.time()
start_mem = operation_data["start_mem"].get("rss_mb", 0)
end_mem = end_mem_info.get("rss_mb", 0)
mem_delta = end_mem - start_mem
time_delta = end_time - operation_data["start_time"]
operation_name = operation_data["operation_name"]
# Record measurement (sync version doesn't use async lock)
self.measurements[operation_name].append(
{
"memory_delta_mb": mem_delta,
"peak_memory_mb": end_mem,
"time_seconds": time_delta,
"timestamp": datetime.now().isoformat(),
"error": error,
"system_percent": end_mem_info.get("system_percent", 0),
}
)
self._log_memory_status(operation_name, mem_delta, end_mem, time_delta, end_mem_info, error, operation_data)
def _log_memory_status(
self,
operation_name: str,
mem_delta: float,
end_mem: float,
time_delta: float,
end_mem_info: Dict,
error: Optional[str],
operation_data: Dict,
):
"""Common logging logic for memory status."""
# Determine log level based on memory situation
if error:
logger.error(
f"Operation '{operation_name}' failed after {time_delta:.2f}s - "
f"Memory: {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB), "
f"System: {end_mem_info.get('system_percent', 0):.1f}%, "
f"Error: {error}"
)
elif mem_delta > self.SPIKE_THRESHOLD_MB:
logger.warning(
f"MEMORY SPIKE: Operation '{operation_name}' - "
f"Increased by {mem_delta:.2f} MB in {time_delta:.2f}s - "
f"Current: {end_mem:.2f} MB, System: {end_mem_info.get('system_percent', 0):.1f}%"
)
# Log stack trace for large spikes
if mem_delta > self.SPIKE_THRESHOLD_MB * 2:
stack = operation_data.get("stack", [])
if stack and len(stack) > 3:
logger.warning("Call stack for memory spike:")
for frame in stack[-5:]:
logger.warning(f" {frame.filename}:{frame.lineno} in {frame.name}")
elif end_mem > self.CRITICAL_THRESHOLD_MB:
logger.error(
f"CRITICAL MEMORY: Operation '{operation_name}' completed - "
f"Memory at {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB), "
f"System: {end_mem_info.get('system_percent', 0):.1f}%"
)
elif end_mem > self.WARNING_THRESHOLD_MB:
logger.warning(f"High memory after '{operation_name}': {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB) in {time_delta:.2f}s")
else:
# Only log normal operations in debug mode
logger.debug(f"Operation '{operation_name}' completed: Memory {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB) in {time_delta:.2f}s")
async def start_background_monitor(self):
"""Start the background memory monitoring task."""
if self._monitoring:
logger.info("Background monitor already running, skipping start")
return
self._monitoring = True
self._monitor_started = True
try:
self._monitor_task = asyncio.create_task(self._monitor_loop())
logger.info(f"Background memory monitor task created successfully (interval: {self.monitor_interval}s)")
except Exception as e:
logger.error(f"Failed to create monitor task: {e}", exc_info=True)
self._monitoring = False
self._monitor_started = False
raise
async def stop_background_monitor(self):
"""Stop the background memory monitoring task."""
self._monitoring = False
self._monitor_started = False
if self._monitor_task:
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
logger.info("Background memory monitor stopped")
async def _monitor_loop(self):
"""Background monitoring loop that runs continuously using asyncio."""
consecutive_high_memory = 0
last_gc_time = time.time()
iteration_count = 0
logger.info(f"Memory monitor loop started (PID: {os.getpid()})")
while self._monitoring:
try:
iteration_count += 1
mem_info = self.get_memory_info()
current_mb = mem_info.get("rss_mb", 0)
system_percent = mem_info.get("system_percent", 0)
# Add to history
async with self.lock:
self.memory_history.append(mem_info)
if len(self.memory_history) > self.max_history_size:
self.memory_history.pop(0)
# Log periodic memory status
# Using INFO since production logging is at INFO level
percent = (current_mb / self.CRITICAL_THRESHOLD_MB) * 100
logger.info(
f"Memory Status: RSS: {current_mb:.2f} MB ({percent:.1f}% of {self.CRITICAL_THRESHOLD_MB} MB limit), "
f"System: {system_percent:.1f}%"
)
# Check memory levels
self._check_memory_thresholds(mem_info)
# Track consecutive high memory readings
if current_mb > self.WARNING_THRESHOLD_MB:
consecutive_high_memory += 1
else:
consecutive_high_memory = 0
# Force GC if memory is consistently high
if consecutive_high_memory >= 3 and time.time() - last_gc_time > 30:
await asyncio.to_thread(self._force_gc_with_logging)
last_gc_time = time.time()
# Check for memory leak patterns
if len(self.memory_history) >= 10:
await self._check_memory_trend()
# Log active operations if memory is critical
if system_percent > self.MEMORY_PERCENT_CRITICAL:
await self._log_active_operations()
await asyncio.sleep(self.monitor_interval)
except asyncio.CancelledError:
logger.info(f"Memory monitor loop cancelled after {iteration_count} iterations")
break
except Exception as e:
logger.error(f"Error in memory monitor loop (iteration {iteration_count}): {e}", exc_info=True)
await asyncio.sleep(self.monitor_interval)
logger.info(f"Memory monitor loop exited after {iteration_count} iterations")
def _check_memory_thresholds(self, mem_info: Dict[str, Any]):
"""Check memory against thresholds and log appropriately."""
current_mb = mem_info.get("rss_mb", 0)
system_percent = mem_info.get("system_percent", 0)
# Check system memory percentage thresholds
if system_percent > self.MEMORY_PERCENT_FATAL and "fatal" not in self._warned_levels:
logger.critical(
f"FATAL MEMORY LEVEL: {system_percent:.1f}% of system memory used! Process: {current_mb:.2f} MB - OOM imminent!"
)
self._warned_levels.add("fatal")
self._dump_memory_state()
elif system_percent > self.MEMORY_PERCENT_CRITICAL and "critical" not in self._warned_levels:
logger.error(f"CRITICAL: System memory at {system_percent:.1f}% - Process using {current_mb:.2f} MB")
self._warned_levels.add("critical")
self._dump_active_operations_summary()
elif system_percent > self.MEMORY_PERCENT_WARNING and "warning" not in self._warned_levels:
logger.warning(f"Memory warning: System at {system_percent:.1f}% - Process: {current_mb:.2f} MB")
self._warned_levels.add("warning")
# Reset warning levels if memory drops
if system_percent < self.MEMORY_PERCENT_WARNING:
self._warned_levels.clear()
async def _check_memory_trend(self):
"""Check if memory is trending upward (potential leak)."""
async with self.lock:
if len(self.memory_history) < 10:
return
# Get the last 10 readings
recent = self.memory_history[-10:]
# Calculate trend
first_mb = recent[0].get("rss_mb", 0)
last_mb = recent[-1].get("rss_mb", 0)
increase_mb = last_mb - first_mb
# Calculate average rate of increase
time_span = 10 * self.monitor_interval # seconds
rate_mb_per_min = (increase_mb / time_span) * 60
# Warn if memory is increasing rapidly
if rate_mb_per_min > 50: # More than 50MB/minute
logger.warning(
f"MEMORY LEAK SUSPECTED: Memory increasing at {rate_mb_per_min:.1f} MB/min - From {first_mb:.2f} MB to {last_mb:.2f} MB"
)
self._dump_active_operations_summary()
def _force_gc_with_logging(self):
"""Force garbage collection and log the results."""
before_mb = self.get_memory_info().get("rss_mb", 0)
collected = gc.collect()
after_mb = self.get_memory_info().get("rss_mb", 0)
freed_mb = before_mb - after_mb
if freed_mb > 10: # Only log if significant memory was freed
logger.info(f"Garbage collection freed {freed_mb:.2f} MB (collected {collected} objects)")
async def _log_active_operations(self):
"""Log currently active operations."""
async with self.lock:
if not self.active_operations:
return
logger.warning(f"Active operations during high memory ({len(self.active_operations)} running):")
for op_id, op_data in self.active_operations.items():
duration = time.time() - op_data["start_time"]
logger.warning(
f" - {op_data['operation_name']}: running for {duration:.1f}s, started at {op_data['start_mem'].get('rss_mb', 0):.2f} MB"
)
def _dump_active_operations_summary(self):
"""Dump summary of recent operations."""
logger.warning("=== Recent Operation Memory Usage ===")
for operation_name, measurements in self.measurements.items():
if measurements:
recent = measurements[-5:] # Last 5 measurements
total_delta = sum(m["memory_delta_mb"] for m in recent)
avg_delta = total_delta / len(recent)
max_peak = max(m["peak_memory_mb"] for m in recent)
logger.warning(f"{operation_name}: Avg Δ{avg_delta:+.1f} MB, Peak {max_peak:.1f} MB ({len(recent)} recent ops)")
def _dump_memory_state(self):
"""Dump detailed memory state for debugging."""
logger.critical("=== MEMORY STATE DUMP ===")
# Current memory
mem_info = self.get_memory_info()
logger.critical(f"Current memory: {json.dumps(mem_info, indent=2)}")
# Top memory consumers by operation
logger.critical("Top memory consuming operations:")
for op_name, measurements in sorted(self.measurements.items(), key=lambda x: sum(m["memory_delta_mb"] for m in x[1]), reverse=True)[
:5
]:
if measurements:
total = sum(m["memory_delta_mb"] for m in measurements)
logger.critical(f" {op_name}: {total:.1f} MB total across {len(measurements)} calls")
# Active operations
if self.active_operations:
logger.critical(f"Active operations: {len(self.active_operations)}")
for op_data in self.active_operations.values():
logger.critical(f" - {op_data['operation_name']}")
# System info
logger.critical(f"Uptime: {datetime.now() - self.start_time}")
logger.critical(f"PID: {os.getpid()}")
# Attempt to get top memory objects (if available)
try:
import sys
logger.critical("Top objects by count:")
obj_counts = defaultdict(int)
for obj in gc.get_objects()[:1000]: # Sample first 1000 objects
obj_counts[type(obj).__name__] += 1
for obj_type, count in sorted(obj_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
logger.critical(f" {obj_type}: {count}")
except Exception as e:
logger.error(f"Could not get object counts: {e}")
def get_report(self) -> str:
"""Generate a summary report of memory usage."""
lines = []
lines.append("=== MEMORY USAGE REPORT ===")
lines.append(f"Uptime: {datetime.now() - self.start_time}")
current = self.get_memory_info()
lines.append(f"Current memory: {current.get('rss_mb', 0):.2f} MB")
lines.append(f"System memory: {current.get('system_percent', 0):.1f}%")
lines.append("\nOperations summary:")
for operation_name, measurements in self.measurements.items():
if measurements:
total_mem = sum(m["memory_delta_mb"] for m in measurements)
avg_mem = total_mem / len(measurements)
max_mem = max(m["memory_delta_mb"] for m in measurements)
errors = sum(1 for m in measurements if m.get("error"))
lines.append(
f" {operation_name}: "
f"{len(measurements)} calls, "
f"Avg Δ{avg_mem:+.1f} MB, "
f"Max Δ{max_mem:+.1f} MB, "
f"Total Δ{total_mem:+.1f} MB"
)
if errors:
lines.append(f" ({errors} errors)")
return "\n".join(lines)
# Global tracker instance
_global_tracker = None
def get_memory_tracker(enable_background_monitor: bool = True, monitor_interval: int = 5) -> MemoryTracker:
"""Get or create the global memory tracker instance."""
global _global_tracker
if _global_tracker is None:
_global_tracker = MemoryTracker(enable_background_monitor=enable_background_monitor, monitor_interval=monitor_interval)
return _global_tracker
def track_operation(operation_name: str):
"""
Convenience decorator that uses the global tracker.
Usage:
@track_operation("my_operation")
async def my_function():
...
"""
tracker = get_memory_tracker()
return tracker.track_operation(operation_name)