From 848aa962b6132fb3099958ca07acb9041f09bbf7 Mon Sep 17 00:00:00 2001 From: Kian Jones <11655409+kianjones9@users.noreply.github.com> Date: Fri, 14 Nov 2025 13:24:50 -0800 Subject: [PATCH] feat: add memory tracking to core (#6179) * add memory tracking to core * move to asyncio from threading.Thread * remove threading.thread all the way * delay decorator monitoring initialization until after event loop is registered * context manager to decorator * add psutil --- letta/monitoring/__init__.py | 13 + letta/monitoring/memory_tracker.py | 552 ++++++++++++++++++ letta/monitoring/request_monitor.py | 274 +++++++++ letta/server/rest_api/app.py | 30 + letta/server/rest_api/routers/v1/agents.py | 3 + letta/server/rest_api/routers/v1/folders.py | 2 + letta/server/rest_api/routers/v1/llms.py | 3 + letta/server/rest_api/routers/v1/sources.py | 2 + letta/server/server.py | 4 + letta/services/agent_manager.py | 14 + .../context_window_calculator.py | 15 + letta/services/file_manager.py | 6 + letta/services/group_manager.py | 4 + letta/services/message_manager.py | 20 + letta/services/passage_manager.py | 7 + letta/services/provider_manager.py | 15 + letta/services/source_manager.py | 4 + pyproject.toml | 1 + uv.lock | 2 + 19 files changed, 971 insertions(+) create mode 100644 letta/monitoring/__init__.py create mode 100644 letta/monitoring/memory_tracker.py create mode 100644 letta/monitoring/request_monitor.py diff --git a/letta/monitoring/__init__.py b/letta/monitoring/__init__.py new file mode 100644 index 00000000..668d0cd4 --- /dev/null +++ b/letta/monitoring/__init__.py @@ -0,0 +1,13 @@ +"""Memory and request monitoring utilities for Letta application.""" + +from .memory_tracker import MemoryTracker, get_memory_tracker, track_operation +from .request_monitor import RequestBodyLogger, RequestSizeMonitoringMiddleware, identify_upload_endpoints + +__all__ = [ + "MemoryTracker", + "get_memory_tracker", + "track_operation", + "RequestSizeMonitoringMiddleware", + "RequestBodyLogger", + "identify_upload_endpoints", +] diff --git a/letta/monitoring/memory_tracker.py b/letta/monitoring/memory_tracker.py new file mode 100644 index 00000000..a0670eca --- /dev/null +++ b/letta/monitoring/memory_tracker.py @@ -0,0 +1,552 @@ +""" +Memory tracking utility for Letta application. +Provides real-time memory monitoring with proactive alerting using asyncio. +""" + +import asyncio +import functools +import gc +import json +import os +import sys +import time +import traceback +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Any, Callable, Dict, List, Optional, Tuple + +import psutil + +from letta.log import get_logger + +logger = get_logger(__name__) + + +class MemoryTracker: + """ + Track memory usage across different operations with proactive alerting. + Uses asyncio for all async operations. + + Features: + - Real-time memory monitoring using asyncio + - Proactive alerts before OOM + - Automatic memory dumps on critical thresholds + - Per-operation tracking and reporting + """ + + # Memory thresholds (in MB) + WARNING_THRESHOLD_MB = 1000 # Warning at 1GB + CRITICAL_THRESHOLD_MB = 2000 # Critical at 2GB + SPIKE_THRESHOLD_MB = 100 # Alert on 100MB+ spikes + + # Memory percentage thresholds + MEMORY_PERCENT_WARNING = 70 # Warn at 70% system memory + MEMORY_PERCENT_CRITICAL = 85 # Critical at 85% system memory + MEMORY_PERCENT_FATAL = 95 # Fatal - likely to OOM + + def __init__(self, enable_background_monitor: bool = True, monitor_interval: int = 5): + """ + Initialize the memory tracker. + + Args: + enable_background_monitor: Whether to start background monitoring + monitor_interval: Interval in seconds between background checks + """ + self.process = psutil.Process(os.getpid()) + self.measurements = defaultdict(list) + self.active_operations = {} + self.lock = asyncio.Lock() # Use asyncio.Lock instead of threading.Lock + self.monitor_interval = monitor_interval + self._monitoring = False + self._monitor_task = None + + # Track memory history for trend analysis + self.memory_history = [] + self.max_history_size = 100 + + # Track if we've already warned about memory levels + self._warned_levels = set() + + # Start time for uptime tracking + self.start_time = datetime.now() + + # Flag to track if we should start the monitor + self._should_start_monitor = enable_background_monitor + self._monitor_started = False + + logger.info( + f"Memory tracker initialized - PID: {os.getpid()}, " + f"Warning: {self.WARNING_THRESHOLD_MB}MB, " + f"Critical: {self.CRITICAL_THRESHOLD_MB}MB" + ) + + def _ensure_monitor_started(self): + """Start the background monitor if needed and if there's an event loop.""" + if self._should_start_monitor and not self._monitor_started: + try: + # Check if there's a running event loop + loop = asyncio.get_running_loop() + # Create the monitor task + asyncio.create_task(self.start_background_monitor()) + self._monitor_started = True + except RuntimeError: + # No event loop running yet, will try again later + pass + + def get_memory_info(self) -> Dict[str, Any]: + """Get current memory information.""" + try: + mem_info = self.process.memory_info() + mem_percent = self.process.memory_percent() + + # Get system-wide memory + system_mem = psutil.virtual_memory() + + return { + "rss_mb": mem_info.rss / 1024 / 1024, + "vms_mb": mem_info.vms / 1024 / 1024, + "percent": mem_percent, + "system_available_mb": system_mem.available / 1024 / 1024, + "system_percent": system_mem.percent, + "timestamp": datetime.now().isoformat(), + } + except Exception as e: + logger.error(f"Failed to get memory info: {e}") + return {} + + def track_operation(self, operation_name: str): + """ + Decorator to track memory for specific operations. + Logs immediately on completion or error. + """ + + def decorator(func): + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + # Ensure background monitor is started (now we have an event loop) + self._ensure_monitor_started() + + start_mem_info = self.get_memory_info() + start_time = time.time() + + # Log operation start if memory is already high + if start_mem_info.get("rss_mb", 0) > self.WARNING_THRESHOLD_MB: + logger.warning(f"Starting operation '{operation_name}' with high memory: {start_mem_info['rss_mb']:.2f} MB") + + # Record call stack for debugging + stack = traceback.extract_stack() + operation_id = f"{operation_name}_{id(func)}_{time.time()}" + + async with self.lock: + self.active_operations[operation_id] = { + "operation_name": operation_name, + "start_mem": start_mem_info, + "start_time": start_time, + "stack": stack, + } + + try: + result = await func(*args, **kwargs) + return result + except Exception as e: + # Log memory state on error + await self._log_operation_completion(operation_id, error=str(e)) + raise + finally: + # Always log operation completion + await self._log_operation_completion(operation_id) + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + start_mem_info = self.get_memory_info() + start_time = time.time() + + # Log operation start if memory is already high + if start_mem_info.get("rss_mb", 0) > self.WARNING_THRESHOLD_MB: + logger.warning(f"Starting operation '{operation_name}' with high memory: {start_mem_info['rss_mb']:.2f} MB") + + operation_id = f"{operation_name}_{id(func)}_{time.time()}" + + # For sync functions, we can't use async lock, so we'll use the operation directly + self.active_operations[operation_id] = { + "operation_name": operation_name, + "start_mem": start_mem_info, + "start_time": start_time, + "stack": traceback.extract_stack(), + } + + try: + result = func(*args, **kwargs) + return result + except Exception as e: + # Log memory state on error (sync version) + self._log_operation_completion_sync(operation_id, error=str(e)) + raise + finally: + # Always log operation completion (sync version) + self._log_operation_completion_sync(operation_id) + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + async def _log_operation_completion(self, operation_id: str, error: Optional[str] = None): + """Log memory usage immediately when an operation completes (async version).""" + async with self.lock: + if operation_id not in self.active_operations: + return + + operation_data = self.active_operations.pop(operation_id) + + end_mem_info = self.get_memory_info() + end_time = time.time() + + start_mem = operation_data["start_mem"].get("rss_mb", 0) + end_mem = end_mem_info.get("rss_mb", 0) + mem_delta = end_mem - start_mem + time_delta = end_time - operation_data["start_time"] + operation_name = operation_data["operation_name"] + + # Record measurement + async with self.lock: + self.measurements[operation_name].append( + { + "memory_delta_mb": mem_delta, + "peak_memory_mb": end_mem, + "time_seconds": time_delta, + "timestamp": datetime.now().isoformat(), + "error": error, + "system_percent": end_mem_info.get("system_percent", 0), + } + ) + + self._log_memory_status(operation_name, mem_delta, end_mem, time_delta, end_mem_info, error, operation_data) + + def _log_operation_completion_sync(self, operation_id: str, error: Optional[str] = None): + """Log memory usage immediately when an operation completes (sync version for non-async functions).""" + if operation_id not in self.active_operations: + return + + operation_data = self.active_operations.pop(operation_id) + end_mem_info = self.get_memory_info() + end_time = time.time() + + start_mem = operation_data["start_mem"].get("rss_mb", 0) + end_mem = end_mem_info.get("rss_mb", 0) + mem_delta = end_mem - start_mem + time_delta = end_time - operation_data["start_time"] + operation_name = operation_data["operation_name"] + + # Record measurement (sync version doesn't use async lock) + self.measurements[operation_name].append( + { + "memory_delta_mb": mem_delta, + "peak_memory_mb": end_mem, + "time_seconds": time_delta, + "timestamp": datetime.now().isoformat(), + "error": error, + "system_percent": end_mem_info.get("system_percent", 0), + } + ) + + self._log_memory_status(operation_name, mem_delta, end_mem, time_delta, end_mem_info, error, operation_data) + + def _log_memory_status( + self, + operation_name: str, + mem_delta: float, + end_mem: float, + time_delta: float, + end_mem_info: Dict, + error: Optional[str], + operation_data: Dict, + ): + """Common logging logic for memory status.""" + # Determine log level based on memory situation + if error: + logger.error( + f"Operation '{operation_name}' failed after {time_delta:.2f}s - " + f"Memory: {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB), " + f"System: {end_mem_info.get('system_percent', 0):.1f}%, " + f"Error: {error}" + ) + elif mem_delta > self.SPIKE_THRESHOLD_MB: + logger.warning( + f"MEMORY SPIKE: Operation '{operation_name}' - " + f"Increased by {mem_delta:.2f} MB in {time_delta:.2f}s - " + f"Current: {end_mem:.2f} MB, System: {end_mem_info.get('system_percent', 0):.1f}%" + ) + + # Log stack trace for large spikes + if mem_delta > self.SPIKE_THRESHOLD_MB * 2: + stack = operation_data.get("stack", []) + if stack and len(stack) > 3: + logger.warning("Call stack for memory spike:") + for frame in stack[-5:]: + logger.warning(f" {frame.filename}:{frame.lineno} in {frame.name}") + + elif end_mem > self.CRITICAL_THRESHOLD_MB: + logger.error( + f"CRITICAL MEMORY: Operation '{operation_name}' completed - " + f"Memory at {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB), " + f"System: {end_mem_info.get('system_percent', 0):.1f}%" + ) + elif end_mem > self.WARNING_THRESHOLD_MB: + logger.warning(f"High memory after '{operation_name}': {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB) in {time_delta:.2f}s") + else: + # Only log normal operations in debug mode + logger.debug(f"Operation '{operation_name}' completed: Memory {end_mem:.2f} MB (Δ{mem_delta:+.2f} MB) in {time_delta:.2f}s") + + async def start_background_monitor(self): + """Start the background memory monitoring task.""" + if self._monitoring: + return + + self._monitoring = True + self._monitor_started = True + self._monitor_task = asyncio.create_task(self._monitor_loop()) + logger.info(f"Background memory monitor started (interval: {self.monitor_interval}s)") + + async def stop_background_monitor(self): + """Stop the background memory monitoring task.""" + self._monitoring = False + self._monitor_started = False + if self._monitor_task: + self._monitor_task.cancel() + try: + await self._monitor_task + except asyncio.CancelledError: + pass + logger.info("Background memory monitor stopped") + + async def _monitor_loop(self): + """Background monitoring loop that runs continuously using asyncio.""" + consecutive_high_memory = 0 + last_gc_time = time.time() + + while self._monitoring: + try: + mem_info = self.get_memory_info() + current_mb = mem_info.get("rss_mb", 0) + system_percent = mem_info.get("system_percent", 0) + + # Add to history + async with self.lock: + self.memory_history.append(mem_info) + if len(self.memory_history) > self.max_history_size: + self.memory_history.pop(0) + + # Check memory levels + self._check_memory_thresholds(mem_info) + + # Track consecutive high memory readings + if current_mb > self.WARNING_THRESHOLD_MB: + consecutive_high_memory += 1 + else: + consecutive_high_memory = 0 + + # Force GC if memory is consistently high + if consecutive_high_memory >= 3 and time.time() - last_gc_time > 30: + await asyncio.to_thread(self._force_gc_with_logging) + last_gc_time = time.time() + + # Check for memory leak patterns + if len(self.memory_history) >= 10: + await self._check_memory_trend() + + # Log active operations if memory is critical + if system_percent > self.MEMORY_PERCENT_CRITICAL: + await self._log_active_operations() + + await asyncio.sleep(self.monitor_interval) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in memory monitor loop: {e}") + await asyncio.sleep(self.monitor_interval) + + def _check_memory_thresholds(self, mem_info: Dict[str, Any]): + """Check memory against thresholds and log appropriately.""" + current_mb = mem_info.get("rss_mb", 0) + system_percent = mem_info.get("system_percent", 0) + + # Check system memory percentage thresholds + if system_percent > self.MEMORY_PERCENT_FATAL and "fatal" not in self._warned_levels: + logger.critical( + f"FATAL MEMORY LEVEL: {system_percent:.1f}% of system memory used! Process: {current_mb:.2f} MB - OOM imminent!" + ) + self._warned_levels.add("fatal") + self._dump_memory_state() + + elif system_percent > self.MEMORY_PERCENT_CRITICAL and "critical" not in self._warned_levels: + logger.error(f"CRITICAL: System memory at {system_percent:.1f}% - Process using {current_mb:.2f} MB") + self._warned_levels.add("critical") + self._dump_active_operations_summary() + + elif system_percent > self.MEMORY_PERCENT_WARNING and "warning" not in self._warned_levels: + logger.warning(f"Memory warning: System at {system_percent:.1f}% - Process: {current_mb:.2f} MB") + self._warned_levels.add("warning") + + # Reset warning levels if memory drops + if system_percent < self.MEMORY_PERCENT_WARNING: + self._warned_levels.clear() + + async def _check_memory_trend(self): + """Check if memory is trending upward (potential leak).""" + async with self.lock: + if len(self.memory_history) < 10: + return + + # Get the last 10 readings + recent = self.memory_history[-10:] + + # Calculate trend + first_mb = recent[0].get("rss_mb", 0) + last_mb = recent[-1].get("rss_mb", 0) + increase_mb = last_mb - first_mb + + # Calculate average rate of increase + time_span = 10 * self.monitor_interval # seconds + rate_mb_per_min = (increase_mb / time_span) * 60 + + # Warn if memory is increasing rapidly + if rate_mb_per_min > 50: # More than 50MB/minute + logger.warning( + f"MEMORY LEAK SUSPECTED: Memory increasing at {rate_mb_per_min:.1f} MB/min - From {first_mb:.2f} MB to {last_mb:.2f} MB" + ) + self._dump_active_operations_summary() + + def _force_gc_with_logging(self): + """Force garbage collection and log the results.""" + before_mb = self.get_memory_info().get("rss_mb", 0) + collected = gc.collect() + after_mb = self.get_memory_info().get("rss_mb", 0) + freed_mb = before_mb - after_mb + + if freed_mb > 10: # Only log if significant memory was freed + logger.info(f"Garbage collection freed {freed_mb:.2f} MB (collected {collected} objects)") + + async def _log_active_operations(self): + """Log currently active operations.""" + async with self.lock: + if not self.active_operations: + return + + logger.warning(f"Active operations during high memory ({len(self.active_operations)} running):") + for op_id, op_data in self.active_operations.items(): + duration = time.time() - op_data["start_time"] + logger.warning( + f" - {op_data['operation_name']}: running for {duration:.1f}s, started at {op_data['start_mem'].get('rss_mb', 0):.2f} MB" + ) + + def _dump_active_operations_summary(self): + """Dump summary of recent operations.""" + logger.warning("=== Recent Operation Memory Usage ===") + for operation_name, measurements in self.measurements.items(): + if measurements: + recent = measurements[-5:] # Last 5 measurements + total_delta = sum(m["memory_delta_mb"] for m in recent) + avg_delta = total_delta / len(recent) + max_peak = max(m["peak_memory_mb"] for m in recent) + + logger.warning(f"{operation_name}: Avg Δ{avg_delta:+.1f} MB, Peak {max_peak:.1f} MB ({len(recent)} recent ops)") + + def _dump_memory_state(self): + """Dump detailed memory state for debugging.""" + logger.critical("=== MEMORY STATE DUMP ===") + + # Current memory + mem_info = self.get_memory_info() + logger.critical(f"Current memory: {json.dumps(mem_info, indent=2)}") + + # Top memory consumers by operation + logger.critical("Top memory consuming operations:") + for op_name, measurements in sorted(self.measurements.items(), key=lambda x: sum(m["memory_delta_mb"] for m in x[1]), reverse=True)[ + :5 + ]: + if measurements: + total = sum(m["memory_delta_mb"] for m in measurements) + logger.critical(f" {op_name}: {total:.1f} MB total across {len(measurements)} calls") + + # Active operations + if self.active_operations: + logger.critical(f"Active operations: {len(self.active_operations)}") + for op_data in self.active_operations.values(): + logger.critical(f" - {op_data['operation_name']}") + + # System info + logger.critical(f"Uptime: {datetime.now() - self.start_time}") + logger.critical(f"PID: {os.getpid()}") + + # Attempt to get top memory objects (if available) + try: + import sys + + logger.critical("Top objects by count:") + obj_counts = defaultdict(int) + for obj in gc.get_objects()[:1000]: # Sample first 1000 objects + obj_counts[type(obj).__name__] += 1 + + for obj_type, count in sorted(obj_counts.items(), key=lambda x: x[1], reverse=True)[:10]: + logger.critical(f" {obj_type}: {count}") + except Exception as e: + logger.error(f"Could not get object counts: {e}") + + def get_report(self) -> str: + """Generate a summary report of memory usage.""" + lines = [] + lines.append("=== MEMORY USAGE REPORT ===") + lines.append(f"Uptime: {datetime.now() - self.start_time}") + + current = self.get_memory_info() + lines.append(f"Current memory: {current.get('rss_mb', 0):.2f} MB") + lines.append(f"System memory: {current.get('system_percent', 0):.1f}%") + + lines.append("\nOperations summary:") + for operation_name, measurements in self.measurements.items(): + if measurements: + total_mem = sum(m["memory_delta_mb"] for m in measurements) + avg_mem = total_mem / len(measurements) + max_mem = max(m["memory_delta_mb"] for m in measurements) + errors = sum(1 for m in measurements if m.get("error")) + + lines.append( + f" {operation_name}: " + f"{len(measurements)} calls, " + f"Avg Δ{avg_mem:+.1f} MB, " + f"Max Δ{max_mem:+.1f} MB, " + f"Total Δ{total_mem:+.1f} MB" + ) + if errors: + lines.append(f" ({errors} errors)") + + return "\n".join(lines) + + +# Global tracker instance +_global_tracker = None + + +def get_memory_tracker(enable_background_monitor: bool = True, monitor_interval: int = 5) -> MemoryTracker: + """Get or create the global memory tracker instance.""" + global _global_tracker + if _global_tracker is None: + _global_tracker = MemoryTracker(enable_background_monitor=enable_background_monitor, monitor_interval=monitor_interval) + return _global_tracker + + +def track_operation(operation_name: str): + """ + Convenience decorator that uses the global tracker. + + Usage: + @track_operation("my_operation") + async def my_function(): + ... + """ + tracker = get_memory_tracker() + return tracker.track_operation(operation_name) diff --git a/letta/monitoring/request_monitor.py b/letta/monitoring/request_monitor.py new file mode 100644 index 00000000..d46e9e83 --- /dev/null +++ b/letta/monitoring/request_monitor.py @@ -0,0 +1,274 @@ +""" +Request size monitoring middleware for Letta application. +Tracks incoming request sizes to identify large uploads causing SSL memory spikes. +""" + +import time +from typing import Optional + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response + +from letta.log import get_logger +from letta.monitoring import get_memory_tracker + +logger = get_logger(__name__) + +# Size thresholds (in bytes) +SIZE_WARNING_THRESHOLD = 10 * 1024 * 1024 # 10MB +SIZE_ERROR_THRESHOLD = 50 * 1024 * 1024 # 50MB +SIZE_CRITICAL_THRESHOLD = 100 * 1024 * 1024 # 100MB + + +class RequestSizeMonitoringMiddleware(BaseHTTPMiddleware): + """ + Middleware to monitor incoming request sizes and detect large uploads. + This helps identify if SSL memory spikes are from large incoming data. + """ + + def __init__(self, app): + super().__init__(app) + self.tracker = get_memory_tracker() + + async def dispatch(self, request: Request, call_next): + start_time = time.time() + request_size = 0 + content_type = request.headers.get("content-type", "") + + # Track the endpoint + endpoint = f"{request.method} {request.url.path}" + + # Special monitoring for known large data endpoints + critical_endpoints = [ + ("/upload", "file upload"), + ("/messages", "message creation"), + ("/agents", "agent creation/update"), + ("/memory", "memory update"), + ("/sources", "source upload"), + ("/folders", "folder upload"), + ] + + is_critical = any(pattern in request.url.path.lower() for pattern, _ in critical_endpoints) + endpoint_type = next((desc for pattern, desc in critical_endpoints if pattern in request.url.path.lower()), "general") + + # Try to get content length from headers + content_length = request.headers.get("content-length") + if content_length: + request_size = int(content_length) + + # Get memory before processing + memory_before = self.tracker.get_memory_info() + + # Enhanced monitoring for critical endpoints + if is_critical and request_size > 1024 * 1024: # Log all critical endpoints > 1MB + logger.info( + f"Critical endpoint access: {endpoint_type} - {endpoint} - " + f"Size: {request_size / 1024 / 1024:.2f} MB - " + f"Memory before: {memory_before.get('rss_mb', 0):.2f} MB" + ) + + # Log large incoming requests BEFORE processing + if request_size > SIZE_CRITICAL_THRESHOLD: + logger.critical( + f"CRITICAL: Large request incoming - {endpoint} ({endpoint_type}) - " + f"Size: {request_size / 1024 / 1024:.2f} MB - " + f"Content-Type: {content_type} - " + f"Memory before: {memory_before.get('rss_mb', 0):.2f} MB" + ) + elif request_size > SIZE_ERROR_THRESHOLD: + logger.error( + f"Large request detected - {endpoint} ({endpoint_type}) - " + f"Size: {request_size / 1024 / 1024:.2f} MB - " + f"Content-Type: {content_type}" + ) + elif request_size > SIZE_WARNING_THRESHOLD: + logger.warning( + f"Sizeable request - {endpoint} ({endpoint_type}) - " + f"Size: {request_size / 1024 / 1024:.2f} MB - " + f"Content-Type: {content_type}" + ) + + # For multipart/form-data (file uploads), try to get more details + if "multipart/form-data" in content_type: + logger.info(f"File upload detected at {endpoint} - Expected size: {request_size / 1024 / 1024:.2f} MB") + # Note: The actual file reading happens when the endpoint accesses request.form() + # That's when SSL read would spike + + # Track the operation with memory monitoring + operation_name = f"request_{request.method}_{request.url.path.replace('/', '_')}" + + # For large JSON payloads, log structure details + if request_size > SIZE_WARNING_THRESHOLD and "application/json" in content_type: + # Create a copy of the request for body logging + body_logger = RequestBodyLogger() + try: + # Clone the body for inspection (this won't consume the original) + body = await request.body() + + # Put it back for the actual handler + async def receive(): + return {"type": "http.request", "body": body} + + request._receive = receive + + # Log the structure + if body: + await self._log_json_structure(body, endpoint) + except Exception as e: + logger.warning(f"Could not inspect request body: {e}") + + try: + # Process the request with memory tracking + # Use the decorator by creating a wrapped function + @self.tracker.track_operation(operation_name) + async def process_request(): + return await call_next(request) + + response = await process_request() + + # Get memory after processing + memory_after = self.tracker.get_memory_info() + memory_delta = memory_after.get("rss_mb", 0) - memory_before.get("rss_mb", 0) + + # Log if memory increased significantly during request processing + if memory_delta > 100: # More than 100MB increase + process_time = time.time() - start_time + logger.error( + f"MEMORY SPIKE during request: {endpoint} - " + f"Memory increased by {memory_delta:.2f} MB - " + f"Request size: {request_size / 1024 / 1024:.2f} MB - " + f"Process time: {process_time:.2f}s - " + f"Content-Type: {content_type}" + ) + + return response + + except Exception as e: + # Log any errors with context + logger.error(f"Error processing request {endpoint} - Size: {request_size / 1024 / 1024:.2f} MB - Error: {str(e)}") + raise + + async def _log_json_structure(self, body: bytes, endpoint: str): + """Helper method to log JSON body structure for large payloads.""" + import json + + try: + data = json.loads(body) + body_size = len(body) + + # Calculate field sizes + field_sizes = {} + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, (str, bytes)): + field_sizes[key] = len(value) + elif isinstance(value, (list, dict)): + field_sizes[key] = len(json.dumps(value)) + else: + field_sizes[key] = 0 + elif isinstance(data, list): + field_sizes["list_items"] = len(data) + if data: + # Sample first item size + field_sizes["first_item_size"] = len(json.dumps(data[0])) + + # Find largest fields + if field_sizes: + largest_fields = sorted(field_sizes.items(), key=lambda x: x[1], reverse=True)[:5] + + logger.warning( + f"Large JSON payload structure at {endpoint}:\n" + f" Total size: {body_size / 1024 / 1024:.2f} MB\n" + f" Top fields by size:\n" + + "\n".join([f" - {k}: {v / 1024 / 1024:.2f} MB" for k, v in largest_fields if v > 1024]) # Only show fields > 1KB + ) + + # Special monitoring for known problematic fields + problematic_fields = ["messages", "memory", "system", "tools", "files", "context"] + for field in problematic_fields: + if field in field_sizes and field_sizes[field] > 5 * 1024 * 1024: # > 5MB + logger.error(f"LARGE FIELD DETECTED: '{field}' is {field_sizes[field] / 1024 / 1024:.2f} MB at {endpoint}") + + except json.JSONDecodeError: + logger.error(f"Could not parse JSON body at {endpoint} (size: {len(body) / 1024 / 1024:.2f} MB)") + except Exception as e: + logger.error(f"Error analyzing JSON structure: {e}") + + +class RequestBodyLogger: + """ + Utility to log request body details for specific endpoints. + Useful for debugging which fields contain large data. + """ + + @staticmethod + async def log_json_body_structure(request: Request, endpoint: str): + """Log the structure and size of JSON request bodies.""" + try: + # Only for JSON requests + if "application/json" not in request.headers.get("content-type", ""): + return + + # Get the body + body = await request.body() + body_size = len(body) + + if body_size > SIZE_WARNING_THRESHOLD: + # Try to parse JSON to understand structure + import json + + try: + data = json.loads(body) + + # Log field sizes + field_sizes = {} + for key, value in data.items() if isinstance(data, dict) else enumerate(data): + if isinstance(value, (str, bytes)): + field_sizes[key] = len(value) + elif isinstance(value, (list, dict)): + field_sizes[key] = len(json.dumps(value)) + else: + field_sizes[key] = 0 + + # Find largest fields + largest_fields = sorted(field_sizes.items(), key=lambda x: x[1], reverse=True)[:5] + + logger.warning( + f"Large JSON body structure at {endpoint}:\n" + f" Total size: {body_size / 1024 / 1024:.2f} MB\n" + f" Top fields by size:\n" + "\n".join([f" - {k}: {v / 1024 / 1024:.2f} MB" for k, v in largest_fields]) + ) + + except json.JSONDecodeError: + logger.error(f"Could not parse JSON body at {endpoint} (size: {body_size / 1024 / 1024:.2f} MB)") + + except Exception as e: + logger.error(f"Error logging request body structure: {e}") + + +def identify_upload_endpoints(app): + """ + Scan the app routes to identify potential upload endpoints. + """ + upload_endpoints = [] + + for route in app.routes: + if hasattr(route, "path") and hasattr(route, "methods"): + path = route.path + methods = route.methods + + # Look for common upload patterns + upload_keywords = ["upload", "file", "attachment", "media", "document", "image", "import"] + if any(keyword in path.lower() for keyword in upload_keywords): + upload_endpoints.append((path, methods)) + + # Also check for POST/PUT to any endpoint (potential large JSON) + if "POST" in methods or "PUT" in methods or "PATCH" in methods: + upload_endpoints.append((path, methods)) + + logger.info("Potential upload/large data endpoints identified:") + for path, methods in upload_endpoints[:20]: # Log first 20 + logger.info(f" {', '.join(methods)}: {path}") + + return upload_endpoints diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index e6fd174f..e97f564b 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -14,6 +14,14 @@ import uvicorn # Enable Python fault handler to get stack traces on segfaults faulthandler.enable() + +# Import memory tracking (if available) +try: + from letta.monitoring import RequestSizeMonitoringMiddleware, get_memory_tracker, identify_upload_endpoints + + MEMORY_TRACKING_ENABLED = True +except ImportError: + MEMORY_TRACKING_ENABLED = False from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse @@ -133,6 +141,13 @@ async def lifespan(app_: FastAPI): """ worker_id = os.getpid() + # Initialize memory tracking + if MEMORY_TRACKING_ENABLED: + logger.info(f"[Worker {worker_id}] Initializing memory tracking") + # Get the global tracker instance (will start background monitor automatically) + tracker = get_memory_tracker(enable_background_monitor=True, monitor_interval=5) + logger.info(f"[Worker {worker_id}] Memory tracking enabled - monitoring every 5s with proactive alerts") + if telemetry_settings.profiler: try: import googlecloudprofiler @@ -174,6 +189,14 @@ async def lifespan(app_: FastAPI): # Cleanup on shutdown logger.info(f"[Worker {worker_id}] Starting lifespan shutdown") + + # Report memory usage before shutdown + if MEMORY_TRACKING_ENABLED: + logger.info(f"[Worker {worker_id}] Generating final memory report") + tracker = get_memory_tracker() + report = tracker.get_report() + logger.info(f"[Worker {worker_id}] Memory report:\n{report}") + try: from letta.jobs.scheduler import shutdown_scheduler_and_release_lock @@ -556,6 +579,13 @@ def create_application() -> "FastAPI": # Add unified logging middleware - enriches log context and logs exceptions app.add_middleware(LoggingMiddleware) + # Add request size monitoring middleware to detect large uploads + if MEMORY_TRACKING_ENABLED: + app.add_middleware(RequestSizeMonitoringMiddleware) + logger.info("Request size monitoring middleware enabled") + # Identify potential upload endpoints + identify_upload_endpoints(app) + app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index b8bb6d6a..da94e221 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -28,6 +28,7 @@ from letta.errors import ( from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4 from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns from letta.log import get_logger +from letta.monitoring import track_operation from letta.orm.errors import NoResultFound from letta.otel.context import get_ctx_attributes from letta.otel.metric_registry import MetricRegistry @@ -329,6 +330,7 @@ async def _import_agent( @router.post("/import", response_model=ImportedAgentsResponse, operation_id="import_agent") +@track_operation("import_agent") async def import_agent( file: UploadFile = File(...), server: "SyncServer" = Depends(get_letta_server), @@ -437,6 +439,7 @@ class CreateAgentRequest(CreateAgent): @router.post("/", response_model=AgentState, operation_id="create_agent") +@track_operation("create_agent") async def create_agent( agent: CreateAgentRequest = Body(...), server: "SyncServer" = Depends(get_letta_server), diff --git a/letta/server/rest_api/routers/v1/folders.py b/letta/server/rest_api/routers/v1/folders.py index 6d3c953b..1ccb9dc8 100644 --- a/letta/server/rest_api/routers/v1/folders.py +++ b/letta/server/rest_api/routers/v1/folders.py @@ -17,6 +17,7 @@ from letta.helpers.pinecone_utils import ( ) from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger +from letta.monitoring import track_operation from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig @@ -227,6 +228,7 @@ async def delete_folder( @router.post("/{folder_id}/upload", response_model=FileMetadata, operation_id="upload_file_to_folder") +@track_operation("file_upload_to_folder") async def upload_file_to_folder( file: UploadFile, folder_id: FolderId, diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 0f97de3e..5f13f2f6 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Depends, Query +from letta.monitoring.memory_tracker import track_operation from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.model import EmbeddingModel, Model from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server @@ -13,6 +14,7 @@ router = APIRouter(prefix="/models", tags=["models", "llms"]) @router.get("/", response_model=List[Model], operation_id="list_models") +@track_operation("list_llm_models_endpoint") async def list_llm_models( provider_category: Optional[List[ProviderCategory]] = Query(None), provider_name: Optional[str] = Query(None), @@ -40,6 +42,7 @@ async def list_llm_models( @router.get("/embedding", response_model=List[EmbeddingModel], operation_id="list_embedding_models") +@track_operation("list_embedding_models_endpoint") async def list_embedding_models( server: "SyncServer" = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index c2cfa290..e9d3b9d8 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -17,6 +17,7 @@ from letta.helpers.pinecone_utils import ( ) from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger +from letta.monitoring import track_operation from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig @@ -207,6 +208,7 @@ async def delete_source( @router.post("/{source_id}/upload", response_model=FileMetadata, operation_id="upload_file_to_source", deprecated=True) +@track_operation("file_upload_to_source") async def upload_file_to_source( file: UploadFile, source_id: SourceId, diff --git a/letta/server/server.py b/letta/server/server.py index 957fa4ff..6324c564 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -31,6 +31,7 @@ from letta.interface import ( CLIInterface, # for printing to terminal ) from letta.log import get_logger +from letta.monitoring.memory_tracker import track_operation from letta.orm.errors import NoResultFound from letta.otel.tracing import log_event, trace_method from letta.prompts.gpt_system import get_system_text @@ -972,6 +973,7 @@ class SyncServer(object): return passage_count, document_count @trace_method + @track_operation("list_llm_models_server") async def list_llm_models_async( self, actor: User, @@ -1023,6 +1025,7 @@ class SyncServer(object): return unique_models + @track_operation("list_embedding_models_server") async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: """Asynchronously list available embedding models with maximum concurrency""" import asyncio @@ -1049,6 +1052,7 @@ class SyncServer(object): return embedding_models + @track_operation("get_enabled_providers") async def get_enabled_providers_async( self, actor: User, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 521b3ced..2e41117a 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -7,6 +7,19 @@ import sqlalchemy as sa from sqlalchemy import delete, func, insert, literal, or_, select, tuple_ from sqlalchemy.dialects.postgresql import insert as pg_insert +# Import memory tracking if available +try: + from letta.monitoring import track_operation + + MEMORY_TRACKING_ENABLED = True +except ImportError: + MEMORY_TRACKING_ENABLED = False + + # Define a no-op decorator if tracking is not available + def track_operation(name): + return lambda f: f + + from letta.constants import ( BASE_MEMORY_TOOLS, BASE_MEMORY_TOOLS_V2, @@ -327,6 +340,7 @@ class AgentManager: # ====================================================================================================================== @trace_method + @track_operation("agent_creation") async def create_agent_async( self, agent_create: CreateAgent, diff --git a/letta/services/context_window_calculator/context_window_calculator.py b/letta/services/context_window_calculator/context_window_calculator.py index f1e4f79b..9dbc21e6 100644 --- a/letta/services/context_window_calculator/context_window_calculator.py +++ b/letta/services/context_window_calculator/context_window_calculator.py @@ -4,6 +4,20 @@ from typing import Any, List, Optional, Tuple from openai.types.beta.function_tool import FunctionTool as OpenAITool from letta.log import get_logger + +# Import memory tracking if available +try: + from letta.monitoring import track_operation + + MEMORY_TRACKING_ENABLED = True +except ImportError: + MEMORY_TRACKING_ENABLED = False + + # Define a no-op decorator if tracking is not available + def track_operation(name): + return lambda f: f + + from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent @@ -96,6 +110,7 @@ class ContextWindowCalculator: return None, 1 + @track_operation("calculate_context_window") async def calculate_context_window( self, agent_state: AgentState, diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index 596b6356..bf1e4998 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import selectinload from letta.constants import MAX_FILENAME_LENGTH from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone from letta.log import get_logger +from letta.monitoring import track_operation from letta.orm.errors import NoResultFound from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel from letta.orm.sqlalchemy_base import AccessType @@ -53,6 +54,7 @@ class FileManager: @enforce_types @trace_method + @track_operation("create_file_with_content") async def create_file( self, file_metadata: PydanticFileMetadata, @@ -356,6 +358,7 @@ class FileManager: @enforce_types @trace_method @raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE) + @track_operation("upsert_file_content") async def upsert_file_content( self, *, @@ -402,6 +405,7 @@ class FileManager: @enforce_types @trace_method @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) + @track_operation("list_files_with_content") async def list_files( self, source_id: str, @@ -629,6 +633,7 @@ class FileManager: @enforce_types @trace_method + @track_operation("batch_get_files_by_ids") async def get_files_by_ids_async( self, file_ids: List[str], actor: PydanticUser, *, include_content: bool = False ) -> List[PydanticFileMetadata]: @@ -664,6 +669,7 @@ class FileManager: @enforce_types @trace_method + @track_operation("batch_get_files_for_agents") async def get_files_for_agents_async( self, agent_ids: List[str], actor: PydanticUser, *, include_content: bool = False ) -> List[PydanticFileMetadata]: diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 4dbdb149..dd911876 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union from sqlalchemy import and_, asc, delete, desc, or_, select from sqlalchemy.orm import Session +from letta.monitoring import track_operation from letta.orm.agent import Agent as AgentModel from letta.orm.block import Block from letta.orm.errors import NoResultFound @@ -73,6 +74,7 @@ class GroupManager: return group.to_pydantic() @enforce_types + @track_operation("create_multi_agent_group") async def create_group_async(self, group: Union[GroupCreate, InternalTemplateGroupCreate], actor: PydanticUser) -> PydanticGroup: async with db_registry.async_session() as session: new_group = GroupModel() @@ -197,6 +199,7 @@ class GroupManager: @enforce_types @trace_method @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) + @track_operation("list_multi_agent_messages") async def list_group_messages_async( self, actor: PydanticUser, @@ -320,6 +323,7 @@ class GroupManager: else: raise ValueError("Extend relationship is not supported for groups.") + @track_operation("process_multi_agent_relationships") async def _process_agent_relationship_async(self, session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True): if not agent_ids: if replace: diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 91a9b71f..113fb7a6 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -7,6 +7,20 @@ from sqlalchemy import delete, exists, func, select, text from letta.constants import CONVERSATION_SEARCH_TOOL_NAME, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger + +# Import memory tracking if available +try: + from letta.monitoring import track_operation + + MEMORY_TRACKING_ENABLED = True +except ImportError: + MEMORY_TRACKING_ENABLED = False + + # Define a no-op decorator if tracking is not available + def track_operation(name): + return lambda f: f + + from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel from letta.otel.tracing import trace_method @@ -441,6 +455,7 @@ class MessageManager: @enforce_types @trace_method + @track_operation("bulk_create_messages") async def create_many_messages_async( self, pydantic_msgs: List[PydanticMessage], @@ -825,6 +840,7 @@ class MessageManager: @enforce_types @trace_method + @track_operation("list_messages") async def list_messages( self, actor: PydanticUser, @@ -949,6 +965,7 @@ class MessageManager: @enforce_types @trace_method + @track_operation("bulk_delete_all_agent_messages") async def delete_all_messages_for_agent_async( self, agent_id: str, actor: PydanticUser, exclude_ids: Optional[List[str]] = None, strict_mode: bool = False ) -> int: @@ -997,6 +1014,7 @@ class MessageManager: @enforce_types @trace_method + @track_operation("bulk_delete_messages_by_ids") async def delete_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser, strict_mode: bool = False) -> int: """ Efficiently deletes messages by their specific IDs, @@ -1044,6 +1062,7 @@ class MessageManager: @enforce_types @trace_method + @track_operation("search_messages") async def search_messages_async( self, agent_id: str, @@ -1164,6 +1183,7 @@ class MessageManager: message_tuples.append((message, metadata)) return message_tuples + @track_operation("search_messages_org") async def search_messages_org_async( self, actor: PydanticUser, diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 55f628ed..faa65748 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -11,6 +11,7 @@ from letta.constants import MAX_EMBEDDING_DIM from letta.helpers.decorators import async_redis_cache from letta.llm_api.llm_client import LLMClient from letta.log import get_logger +from letta.monitoring import track_operation from letta.orm import ArchivesAgents from letta.orm.errors import NoResultFound from letta.orm.passage import ArchivalPassage, SourcePassage @@ -302,6 +303,7 @@ class PassageManager: @enforce_types @trace_method + @track_operation("batch_create_archival_passages") async def create_many_archival_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: """Create multiple archival passages.""" archival_passages = [] @@ -354,6 +356,7 @@ class PassageManager: @enforce_types @trace_method + @track_operation("batch_create_source_passages") async def create_many_source_passages_async( self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser ) -> List[PydanticPassage]: @@ -451,6 +454,7 @@ class PassageManager: @enforce_types @trace_method + @track_operation("insert_passages_with_embeddings") async def insert_passage( self, agent_state: AgentState, @@ -545,6 +549,7 @@ class PassageManager: except Exception as e: raise e + @track_operation("generate_embeddings_batch") async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config, actor: PydanticUser) -> List[List[float]]: """Generate embeddings for all text chunks concurrently using LLMClient""" @@ -764,6 +769,7 @@ class PassageManager: @enforce_types @trace_method + @track_operation("bulk_delete_agent_passages") async def delete_agent_passages_async( self, passages: List[PydanticPassage], @@ -818,6 +824,7 @@ class PassageManager: @enforce_types @trace_method + @track_operation("bulk_delete_source_passages") async def delete_source_passages_async( self, actor: PydanticUser, diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index f99f7944..42b09c0b 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -3,6 +3,20 @@ from typing import List, Optional, Tuple, Union from letta.orm.provider import Provider as ProviderModel from letta.orm.provider_model import ProviderModel as ProviderModelORM from letta.otel.tracing import trace_method + +# Import memory tracking if available +try: + from letta.monitoring import track_operation + + MEMORY_TRACKING_ENABLED = True +except ImportError: + MEMORY_TRACKING_ENABLED = False + + # Define a no-op decorator if tracking is not available + def track_operation(name): + return lambda f: f + + from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType from letta.schemas.llm_config import LLMConfig @@ -746,6 +760,7 @@ class ProviderManager: @enforce_types @trace_method + @track_operation("list_models") async def list_models_async( self, actor: PydanticUser, diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index d79e3c1c..48fba820 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -5,6 +5,7 @@ from sqlalchemy import and_, exists, select from letta.helpers.pinecone_utils import should_use_pinecone from letta.helpers.tpuf_client import should_use_tpuf +from letta.monitoring import track_operation from letta.orm import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.source import Source as SourceModel @@ -272,6 +273,7 @@ class SourceManager: @enforce_types @trace_method @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) + @track_operation("list_all_attached_agents") async def list_attached_agents( self, source_id: str, actor: PydanticUser, ids_only: bool = False ) -> Union[List[PydanticAgentState], List[str]]: @@ -465,6 +467,7 @@ class SourceManager: @enforce_types @trace_method + @track_operation("batch_get_sources_by_ids") async def get_sources_by_ids_async(self, source_ids: List[str], actor: PydanticUser) -> List[PydanticSource]: """ Get multiple sources by their IDs in a single query. @@ -491,6 +494,7 @@ class SourceManager: @enforce_types @trace_method + @track_operation("batch_get_sources_for_agents") async def get_sources_for_agents_async(self, agent_ids: List[str], actor: PydanticUser) -> List[PydanticSource]: """ Get all sources associated with the given agents via sources-agents relationships. diff --git a/pyproject.toml b/pyproject.toml index e4c5ab42..6b781839 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ dependencies = [ "readability-lxml", "google-genai>=1.15.0", "datadog>=0.49.1", + "psutil>=5.9.0", ] [project.scripts] diff --git a/uv.lock b/uv.lock index 1a61c45f..3b0fdfe9 100644 --- a/uv.lock +++ b/uv.lock @@ -2446,6 +2446,7 @@ dependencies = [ { name = "orjson" }, { name = "pathvalidate" }, { name = "prettytable" }, + { name = "psutil" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyhumps" }, @@ -2620,6 +2621,7 @@ requires-dist = [ { name = "pinecone", extras = ["asyncio"], marker = "extra == 'pinecone'", specifier = ">=7.3.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.5.0" }, { name = "prettytable", specifier = ">=3.9.0" }, + { name = "psutil", specifier = ">=5.9.0" }, { name = "psycopg2", marker = "extra == 'postgres'", specifier = ">=2.9.10" }, { name = "psycopg2-binary", marker = "extra == 'postgres'", specifier = ">=2.9.10" }, { name = "pydantic", specifier = ">=2.10.6" },