feat: add memory tracking to core (#6179)
* add memory tracking to core * move to asyncio from threading.Thread * remove threading.thread all the way * delay decorator monitoring initialization until after event loop is registered * context manager to decorator * add psutil
This commit is contained in:
13
letta/monitoring/__init__.py
Normal file
13
letta/monitoring/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Memory and request monitoring utilities for Letta application."""
|
||||
|
||||
from .memory_tracker import MemoryTracker, get_memory_tracker, track_operation
|
||||
from .request_monitor import RequestBodyLogger, RequestSizeMonitoringMiddleware, identify_upload_endpoints
|
||||
|
||||
__all__ = [
|
||||
"MemoryTracker",
|
||||
"get_memory_tracker",
|
||||
"track_operation",
|
||||
"RequestSizeMonitoringMiddleware",
|
||||
"RequestBodyLogger",
|
||||
"identify_upload_endpoints",
|
||||
]
|
||||
552
letta/monitoring/memory_tracker.py
Normal file
552
letta/monitoring/memory_tracker.py
Normal file
@@ -0,0 +1,552 @@
|
||||
"""
|
||||
Memory tracking utility for Letta application.
|
||||
Provides real-time memory monitoring with proactive alerting using asyncio.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryTracker:
|
||||
"""
|
||||
Track memory usage across different operations with proactive alerting.
|
||||
Uses asyncio for all async operations.
|
||||
|
||||
Features:
|
||||
- Real-time memory monitoring using asyncio
|
||||
- Proactive alerts before OOM
|
||||
- Automatic memory dumps on critical thresholds
|
||||
- Per-operation tracking and reporting
|
||||
"""
|
||||
|
||||
# Memory thresholds (in MB)
|
||||
WARNING_THRESHOLD_MB = 1000 # Warning at 1GB
|
||||
CRITICAL_THRESHOLD_MB = 2000 # Critical at 2GB
|
||||
SPIKE_THRESHOLD_MB = 100 # Alert on 100MB+ spikes
|
||||
|
||||
# Memory percentage thresholds
|
||||
MEMORY_PERCENT_WARNING = 70 # Warn at 70% system memory
|
||||
MEMORY_PERCENT_CRITICAL = 85 # Critical at 85% system memory
|
||||
MEMORY_PERCENT_FATAL = 95 # Fatal - likely to OOM
|
||||
|
||||
def __init__(self, enable_background_monitor: bool = True, monitor_interval: int = 5):
|
||||
"""
|
||||
Initialize the memory tracker.
|
||||
|
||||
Args:
|
||||
enable_background_monitor: Whether to start background monitoring
|
||||
monitor_interval: Interval in seconds between background checks
|
||||
"""
|
||||
self.process = psutil.Process(os.getpid())
|
||||
self.measurements = defaultdict(list)
|
||||
self.active_operations = {}
|
||||
self.lock = asyncio.Lock() # Use asyncio.Lock instead of threading.Lock
|
||||
self.monitor_interval = monitor_interval
|
||||
self._monitoring = False
|
||||
self._monitor_task = None
|
||||
|
||||
# Track memory history for trend analysis
|
||||
self.memory_history = []
|
||||
self.max_history_size = 100
|
||||
|
||||
# Track if we've already warned about memory levels
|
||||
self._warned_levels = set()
|
||||
|
||||
# Start time for uptime tracking
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# Flag to track if we should start the monitor
|
||||
self._should_start_monitor = enable_background_monitor
|
||||
self._monitor_started = False
|
||||
|
||||
logger.info(
|
||||
f"Memory tracker initialized - PID: {os.getpid()}, "
|
||||
f"Warning: {self.WARNING_THRESHOLD_MB}MB, "
|
||||
f"Critical: {self.CRITICAL_THRESHOLD_MB}MB"
|
||||
)
|
||||
|
||||
def _ensure_monitor_started(self):
|
||||
"""Start the background monitor if needed and if there's an event loop."""
|
||||
if self._should_start_monitor and not self._monitor_started:
|
||||
try:
|
||||
# Check if there's a running event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# Create the monitor task
|
||||
asyncio.create_task(self.start_background_monitor())
|
||||
self._monitor_started = True
|
||||
except RuntimeError:
|
||||
# No event loop running yet, will try again later
|
||||
pass
|
||||
|
||||
def get_memory_info(self) -> Dict[str, Any]:
|
||||
"""Get current memory information."""
|
||||
try:
|
||||
mem_info = self.process.memory_info()
|
||||
mem_percent = self.process.memory_percent()
|
||||
|
||||
# Get system-wide memory
|
||||
system_mem = psutil.virtual_memory()
|
||||
|
||||
return {
|
||||
"rss_mb": mem_info.rss / 1024 / 1024,
|
||||
"vms_mb": mem_info.vms / 1024 / 1024,
|
||||
"percent": mem_percent,
|
||||
"system_available_mb": system_mem.available / 1024 / 1024,
|
||||
"system_percent": system_mem.percent,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory info: {e}")
|
||||
return {}
|
||||
|
||||
def track_operation(self, operation_name: str):
|
||||
"""
|
||||
Decorator to track memory for specific operations.
|
||||
Logs immediately on completion or error.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# Ensure background monitor is started (now we have an event loop)
|
||||
self._ensure_monitor_started()
|
||||
|
||||
start_mem_info = self.get_memory_info()
|
||||
start_time = time.time()
|
||||
|
||||
# Log operation start if memory is already high
|
||||
if start_mem_info.get("rss_mb", 0) > self.WARNING_THRESHOLD_MB:
|
||||
logger.warning(f"Starting operation '{operation_name}' with high memory: {start_mem_info['rss_mb']:.2f} MB")
|
||||
|
||||
# Record call stack for debugging
|
||||
stack = traceback.extract_stack()
|
||||
operation_id = f"{operation_name}_{id(func)}_{time.time()}"
|
||||
|
||||
async with self.lock:
|
||||
self.active_operations[operation_id] = {
|
||||
"operation_name": operation_name,
|
||||
"start_mem": start_mem_info,
|
||||
"start_time": start_time,
|
||||
"stack": stack,
|
||||
}
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
# Log memory state on error
|
||||
await self._log_operation_completion(operation_id, error=str(e))
|
||||
raise
|
||||
finally:
|
||||
# Always log operation completion
|
||||
await self._log_operation_completion(operation_id)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_mem_info = self.get_memory_info()
|
||||
start_time = time.time()
|
||||
|
||||
# Log operation start if memory is already high
|
||||
if start_mem_info.get("rss_mb", 0) > self.WARNING_THRESHOLD_MB:
|
||||
logger.warning(f"Starting operation '{operation_name}' with high memory: {start_mem_info['rss_mb']:.2f} MB")
|
||||
|
||||
operation_id = f"{operation_name}_{id(func)}_{time.time()}"
|
||||
|
||||
# For sync functions, we can't use async lock, so we'll use the operation directly
|
||||
self.active_operations[operation_id] = {
|
||||
"operation_name": operation_name,
|
||||
"start_mem": start_mem_info,
|
||||
"start_time": start_time,
|
||||
"stack": traceback.extract_stack(),
|
||||
}
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
# Log memory state on error (sync version)
|
||||
self._log_operation_completion_sync(operation_id, error=str(e))
|
||||
raise
|
||||
finally:
|
||||
# Always log operation completion (sync version)
|
||||
self._log_operation_completion_sync(operation_id)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
async def _log_operation_completion(self, operation_id: str, error: Optional[str] = None):
|
||||
"""Log memory usage immediately when an operation completes (async version)."""
|
||||
async with self.lock:
|
||||
if operation_id not in self.active_operations:
|
||||
return
|
||||
|
||||
operation_data = self.active_operations.pop(operation_id)
|
||||
|
||||
end_mem_info = self.get_memory_info()
|
||||
end_time = time.time()
|
||||
|
||||
start_mem = operation_data["start_mem"].get("rss_mb", 0)
|
||||
end_mem = end_mem_info.get("rss_mb", 0)
|
||||
mem_delta = end_mem - start_mem
|
||||
time_delta = end_time - operation_data["start_time"]
|
||||
operation_name = operation_data["operation_name"]
|
||||
|
||||
# Record measurement
|
||||
async with self.lock:
|
||||
self.measurements[operation_name].append(
|
||||
{
|
||||
"memory_delta_mb": mem_delta,
|
||||
"peak_memory_mb": end_mem,
|
||||
"time_seconds": time_delta,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": error,
|
||||
"system_percent": end_mem_info.get("system_percent", 0),
|
||||
}
|
||||
)
|
||||
|
||||
self._log_memory_status(operation_name, mem_delta, end_mem, time_delta, end_mem_info, error, operation_data)
|
||||
|
||||
def _log_operation_completion_sync(self, operation_id: str, error: Optional[str] = None):
|
||||
"""Log memory usage immediately when an operation completes (sync version for non-async functions)."""
|
||||
if operation_id not in self.active_operations:
|
||||
return
|
||||
|
||||
operation_data = self.active_operations.pop(operation_id)
|
||||
end_mem_info = self.get_memory_info()
|
||||
end_time = time.time()
|
||||
|
||||
start_mem = operation_data["start_mem"].get("rss_mb", 0)
|
||||
end_mem = end_mem_info.get("rss_mb", 0)
|
||||
mem_delta = end_mem - start_mem
|
||||
time_delta = end_time - operation_data["start_time"]
|
||||
operation_name = operation_data["operation_name"]
|
||||
|
||||
# Record measurement (sync version doesn't use async lock)
|
||||
self.measurements[operation_name].append(
|
||||
{
|
||||
"memory_delta_mb": mem_delta,
|
||||
"peak_memory_mb": end_mem,
|
||||
"time_seconds": time_delta,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": error,
|
||||
"system_percent": end_mem_info.get("system_percent", 0),
|
||||
}
|
||||
)
|
||||
|
||||
self._log_memory_status(operation_name, mem_delta, end_mem, time_delta, end_mem_info, error, operation_data)
|
||||
|
||||
def _log_memory_status(
|
||||
self,
|
||||
operation_name: str,
|
||||
mem_delta: float,
|
||||
end_mem: float,
|
||||
time_delta: float,
|
||||
end_mem_info: Dict,
|
||||
error: Optional[str],
|
||||
operation_data: Dict,
|
||||
):
|
||||
"""Common logging logic for memory status."""
|
||||
# Determine log level based on memory situation
|
||||
if error:
|
||||
logger.error(
|
||||
f"Operation '{operation_name}' failed after {time_delta:.2f}s - "
|
||||
f"Memory: {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB), "
|
||||
f"System: {end_mem_info.get('system_percent', 0):.1f}%, "
|
||||
f"Error: {error}"
|
||||
)
|
||||
elif mem_delta > self.SPIKE_THRESHOLD_MB:
|
||||
logger.warning(
|
||||
f"MEMORY SPIKE: Operation '{operation_name}' - "
|
||||
f"Increased by {mem_delta:.2f} MB in {time_delta:.2f}s - "
|
||||
f"Current: {end_mem:.2f} MB, System: {end_mem_info.get('system_percent', 0):.1f}%"
|
||||
)
|
||||
|
||||
# Log stack trace for large spikes
|
||||
if mem_delta > self.SPIKE_THRESHOLD_MB * 2:
|
||||
stack = operation_data.get("stack", [])
|
||||
if stack and len(stack) > 3:
|
||||
logger.warning("Call stack for memory spike:")
|
||||
for frame in stack[-5:]:
|
||||
logger.warning(f" {frame.filename}:{frame.lineno} in {frame.name}")
|
||||
|
||||
elif end_mem > self.CRITICAL_THRESHOLD_MB:
|
||||
logger.error(
|
||||
f"CRITICAL MEMORY: Operation '{operation_name}' completed - "
|
||||
f"Memory at {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB), "
|
||||
f"System: {end_mem_info.get('system_percent', 0):.1f}%"
|
||||
)
|
||||
elif end_mem > self.WARNING_THRESHOLD_MB:
|
||||
logger.warning(f"High memory after '{operation_name}': {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB) in {time_delta:.2f}s")
|
||||
else:
|
||||
# Only log normal operations in debug mode
|
||||
logger.debug(f"Operation '{operation_name}' completed: Memory {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB) in {time_delta:.2f}s")
|
||||
|
||||
async def start_background_monitor(self):
|
||||
"""Start the background memory monitoring task."""
|
||||
if self._monitoring:
|
||||
return
|
||||
|
||||
self._monitoring = True
|
||||
self._monitor_started = True
|
||||
self._monitor_task = asyncio.create_task(self._monitor_loop())
|
||||
logger.info(f"Background memory monitor started (interval: {self.monitor_interval}s)")
|
||||
|
||||
async def stop_background_monitor(self):
|
||||
"""Stop the background memory monitoring task."""
|
||||
self._monitoring = False
|
||||
self._monitor_started = False
|
||||
if self._monitor_task:
|
||||
self._monitor_task.cancel()
|
||||
try:
|
||||
await self._monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Background memory monitor stopped")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""Background monitoring loop that runs continuously using asyncio."""
|
||||
consecutive_high_memory = 0
|
||||
last_gc_time = time.time()
|
||||
|
||||
while self._monitoring:
|
||||
try:
|
||||
mem_info = self.get_memory_info()
|
||||
current_mb = mem_info.get("rss_mb", 0)
|
||||
system_percent = mem_info.get("system_percent", 0)
|
||||
|
||||
# Add to history
|
||||
async with self.lock:
|
||||
self.memory_history.append(mem_info)
|
||||
if len(self.memory_history) > self.max_history_size:
|
||||
self.memory_history.pop(0)
|
||||
|
||||
# Check memory levels
|
||||
self._check_memory_thresholds(mem_info)
|
||||
|
||||
# Track consecutive high memory readings
|
||||
if current_mb > self.WARNING_THRESHOLD_MB:
|
||||
consecutive_high_memory += 1
|
||||
else:
|
||||
consecutive_high_memory = 0
|
||||
|
||||
# Force GC if memory is consistently high
|
||||
if consecutive_high_memory >= 3 and time.time() - last_gc_time > 30:
|
||||
await asyncio.to_thread(self._force_gc_with_logging)
|
||||
last_gc_time = time.time()
|
||||
|
||||
# Check for memory leak patterns
|
||||
if len(self.memory_history) >= 10:
|
||||
await self._check_memory_trend()
|
||||
|
||||
# Log active operations if memory is critical
|
||||
if system_percent > self.MEMORY_PERCENT_CRITICAL:
|
||||
await self._log_active_operations()
|
||||
|
||||
await asyncio.sleep(self.monitor_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in memory monitor loop: {e}")
|
||||
await asyncio.sleep(self.monitor_interval)
|
||||
|
||||
def _check_memory_thresholds(self, mem_info: Dict[str, Any]):
|
||||
"""Check memory against thresholds and log appropriately."""
|
||||
current_mb = mem_info.get("rss_mb", 0)
|
||||
system_percent = mem_info.get("system_percent", 0)
|
||||
|
||||
# Check system memory percentage thresholds
|
||||
if system_percent > self.MEMORY_PERCENT_FATAL and "fatal" not in self._warned_levels:
|
||||
logger.critical(
|
||||
f"FATAL MEMORY LEVEL: {system_percent:.1f}% of system memory used! Process: {current_mb:.2f} MB - OOM imminent!"
|
||||
)
|
||||
self._warned_levels.add("fatal")
|
||||
self._dump_memory_state()
|
||||
|
||||
elif system_percent > self.MEMORY_PERCENT_CRITICAL and "critical" not in self._warned_levels:
|
||||
logger.error(f"CRITICAL: System memory at {system_percent:.1f}% - Process using {current_mb:.2f} MB")
|
||||
self._warned_levels.add("critical")
|
||||
self._dump_active_operations_summary()
|
||||
|
||||
elif system_percent > self.MEMORY_PERCENT_WARNING and "warning" not in self._warned_levels:
|
||||
logger.warning(f"Memory warning: System at {system_percent:.1f}% - Process: {current_mb:.2f} MB")
|
||||
self._warned_levels.add("warning")
|
||||
|
||||
# Reset warning levels if memory drops
|
||||
if system_percent < self.MEMORY_PERCENT_WARNING:
|
||||
self._warned_levels.clear()
|
||||
|
||||
async def _check_memory_trend(self):
|
||||
"""Check if memory is trending upward (potential leak)."""
|
||||
async with self.lock:
|
||||
if len(self.memory_history) < 10:
|
||||
return
|
||||
|
||||
# Get the last 10 readings
|
||||
recent = self.memory_history[-10:]
|
||||
|
||||
# Calculate trend
|
||||
first_mb = recent[0].get("rss_mb", 0)
|
||||
last_mb = recent[-1].get("rss_mb", 0)
|
||||
increase_mb = last_mb - first_mb
|
||||
|
||||
# Calculate average rate of increase
|
||||
time_span = 10 * self.monitor_interval # seconds
|
||||
rate_mb_per_min = (increase_mb / time_span) * 60
|
||||
|
||||
# Warn if memory is increasing rapidly
|
||||
if rate_mb_per_min > 50: # More than 50MB/minute
|
||||
logger.warning(
|
||||
f"MEMORY LEAK SUSPECTED: Memory increasing at {rate_mb_per_min:.1f} MB/min - From {first_mb:.2f} MB to {last_mb:.2f} MB"
|
||||
)
|
||||
self._dump_active_operations_summary()
|
||||
|
||||
def _force_gc_with_logging(self):
|
||||
"""Force garbage collection and log the results."""
|
||||
before_mb = self.get_memory_info().get("rss_mb", 0)
|
||||
collected = gc.collect()
|
||||
after_mb = self.get_memory_info().get("rss_mb", 0)
|
||||
freed_mb = before_mb - after_mb
|
||||
|
||||
if freed_mb > 10: # Only log if significant memory was freed
|
||||
logger.info(f"Garbage collection freed {freed_mb:.2f} MB (collected {collected} objects)")
|
||||
|
||||
async def _log_active_operations(self):
|
||||
"""Log currently active operations."""
|
||||
async with self.lock:
|
||||
if not self.active_operations:
|
||||
return
|
||||
|
||||
logger.warning(f"Active operations during high memory ({len(self.active_operations)} running):")
|
||||
for op_id, op_data in self.active_operations.items():
|
||||
duration = time.time() - op_data["start_time"]
|
||||
logger.warning(
|
||||
f" - {op_data['operation_name']}: running for {duration:.1f}s, started at {op_data['start_mem'].get('rss_mb', 0):.2f} MB"
|
||||
)
|
||||
|
||||
def _dump_active_operations_summary(self):
|
||||
"""Dump summary of recent operations."""
|
||||
logger.warning("=== Recent Operation Memory Usage ===")
|
||||
for operation_name, measurements in self.measurements.items():
|
||||
if measurements:
|
||||
recent = measurements[-5:] # Last 5 measurements
|
||||
total_delta = sum(m["memory_delta_mb"] for m in recent)
|
||||
avg_delta = total_delta / len(recent)
|
||||
max_peak = max(m["peak_memory_mb"] for m in recent)
|
||||
|
||||
logger.warning(f"{operation_name}: Avg Δ{avg_delta:+.1f} MB, Peak {max_peak:.1f} MB ({len(recent)} recent ops)")
|
||||
|
||||
def _dump_memory_state(self):
|
||||
"""Dump detailed memory state for debugging."""
|
||||
logger.critical("=== MEMORY STATE DUMP ===")
|
||||
|
||||
# Current memory
|
||||
mem_info = self.get_memory_info()
|
||||
logger.critical(f"Current memory: {json.dumps(mem_info, indent=2)}")
|
||||
|
||||
# Top memory consumers by operation
|
||||
logger.critical("Top memory consuming operations:")
|
||||
for op_name, measurements in sorted(self.measurements.items(), key=lambda x: sum(m["memory_delta_mb"] for m in x[1]), reverse=True)[
|
||||
:5
|
||||
]:
|
||||
if measurements:
|
||||
total = sum(m["memory_delta_mb"] for m in measurements)
|
||||
logger.critical(f" {op_name}: {total:.1f} MB total across {len(measurements)} calls")
|
||||
|
||||
# Active operations
|
||||
if self.active_operations:
|
||||
logger.critical(f"Active operations: {len(self.active_operations)}")
|
||||
for op_data in self.active_operations.values():
|
||||
logger.critical(f" - {op_data['operation_name']}")
|
||||
|
||||
# System info
|
||||
logger.critical(f"Uptime: {datetime.now() - self.start_time}")
|
||||
logger.critical(f"PID: {os.getpid()}")
|
||||
|
||||
# Attempt to get top memory objects (if available)
|
||||
try:
|
||||
import sys
|
||||
|
||||
logger.critical("Top objects by count:")
|
||||
obj_counts = defaultdict(int)
|
||||
for obj in gc.get_objects()[:1000]: # Sample first 1000 objects
|
||||
obj_counts[type(obj).__name__] += 1
|
||||
|
||||
for obj_type, count in sorted(obj_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
|
||||
logger.critical(f" {obj_type}: {count}")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not get object counts: {e}")
|
||||
|
||||
def get_report(self) -> str:
|
||||
"""Generate a summary report of memory usage."""
|
||||
lines = []
|
||||
lines.append("=== MEMORY USAGE REPORT ===")
|
||||
lines.append(f"Uptime: {datetime.now() - self.start_time}")
|
||||
|
||||
current = self.get_memory_info()
|
||||
lines.append(f"Current memory: {current.get('rss_mb', 0):.2f} MB")
|
||||
lines.append(f"System memory: {current.get('system_percent', 0):.1f}%")
|
||||
|
||||
lines.append("\nOperations summary:")
|
||||
for operation_name, measurements in self.measurements.items():
|
||||
if measurements:
|
||||
total_mem = sum(m["memory_delta_mb"] for m in measurements)
|
||||
avg_mem = total_mem / len(measurements)
|
||||
max_mem = max(m["memory_delta_mb"] for m in measurements)
|
||||
errors = sum(1 for m in measurements if m.get("error"))
|
||||
|
||||
lines.append(
|
||||
f" {operation_name}: "
|
||||
f"{len(measurements)} calls, "
|
||||
f"Avg Δ{avg_mem:+.1f} MB, "
|
||||
f"Max Δ{max_mem:+.1f} MB, "
|
||||
f"Total Δ{total_mem:+.1f} MB"
|
||||
)
|
||||
if errors:
|
||||
lines.append(f" ({errors} errors)")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Global tracker instance
|
||||
_global_tracker = None
|
||||
|
||||
|
||||
def get_memory_tracker(enable_background_monitor: bool = True, monitor_interval: int = 5) -> MemoryTracker:
|
||||
"""Get or create the global memory tracker instance."""
|
||||
global _global_tracker
|
||||
if _global_tracker is None:
|
||||
_global_tracker = MemoryTracker(enable_background_monitor=enable_background_monitor, monitor_interval=monitor_interval)
|
||||
return _global_tracker
|
||||
|
||||
|
||||
def track_operation(operation_name: str):
|
||||
"""
|
||||
Convenience decorator that uses the global tracker.
|
||||
|
||||
Usage:
|
||||
@track_operation("my_operation")
|
||||
async def my_function():
|
||||
...
|
||||
"""
|
||||
tracker = get_memory_tracker()
|
||||
return tracker.track_operation(operation_name)
|
||||
274
letta/monitoring/request_monitor.py
Normal file
274
letta/monitoring/request_monitor.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Request size monitoring middleware for Letta application.
|
||||
Tracks incoming request sizes to identify large uploads causing SSL memory spikes.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.monitoring import get_memory_tracker
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Size thresholds (in bytes)
|
||||
SIZE_WARNING_THRESHOLD = 10 * 1024 * 1024 # 10MB
|
||||
SIZE_ERROR_THRESHOLD = 50 * 1024 * 1024 # 50MB
|
||||
SIZE_CRITICAL_THRESHOLD = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
|
||||
class RequestSizeMonitoringMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to monitor incoming request sizes and detect large uploads.
|
||||
This helps identify if SSL memory spikes are from large incoming data.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.tracker = get_memory_tracker()
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
start_time = time.time()
|
||||
request_size = 0
|
||||
content_type = request.headers.get("content-type", "")
|
||||
|
||||
# Track the endpoint
|
||||
endpoint = f"{request.method} {request.url.path}"
|
||||
|
||||
# Special monitoring for known large data endpoints
|
||||
critical_endpoints = [
|
||||
("/upload", "file upload"),
|
||||
("/messages", "message creation"),
|
||||
("/agents", "agent creation/update"),
|
||||
("/memory", "memory update"),
|
||||
("/sources", "source upload"),
|
||||
("/folders", "folder upload"),
|
||||
]
|
||||
|
||||
is_critical = any(pattern in request.url.path.lower() for pattern, _ in critical_endpoints)
|
||||
endpoint_type = next((desc for pattern, desc in critical_endpoints if pattern in request.url.path.lower()), "general")
|
||||
|
||||
# Try to get content length from headers
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length:
|
||||
request_size = int(content_length)
|
||||
|
||||
# Get memory before processing
|
||||
memory_before = self.tracker.get_memory_info()
|
||||
|
||||
# Enhanced monitoring for critical endpoints
|
||||
if is_critical and request_size > 1024 * 1024: # Log all critical endpoints > 1MB
|
||||
logger.info(
|
||||
f"Critical endpoint access: {endpoint_type} - {endpoint} - "
|
||||
f"Size: {request_size / 1024 / 1024:.2f} MB - "
|
||||
f"Memory before: {memory_before.get('rss_mb', 0):.2f} MB"
|
||||
)
|
||||
|
||||
# Log large incoming requests BEFORE processing
|
||||
if request_size > SIZE_CRITICAL_THRESHOLD:
|
||||
logger.critical(
|
||||
f"CRITICAL: Large request incoming - {endpoint} ({endpoint_type}) - "
|
||||
f"Size: {request_size / 1024 / 1024:.2f} MB - "
|
||||
f"Content-Type: {content_type} - "
|
||||
f"Memory before: {memory_before.get('rss_mb', 0):.2f} MB"
|
||||
)
|
||||
elif request_size > SIZE_ERROR_THRESHOLD:
|
||||
logger.error(
|
||||
f"Large request detected - {endpoint} ({endpoint_type}) - "
|
||||
f"Size: {request_size / 1024 / 1024:.2f} MB - "
|
||||
f"Content-Type: {content_type}"
|
||||
)
|
||||
elif request_size > SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Sizeable request - {endpoint} ({endpoint_type}) - "
|
||||
f"Size: {request_size / 1024 / 1024:.2f} MB - "
|
||||
f"Content-Type: {content_type}"
|
||||
)
|
||||
|
||||
# For multipart/form-data (file uploads), try to get more details
|
||||
if "multipart/form-data" in content_type:
|
||||
logger.info(f"File upload detected at {endpoint} - Expected size: {request_size / 1024 / 1024:.2f} MB")
|
||||
# Note: The actual file reading happens when the endpoint accesses request.form()
|
||||
# That's when SSL read would spike
|
||||
|
||||
# Track the operation with memory monitoring
|
||||
operation_name = f"request_{request.method}_{request.url.path.replace('/', '_')}"
|
||||
|
||||
# For large JSON payloads, log structure details
|
||||
if request_size > SIZE_WARNING_THRESHOLD and "application/json" in content_type:
|
||||
# Create a copy of the request for body logging
|
||||
body_logger = RequestBodyLogger()
|
||||
try:
|
||||
# Clone the body for inspection (this won't consume the original)
|
||||
body = await request.body()
|
||||
|
||||
# Put it back for the actual handler
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body}
|
||||
|
||||
request._receive = receive
|
||||
|
||||
# Log the structure
|
||||
if body:
|
||||
await self._log_json_structure(body, endpoint)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not inspect request body: {e}")
|
||||
|
||||
try:
|
||||
# Process the request with memory tracking
|
||||
# Use the decorator by creating a wrapped function
|
||||
@self.tracker.track_operation(operation_name)
|
||||
async def process_request():
|
||||
return await call_next(request)
|
||||
|
||||
response = await process_request()
|
||||
|
||||
# Get memory after processing
|
||||
memory_after = self.tracker.get_memory_info()
|
||||
memory_delta = memory_after.get("rss_mb", 0) - memory_before.get("rss_mb", 0)
|
||||
|
||||
# Log if memory increased significantly during request processing
|
||||
if memory_delta > 100: # More than 100MB increase
|
||||
process_time = time.time() - start_time
|
||||
logger.error(
|
||||
f"MEMORY SPIKE during request: {endpoint} - "
|
||||
f"Memory increased by {memory_delta:.2f} MB - "
|
||||
f"Request size: {request_size / 1024 / 1024:.2f} MB - "
|
||||
f"Process time: {process_time:.2f}s - "
|
||||
f"Content-Type: {content_type}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Log any errors with context
|
||||
logger.error(f"Error processing request {endpoint} - Size: {request_size / 1024 / 1024:.2f} MB - Error: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _log_json_structure(self, body: bytes, endpoint: str):
|
||||
"""Helper method to log JSON body structure for large payloads."""
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(body)
|
||||
body_size = len(body)
|
||||
|
||||
# Calculate field sizes
|
||||
field_sizes = {}
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (str, bytes)):
|
||||
field_sizes[key] = len(value)
|
||||
elif isinstance(value, (list, dict)):
|
||||
field_sizes[key] = len(json.dumps(value))
|
||||
else:
|
||||
field_sizes[key] = 0
|
||||
elif isinstance(data, list):
|
||||
field_sizes["list_items"] = len(data)
|
||||
if data:
|
||||
# Sample first item size
|
||||
field_sizes["first_item_size"] = len(json.dumps(data[0]))
|
||||
|
||||
# Find largest fields
|
||||
if field_sizes:
|
||||
largest_fields = sorted(field_sizes.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
|
||||
logger.warning(
|
||||
f"Large JSON payload structure at {endpoint}:\n"
|
||||
f" Total size: {body_size / 1024 / 1024:.2f} MB\n"
|
||||
f" Top fields by size:\n"
|
||||
+ "\n".join([f" - {k}: {v / 1024 / 1024:.2f} MB" for k, v in largest_fields if v > 1024]) # Only show fields > 1KB
|
||||
)
|
||||
|
||||
# Special monitoring for known problematic fields
|
||||
problematic_fields = ["messages", "memory", "system", "tools", "files", "context"]
|
||||
for field in problematic_fields:
|
||||
if field in field_sizes and field_sizes[field] > 5 * 1024 * 1024: # > 5MB
|
||||
logger.error(f"LARGE FIELD DETECTED: '{field}' is {field_sizes[field] / 1024 / 1024:.2f} MB at {endpoint}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Could not parse JSON body at {endpoint} (size: {len(body) / 1024 / 1024:.2f} MB)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing JSON structure: {e}")
|
||||
|
||||
|
||||
class RequestBodyLogger:
|
||||
"""
|
||||
Utility to log request body details for specific endpoints.
|
||||
Useful for debugging which fields contain large data.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def log_json_body_structure(request: Request, endpoint: str):
|
||||
"""Log the structure and size of JSON request bodies."""
|
||||
try:
|
||||
# Only for JSON requests
|
||||
if "application/json" not in request.headers.get("content-type", ""):
|
||||
return
|
||||
|
||||
# Get the body
|
||||
body = await request.body()
|
||||
body_size = len(body)
|
||||
|
||||
if body_size > SIZE_WARNING_THRESHOLD:
|
||||
# Try to parse JSON to understand structure
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(body)
|
||||
|
||||
# Log field sizes
|
||||
field_sizes = {}
|
||||
for key, value in data.items() if isinstance(data, dict) else enumerate(data):
|
||||
if isinstance(value, (str, bytes)):
|
||||
field_sizes[key] = len(value)
|
||||
elif isinstance(value, (list, dict)):
|
||||
field_sizes[key] = len(json.dumps(value))
|
||||
else:
|
||||
field_sizes[key] = 0
|
||||
|
||||
# Find largest fields
|
||||
largest_fields = sorted(field_sizes.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
|
||||
logger.warning(
|
||||
f"Large JSON body structure at {endpoint}:\n"
|
||||
f" Total size: {body_size / 1024 / 1024:.2f} MB\n"
|
||||
f" Top fields by size:\n" + "\n".join([f" - {k}: {v / 1024 / 1024:.2f} MB" for k, v in largest_fields])
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Could not parse JSON body at {endpoint} (size: {body_size / 1024 / 1024:.2f} MB)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging request body structure: {e}")
|
||||
|
||||
|
||||
def identify_upload_endpoints(app):
|
||||
"""
|
||||
Scan the app routes to identify potential upload endpoints.
|
||||
"""
|
||||
upload_endpoints = []
|
||||
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and hasattr(route, "methods"):
|
||||
path = route.path
|
||||
methods = route.methods
|
||||
|
||||
# Look for common upload patterns
|
||||
upload_keywords = ["upload", "file", "attachment", "media", "document", "image", "import"]
|
||||
if any(keyword in path.lower() for keyword in upload_keywords):
|
||||
upload_endpoints.append((path, methods))
|
||||
|
||||
# Also check for POST/PUT to any endpoint (potential large JSON)
|
||||
if "POST" in methods or "PUT" in methods or "PATCH" in methods:
|
||||
upload_endpoints.append((path, methods))
|
||||
|
||||
logger.info("Potential upload/large data endpoints identified:")
|
||||
for path, methods in upload_endpoints[:20]: # Log first 20
|
||||
logger.info(f" {', '.join(methods)}: {path}")
|
||||
|
||||
return upload_endpoints
|
||||
@@ -14,6 +14,14 @@ import uvicorn
|
||||
|
||||
# Enable Python fault handler to get stack traces on segfaults
|
||||
faulthandler.enable()
|
||||
|
||||
# Import memory tracking (if available)
|
||||
try:
|
||||
from letta.monitoring import RequestSizeMonitoringMiddleware, get_memory_tracker, identify_upload_endpoints
|
||||
|
||||
MEMORY_TRACKING_ENABLED = True
|
||||
except ImportError:
|
||||
MEMORY_TRACKING_ENABLED = False
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -133,6 +141,13 @@ async def lifespan(app_: FastAPI):
|
||||
"""
|
||||
worker_id = os.getpid()
|
||||
|
||||
# Initialize memory tracking
|
||||
if MEMORY_TRACKING_ENABLED:
|
||||
logger.info(f"[Worker {worker_id}] Initializing memory tracking")
|
||||
# Get the global tracker instance (will start background monitor automatically)
|
||||
tracker = get_memory_tracker(enable_background_monitor=True, monitor_interval=5)
|
||||
logger.info(f"[Worker {worker_id}] Memory tracking enabled - monitoring every 5s with proactive alerts")
|
||||
|
||||
if telemetry_settings.profiler:
|
||||
try:
|
||||
import googlecloudprofiler
|
||||
@@ -174,6 +189,14 @@ async def lifespan(app_: FastAPI):
|
||||
|
||||
# Cleanup on shutdown
|
||||
logger.info(f"[Worker {worker_id}] Starting lifespan shutdown")
|
||||
|
||||
# Report memory usage before shutdown
|
||||
if MEMORY_TRACKING_ENABLED:
|
||||
logger.info(f"[Worker {worker_id}] Generating final memory report")
|
||||
tracker = get_memory_tracker()
|
||||
report = tracker.get_report()
|
||||
logger.info(f"[Worker {worker_id}] Memory report:\n{report}")
|
||||
|
||||
try:
|
||||
from letta.jobs.scheduler import shutdown_scheduler_and_release_lock
|
||||
|
||||
@@ -556,6 +579,13 @@ def create_application() -> "FastAPI":
|
||||
# Add unified logging middleware - enriches log context and logs exceptions
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
# Add request size monitoring middleware to detect large uploads
|
||||
if MEMORY_TRACKING_ENABLED:
|
||||
app.add_middleware(RequestSizeMonitoringMiddleware)
|
||||
logger.info("Request size monitoring middleware enabled")
|
||||
# Identify potential upload endpoints
|
||||
identify_upload_endpoints(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
|
||||
@@ -28,6 +28,7 @@ from letta.errors import (
|
||||
from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
||||
from letta.log import get_logger
|
||||
from letta.monitoring import track_operation
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.otel.context import get_ctx_attributes
|
||||
from letta.otel.metric_registry import MetricRegistry
|
||||
@@ -329,6 +330,7 @@ async def _import_agent(
|
||||
|
||||
|
||||
@router.post("/import", response_model=ImportedAgentsResponse, operation_id="import_agent")
|
||||
@track_operation("import_agent")
|
||||
async def import_agent(
|
||||
file: UploadFile = File(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -437,6 +439,7 @@ class CreateAgentRequest(CreateAgent):
|
||||
|
||||
|
||||
@router.post("/", response_model=AgentState, operation_id="create_agent")
|
||||
@track_operation("create_agent")
|
||||
async def create_agent(
|
||||
agent: CreateAgentRequest = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
|
||||
@@ -17,6 +17,7 @@ from letta.helpers.pinecone_utils import (
|
||||
)
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.monitoring import track_operation
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -227,6 +228,7 @@ async def delete_folder(
|
||||
|
||||
|
||||
@router.post("/{folder_id}/upload", response_model=FileMetadata, operation_id="upload_file_to_folder")
|
||||
@track_operation("file_upload_to_folder")
|
||||
async def upload_file_to_folder(
|
||||
file: UploadFile,
|
||||
folder_id: FolderId,
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from letta.monitoring.memory_tracker import track_operation
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.model import EmbeddingModel, Model
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
@@ -13,6 +14,7 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Model], operation_id="list_models")
|
||||
@track_operation("list_llm_models_endpoint")
|
||||
async def list_llm_models(
|
||||
provider_category: Optional[List[ProviderCategory]] = Query(None),
|
||||
provider_name: Optional[str] = Query(None),
|
||||
@@ -40,6 +42,7 @@ async def list_llm_models(
|
||||
|
||||
|
||||
@router.get("/embedding", response_model=List[EmbeddingModel], operation_id="list_embedding_models")
|
||||
@track_operation("list_embedding_models_endpoint")
|
||||
async def list_embedding_models(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
|
||||
@@ -17,6 +17,7 @@ from letta.helpers.pinecone_utils import (
|
||||
)
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.monitoring import track_operation
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -207,6 +208,7 @@ async def delete_source(
|
||||
|
||||
|
||||
@router.post("/{source_id}/upload", response_model=FileMetadata, operation_id="upload_file_to_source", deprecated=True)
|
||||
@track_operation("file_upload_to_source")
|
||||
async def upload_file_to_source(
|
||||
file: UploadFile,
|
||||
source_id: SourceId,
|
||||
|
||||
@@ -31,6 +31,7 @@ from letta.interface import (
|
||||
CLIInterface, # for printing to terminal
|
||||
)
|
||||
from letta.log import get_logger
|
||||
from letta.monitoring.memory_tracker import track_operation
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.prompts.gpt_system import get_system_text
|
||||
@@ -972,6 +973,7 @@ class SyncServer(object):
|
||||
return passage_count, document_count
|
||||
|
||||
@trace_method
|
||||
@track_operation("list_llm_models_server")
|
||||
async def list_llm_models_async(
|
||||
self,
|
||||
actor: User,
|
||||
@@ -1023,6 +1025,7 @@ class SyncServer(object):
|
||||
|
||||
return unique_models
|
||||
|
||||
@track_operation("list_embedding_models_server")
|
||||
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""Asynchronously list available embedding models with maximum concurrency"""
|
||||
import asyncio
|
||||
@@ -1049,6 +1052,7 @@ class SyncServer(object):
|
||||
|
||||
return embedding_models
|
||||
|
||||
@track_operation("get_enabled_providers")
|
||||
async def get_enabled_providers_async(
|
||||
self,
|
||||
actor: User,
|
||||
|
||||
@@ -7,6 +7,19 @@ import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, insert, literal, or_, select, tuple_
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
# Import memory tracking if available
|
||||
try:
|
||||
from letta.monitoring import track_operation
|
||||
|
||||
MEMORY_TRACKING_ENABLED = True
|
||||
except ImportError:
|
||||
MEMORY_TRACKING_ENABLED = False
|
||||
|
||||
# Define a no-op decorator if tracking is not available
|
||||
def track_operation(name):
|
||||
return lambda f: f
|
||||
|
||||
|
||||
from letta.constants import (
|
||||
BASE_MEMORY_TOOLS,
|
||||
BASE_MEMORY_TOOLS_V2,
|
||||
@@ -327,6 +340,7 @@ class AgentManager:
|
||||
# ======================================================================================================================
|
||||
|
||||
@trace_method
|
||||
@track_operation("agent_creation")
|
||||
async def create_agent_async(
|
||||
self,
|
||||
agent_create: CreateAgent,
|
||||
|
||||
@@ -4,6 +4,20 @@ from typing import Any, List, Optional, Tuple
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
# Import memory tracking if available
|
||||
try:
|
||||
from letta.monitoring import track_operation
|
||||
|
||||
MEMORY_TRACKING_ENABLED = True
|
||||
except ImportError:
|
||||
MEMORY_TRACKING_ENABLED = False
|
||||
|
||||
# Define a no-op decorator if tracking is not available
|
||||
def track_operation(name):
|
||||
return lambda f: f
|
||||
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
@@ -96,6 +110,7 @@ class ContextWindowCalculator:
|
||||
|
||||
return None, 1
|
||||
|
||||
@track_operation("calculate_context_window")
|
||||
async def calculate_context_window(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import selectinload
|
||||
from letta.constants import MAX_FILENAME_LENGTH
|
||||
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
|
||||
from letta.log import get_logger
|
||||
from letta.monitoring import track_operation
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
@@ -53,6 +54,7 @@ class FileManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("create_file_with_content")
|
||||
async def create_file(
|
||||
self,
|
||||
file_metadata: PydanticFileMetadata,
|
||||
@@ -356,6 +358,7 @@ class FileManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
@track_operation("upsert_file_content")
|
||||
async def upsert_file_content(
|
||||
self,
|
||||
*,
|
||||
@@ -402,6 +405,7 @@ class FileManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@track_operation("list_files_with_content")
|
||||
async def list_files(
|
||||
self,
|
||||
source_id: str,
|
||||
@@ -629,6 +633,7 @@ class FileManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("batch_get_files_by_ids")
|
||||
async def get_files_by_ids_async(
|
||||
self, file_ids: List[str], actor: PydanticUser, *, include_content: bool = False
|
||||
) -> List[PydanticFileMetadata]:
|
||||
@@ -664,6 +669,7 @@ class FileManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("batch_get_files_for_agents")
|
||||
async def get_files_for_agents_async(
|
||||
self, agent_ids: List[str], actor: PydanticUser, *, include_content: bool = False
|
||||
) -> List[PydanticFileMetadata]:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Optional, Union
|
||||
from sqlalchemy import and_, asc, delete, desc, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.monitoring import track_operation
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -73,6 +74,7 @@ class GroupManager:
|
||||
return group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@track_operation("create_multi_agent_group")
|
||||
async def create_group_async(self, group: Union[GroupCreate, InternalTemplateGroupCreate], actor: PydanticUser) -> PydanticGroup:
|
||||
async with db_registry.async_session() as session:
|
||||
new_group = GroupModel()
|
||||
@@ -197,6 +199,7 @@ class GroupManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@track_operation("list_multi_agent_messages")
|
||||
async def list_group_messages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -320,6 +323,7 @@ class GroupManager:
|
||||
else:
|
||||
raise ValueError("Extend relationship is not supported for groups.")
|
||||
|
||||
@track_operation("process_multi_agent_relationships")
|
||||
async def _process_agent_relationship_async(self, session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
|
||||
if not agent_ids:
|
||||
if replace:
|
||||
|
||||
@@ -7,6 +7,20 @@ from sqlalchemy import delete, exists, func, select, text
|
||||
|
||||
from letta.constants import CONVERSATION_SEARCH_TOOL_NAME, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
|
||||
# Import memory tracking if available
|
||||
try:
|
||||
from letta.monitoring import track_operation
|
||||
|
||||
MEMORY_TRACKING_ENABLED = True
|
||||
except ImportError:
|
||||
MEMORY_TRACKING_ENABLED = False
|
||||
|
||||
# Define a no-op decorator if tracking is not available
|
||||
def track_operation(name):
|
||||
return lambda f: f
|
||||
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
@@ -441,6 +455,7 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("bulk_create_messages")
|
||||
async def create_many_messages_async(
|
||||
self,
|
||||
pydantic_msgs: List[PydanticMessage],
|
||||
@@ -825,6 +840,7 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("list_messages")
|
||||
async def list_messages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -949,6 +965,7 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("bulk_delete_all_agent_messages")
|
||||
async def delete_all_messages_for_agent_async(
|
||||
self, agent_id: str, actor: PydanticUser, exclude_ids: Optional[List[str]] = None, strict_mode: bool = False
|
||||
) -> int:
|
||||
@@ -997,6 +1014,7 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("bulk_delete_messages_by_ids")
|
||||
async def delete_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser, strict_mode: bool = False) -> int:
|
||||
"""
|
||||
Efficiently deletes messages by their specific IDs,
|
||||
@@ -1044,6 +1062,7 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("search_messages")
|
||||
async def search_messages_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -1164,6 +1183,7 @@ class MessageManager:
|
||||
message_tuples.append((message, metadata))
|
||||
return message_tuples
|
||||
|
||||
@track_operation("search_messages_org")
|
||||
async def search_messages_org_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.monitoring import track_operation
|
||||
from letta.orm import ArchivesAgents
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.passage import ArchivalPassage, SourcePassage
|
||||
@@ -302,6 +303,7 @@ class PassageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("batch_create_archival_passages")
|
||||
async def create_many_archival_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple archival passages."""
|
||||
archival_passages = []
|
||||
@@ -354,6 +356,7 @@ class PassageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("batch_create_source_passages")
|
||||
async def create_many_source_passages_async(
|
||||
self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser
|
||||
) -> List[PydanticPassage]:
|
||||
@@ -451,6 +454,7 @@ class PassageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("insert_passages_with_embeddings")
|
||||
async def insert_passage(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
@@ -545,6 +549,7 @@ class PassageManager:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@track_operation("generate_embeddings_batch")
|
||||
async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config, actor: PydanticUser) -> List[List[float]]:
|
||||
"""Generate embeddings for all text chunks concurrently using LLMClient"""
|
||||
|
||||
@@ -764,6 +769,7 @@ class PassageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("bulk_delete_agent_passages")
|
||||
async def delete_agent_passages_async(
|
||||
self,
|
||||
passages: List[PydanticPassage],
|
||||
@@ -818,6 +824,7 @@ class PassageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("bulk_delete_source_passages")
|
||||
async def delete_source_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
|
||||
@@ -3,6 +3,20 @@ from typing import List, Optional, Tuple, Union
|
||||
from letta.orm.provider import Provider as ProviderModel
|
||||
from letta.orm.provider_model import ProviderModel as ProviderModelORM
|
||||
from letta.otel.tracing import trace_method
|
||||
|
||||
# Import memory tracking if available
|
||||
try:
|
||||
from letta.monitoring import track_operation
|
||||
|
||||
MEMORY_TRACKING_ENABLED = True
|
||||
except ImportError:
|
||||
MEMORY_TRACKING_ENABLED = False
|
||||
|
||||
# Define a no-op decorator if tracking is not available
|
||||
def track_operation(name):
|
||||
return lambda f: f
|
||||
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -746,6 +760,7 @@ class ProviderManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("list_models")
|
||||
async def list_models_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy import and_, exists, select
|
||||
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.monitoring import track_operation
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.source import Source as SourceModel
|
||||
@@ -272,6 +273,7 @@ class SourceManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@track_operation("list_all_attached_agents")
|
||||
async def list_attached_agents(
|
||||
self, source_id: str, actor: PydanticUser, ids_only: bool = False
|
||||
) -> Union[List[PydanticAgentState], List[str]]:
|
||||
@@ -465,6 +467,7 @@ class SourceManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("batch_get_sources_by_ids")
|
||||
async def get_sources_by_ids_async(self, source_ids: List[str], actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""
|
||||
Get multiple sources by their IDs in a single query.
|
||||
@@ -491,6 +494,7 @@ class SourceManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@track_operation("batch_get_sources_for_agents")
|
||||
async def get_sources_for_agents_async(self, agent_ids: List[str], actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""
|
||||
Get all sources associated with the given agents via sources-agents relationships.
|
||||
|
||||
@@ -71,6 +71,7 @@ dependencies = [
|
||||
"readability-lxml",
|
||||
"google-genai>=1.15.0",
|
||||
"datadog>=0.49.1",
|
||||
"psutil>=5.9.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -2446,6 +2446,7 @@ dependencies = [
|
||||
{ name = "orjson" },
|
||||
{ name = "pathvalidate" },
|
||||
{ name = "prettytable" },
|
||||
{ name = "psutil" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyhumps" },
|
||||
@@ -2620,6 +2621,7 @@ requires-dist = [
|
||||
{ name = "pinecone", extras = ["asyncio"], marker = "extra == 'pinecone'", specifier = ">=7.3.0" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.5.0" },
|
||||
{ name = "prettytable", specifier = ">=3.9.0" },
|
||||
{ name = "psutil", specifier = ">=5.9.0" },
|
||||
{ name = "psycopg2", marker = "extra == 'postgres'", specifier = ">=2.9.10" },
|
||||
{ name = "psycopg2-binary", marker = "extra == 'postgres'", specifier = ">=2.9.10" },
|
||||
{ name = "pydantic", specifier = ">=2.10.6" },
|
||||
|
||||
Reference in New Issue
Block a user