Files
letta-server/letta/services/tool_sandbox/safe_pickle.py
Kian Jones f5c4ab50f4 chore: add ty + pre-commit hook and repeal even more ruff rules (#9504)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import

* remove all ignores, add FastAPI rules and Ruff rules

* add ty and precommit

* ruff stuff

* ty check fixes

* ty check fixes pt 2

* error on invalid
2026-02-24 10:55:11 -08:00

194 lines
5.8 KiB
Python

"""Safe pickle serialization wrapper for Modal sandbox.
This module provides defensive serialization utilities to prevent segmentation
faults and other crashes when passing complex objects to Modal containers.
"""
import pickle
import sys
from typing import Any, Optional, Tuple
from letta.log import get_logger
logger = get_logger(__name__)
# Serialization limits
MAX_PICKLE_SIZE = 10 * 1024 * 1024 # 10MB limit
MAX_RECURSION_DEPTH = 50 # Prevent deep object graphs
PICKLE_PROTOCOL = 4 # Use protocol 4 for better compatibility
class SafePickleError(Exception):
"""Raised when safe pickling fails."""
class RecursionLimiter:
"""Context manager to limit recursion depth during pickling."""
def __init__(self, max_depth: int):
self.max_depth = max_depth
self.original_limit = None
def __enter__(self):
self.original_limit = sys.getrecursionlimit()
sys.setrecursionlimit(min(self.max_depth, self.original_limit))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_limit is not None:
sys.setrecursionlimit(self.original_limit)
def safe_pickle_dumps(obj: Any, max_size: int = MAX_PICKLE_SIZE) -> bytes:
"""Safely pickle an object with size and recursion limits.
Args:
obj: The object to pickle
max_size: Maximum allowed pickle size in bytes
Returns:
bytes: The pickled object
Raises:
SafePickleError: If pickling fails or exceeds limits
"""
try:
# First check for obvious size issues
# Do a quick pickle to check size
quick_pickle = pickle.dumps(obj, protocol=PICKLE_PROTOCOL)
if len(quick_pickle) > max_size:
raise SafePickleError(f"Pickle size {len(quick_pickle)} exceeds limit {max_size}")
# Check recursion depth by traversing the object
def check_depth(obj, depth=0):
if depth > MAX_RECURSION_DEPTH:
raise SafePickleError(f"Object graph too deep (depth > {MAX_RECURSION_DEPTH})")
if isinstance(obj, (list, tuple)):
for item in obj:
check_depth(item, depth + 1)
elif isinstance(obj, dict):
for value in obj.values():
check_depth(value, depth + 1)
elif hasattr(obj, "__dict__"):
check_depth(obj.__dict__, depth + 1)
check_depth(obj)
logger.debug(f"Successfully pickled object of size {len(quick_pickle)} bytes")
return quick_pickle
except SafePickleError:
raise
except RecursionError as e:
raise SafePickleError(f"Object graph too deep: {e}")
except Exception as e:
raise SafePickleError(f"Failed to pickle object: {e}")
def safe_pickle_loads(data: bytes) -> Any:
"""Safely unpickle data with error handling.
Args:
data: The pickled data
Returns:
Any: The unpickled object
Raises:
SafePickleError: If unpickling fails
"""
if not data:
raise SafePickleError("Cannot unpickle empty data")
if len(data) > MAX_PICKLE_SIZE:
raise SafePickleError(f"Pickle data size {len(data)} exceeds limit {MAX_PICKLE_SIZE}")
try:
obj = pickle.loads(data)
logger.debug(f"Successfully unpickled object from {len(data)} bytes")
return obj
except Exception as e:
raise SafePickleError(f"Failed to unpickle data: {e}")
def try_pickle_with_fallback(obj: Any, fallback_value: Any = None, max_size: int = MAX_PICKLE_SIZE) -> Tuple[Optional[bytes], bool]:
"""Try to pickle an object with fallback on failure.
Args:
obj: The object to pickle
fallback_value: Value to use if pickling fails
max_size: Maximum allowed pickle size
Returns:
Tuple of (pickled_data or None, success_flag)
"""
try:
pickled = safe_pickle_dumps(obj, max_size)
return pickled, True
except SafePickleError as e:
logger.warning(f"Failed to pickle object, using fallback: {e}")
if fallback_value is not None:
try:
pickled = safe_pickle_dumps(fallback_value, max_size)
return pickled, False
except SafePickleError:
pass
return None, False
def validate_pickleable(obj: Any) -> bool:
"""Check if an object can be safely pickled.
Args:
obj: The object to validate
Returns:
bool: True if the object can be pickled safely
"""
try:
# Try to pickle to a small buffer
safe_pickle_dumps(obj, max_size=MAX_PICKLE_SIZE)
return True
except SafePickleError:
return False
def sanitize_for_pickle(obj: Any) -> Any:
"""Sanitize an object for safe pickling.
This function attempts to make an object pickleable by converting
problematic types to safe alternatives.
Args:
obj: The object to sanitize
Returns:
Any: A sanitized version of the object
"""
# Handle common problematic types
if hasattr(obj, "__dict__"):
# For objects with __dict__, try to sanitize attributes
sanitized = {}
for key, value in obj.__dict__.items():
if key.startswith("_"):
continue # Skip private attributes
# Convert non-pickleable types
if callable(value):
sanitized[key] = f"<function {value.__name__}>"
elif hasattr(value, "__module__"):
sanitized[key] = f"<{value.__class__.__name__} object>"
else:
try:
# Test if the value is pickleable
pickle.dumps(value, protocol=PICKLE_PROTOCOL)
sanitized[key] = value
except Exception:
sanitized[key] = str(value)
return sanitized
# For other types, return as-is and let pickle handle it
return obj