Files
letta-server/letta/otel/sqlalchemy_instrumentation.py
Kian Jones 720fc9c758 feat: add statement_timeout to SQLAlchemy OTEL spans (#9220)
Queries Postgres for statement_timeout on connection checkout and adds
it as db.statement_timeout attribute on cursor execution spans.

🤖 Generated with [Letta Code](https://letta.com)

Co-authored-by: Letta <noreply@letta.com>
2026-02-24 10:52:06 -08:00

570 lines
19 KiB
Python

import asyncio
import threading
import traceback
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
from sqlalchemy import Engine, event
from sqlalchemy.orm import Session
from sqlalchemy.orm.loading import load_on_ident, load_on_pk_identity
from sqlalchemy.orm.strategies import ImmediateLoader, JoinedLoader, LazyLoader, SelectInLoader, SubqueryLoader
_config = {
"enabled": True,
"sql_truncate_length": 1000,
"monitor_joined_loading": True,
"log_instrumentation_errors": True,
}
_instrumentation_state = {
"engine_listeners": [],
"session_listeners": [],
"original_methods": {},
"active": False,
}
_context = threading.local()
def _get_tracer():
"""Get the OpenTelemetry tracer for SQLAlchemy instrumentation."""
return trace.get_tracer("sqlalchemy_sync_instrumentation", "1.0.0")
def _is_event_loop_running() -> bool:
"""Check if an asyncio event loop is running in the current thread."""
try:
loop = asyncio.get_running_loop()
return loop.is_running()
except RuntimeError:
return False
def _is_main_thread() -> bool:
"""Check if we're running on the main thread."""
return threading.current_thread() is threading.main_thread()
def _truncate_sql(sql: str, max_length: int = 1000) -> str:
"""Truncate SQL statement to specified length."""
if len(sql) <= max_length:
return sql
return sql[: max_length - 3] + "..."
def _create_sync_db_span(
operation_type: str,
sql_statement: Optional[str] = None,
loader_type: Optional[str] = None,
relationship_key: Optional[str] = None,
is_joined: bool = False,
additional_attrs: Optional[Dict[str, Any]] = None,
) -> Any:
"""
Create an OpenTelemetry span for a synchronous database operation.
Args:
operation_type: Type of database operation
sql_statement: SQL statement being executed
loader_type: Type of SQLAlchemy loader (selectin, joined, lazy, etc.)
relationship_key: Name of relationship attribute if applicable
is_joined: Whether this is from joined loading
additional_attrs: Additional attributes to add to the span
Returns:
OpenTelemetry span
"""
if not _config["enabled"]:
return None
# Only create spans for potentially problematic operations
if not _is_event_loop_running():
return None
tracer = _get_tracer()
span = tracer.start_span("db_operation")
# Set core attributes
span.set_attribute("db.operation.type", operation_type)
# SQL statement
if sql_statement:
span.set_attribute("db.statement", _truncate_sql(sql_statement, _config["sql_truncate_length"]))
# Loader information
if loader_type:
span.set_attribute("sqlalchemy.loader.type", loader_type)
span.set_attribute("sqlalchemy.loader.is_joined", is_joined)
# Relationship information
if relationship_key:
span.set_attribute("sqlalchemy.relationship.key", relationship_key)
# Additional attributes
if additional_attrs:
for key, value in additional_attrs.items():
span.set_attribute(key, value)
return span
def _instrument_engine_events(engine: Engine) -> None:
"""Instrument SQLAlchemy engine events to detect sync operations."""
# Check if this is an AsyncEngine and get its sync_engine if it is
from sqlalchemy.ext.asyncio import AsyncEngine
if isinstance(engine, AsyncEngine):
engine = engine.sync_engine
def checkout(dbapi_conn, connection_record, connection_proxy):
"""Query and cache statement_timeout on connection checkout."""
try:
cursor = dbapi_conn.cursor()
cursor.execute("SHOW statement_timeout")
result = cursor.fetchone()
connection_record.info["statement_timeout"] = result[0] if result else None
cursor.close()
except Exception:
connection_record.info["statement_timeout"] = None
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""Track cursor execution start."""
if not _config["enabled"]:
return
statement_timeout = conn.info.get("statement_timeout")
# Store context for the after event
context._sync_instrumentation_span = _create_sync_db_span(
operation_type="cursor_execute",
sql_statement=statement,
additional_attrs={
"db.executemany": executemany,
"db.connection.info": str(conn.info),
"db.statement_timeout": statement_timeout,
},
)
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""Track cursor execution completion."""
if not _config["enabled"]:
return
span = getattr(context, "_sync_instrumentation_span", None)
if span:
span.set_status(Status(StatusCode.OK))
span.end()
context._sync_instrumentation_span = None
def handle_cursor_error(exception_context):
"""Handle cursor execution errors."""
if not _config["enabled"]:
return
# Extract context from exception_context
context = getattr(exception_context, "execution_context", None)
if not context:
return
span = getattr(context, "_sync_instrumentation_span", None)
if span:
span.set_status(Status(StatusCode.ERROR, "Database operation failed"))
span.end()
context._sync_instrumentation_span = None
# Register engine events
event.listen(engine.pool, "checkout", checkout)
event.listen(engine, "before_cursor_execute", before_cursor_execute)
event.listen(engine, "after_cursor_execute", after_cursor_execute)
event.listen(engine, "handle_error", handle_cursor_error)
# Store listeners for cleanup
_instrumentation_state["engine_listeners"].extend(
[
(engine.pool, "checkout", checkout),
(engine, "before_cursor_execute", before_cursor_execute),
(engine, "after_cursor_execute", after_cursor_execute),
(engine, "handle_error", handle_cursor_error),
]
)
def _instrument_loader_strategies() -> None:
"""Instrument SQLAlchemy loader strategies to detect lazy loading."""
def create_loader_wrapper(loader_class: type, loader_type: str, is_joined: bool = False):
"""Create a wrapper for loader strategy methods."""
def wrapper(original_method: Callable):
@wraps(original_method)
def instrumented_method(self, *args, **kwargs):
# Extract relationship information if available
relationship_key = getattr(self, "key", None)
if hasattr(self, "parent_property"):
relationship_key = getattr(self.parent_property, "key", relationship_key)
span = _create_sync_db_span(
operation_type="loader_strategy",
loader_type=loader_type,
relationship_key=relationship_key,
is_joined=is_joined,
additional_attrs={
"sqlalchemy.loader.class": loader_class.__name__,
"sqlalchemy.loader.method": original_method.__name__,
},
)
try:
result = original_method(self, *args, **kwargs)
if span:
span.set_status(Status(StatusCode.OK))
return result
except Exception as e:
if span:
span.set_status(Status(StatusCode.ERROR, str(e)))
raise
finally:
if span:
span.end()
return instrumented_method
return wrapper
# Instrument different loader strategies
loaders_to_instrument = [
(SelectInLoader, "selectin", False),
(JoinedLoader, "joined", True),
(LazyLoader, "lazy", False),
(SubqueryLoader, "subquery", False),
(ImmediateLoader, "immediate", False),
]
for loader_class, loader_type, is_joined in loaders_to_instrument:
# Skip if monitoring joined loading is disabled
if is_joined and not _config["monitor_joined_loading"]:
continue
wrapper = create_loader_wrapper(loader_class, loader_type, is_joined)
# Instrument key methods
methods_to_instrument = ["_load_for_path", "load_for_path"]
for method_name in methods_to_instrument:
if hasattr(loader_class, method_name):
original_method = getattr(loader_class, method_name)
key = f"{loader_class.__name__}.{method_name}"
# Store original method for cleanup
_instrumentation_state["original_methods"][key] = original_method
# Apply wrapper
setattr(loader_class, method_name, wrapper(original_method))
# Instrument additional joined loading specific methods
if _config["monitor_joined_loading"]:
joined_methods = [
(JoinedLoader, "_create_eager_join"),
(JoinedLoader, "_generate_cache_key"),
]
wrapper = create_loader_wrapper(JoinedLoader, "joined", True)
for loader_class, method_name in joined_methods:
if hasattr(loader_class, method_name):
original_method = getattr(loader_class, method_name)
key = f"{loader_class.__name__}.{method_name}"
_instrumentation_state["original_methods"][key] = original_method
setattr(loader_class, method_name, wrapper(original_method))
def _instrument_loading_functions() -> None:
"""Instrument SQLAlchemy loading functions."""
def create_loading_wrapper(func_name: str):
"""Create a wrapper for loading functions."""
def wrapper(original_func: Callable):
@wraps(original_func)
def instrumented_func(*args, **kwargs):
span = _create_sync_db_span(
operation_type="loading_function",
additional_attrs={
"sqlalchemy.loading.function": func_name,
},
)
try:
result = original_func(*args, **kwargs)
if span:
span.set_status(Status(StatusCode.OK))
return result
except Exception as e:
if span:
span.set_status(Status(StatusCode.ERROR, str(e)))
raise
finally:
if span:
span.end()
return instrumented_func
return wrapper
# Instrument loading functions
import sqlalchemy.orm.loading as loading_module
functions_to_instrument = [
(loading_module, "load_on_ident", load_on_ident),
(loading_module, "load_on_pk_identity", load_on_pk_identity),
]
for module, func_name, original_func in functions_to_instrument:
wrapper = create_loading_wrapper(func_name)
# Store original function for cleanup
_instrumentation_state["original_methods"][f"loading.{func_name}"] = original_func
# Apply wrapper
setattr(module, func_name, wrapper(original_func))
def _instrument_session_operations() -> None:
"""Instrument SQLAlchemy session operations."""
def before_flush(session, flush_context, instances):
"""Track session flush operations."""
if not _config["enabled"]:
return
span = _create_sync_db_span(
operation_type="session_flush",
additional_attrs={
"sqlalchemy.session.new_count": len(session.new),
"sqlalchemy.session.dirty_count": len(session.dirty),
"sqlalchemy.session.deleted_count": len(session.deleted),
},
)
# Store span in session for cleanup
session._sync_instrumentation_flush_span = span
def after_flush(session, flush_context):
"""Track session flush completion."""
if not _config["enabled"]:
return
span = getattr(session, "_sync_instrumentation_flush_span", None)
if span:
span.set_status(Status(StatusCode.OK))
span.end()
session._sync_instrumentation_flush_span = None
def after_flush_postexec(session, flush_context):
"""Track session flush post-execution."""
if not _config["enabled"]:
return
span = getattr(session, "_sync_instrumentation_flush_span", None)
if span:
span.set_status(Status(StatusCode.OK))
span.end()
session._sync_instrumentation_flush_span = None
# Register session events
event.listen(Session, "before_flush", before_flush)
event.listen(Session, "after_flush", after_flush)
event.listen(Session, "after_flush_postexec", after_flush_postexec)
# Store listeners for cleanup
_instrumentation_state["session_listeners"].extend(
[
(Session, "before_flush", before_flush),
(Session, "after_flush", after_flush),
(Session, "after_flush_postexec", after_flush_postexec),
]
)
def setup_sqlalchemy_sync_instrumentation(
engines: Optional[List[Engine]] = None,
config_overrides: Optional[Dict[str, Any]] = None,
lazy_loading_only: bool = True,
) -> None:
"""
Set up SQLAlchemy synchronous operation instrumentation.
Args:
engines: List of SQLAlchemy engines to instrument. If None, will attempt
to discover engines automatically.
config_overrides: Dictionary of configuration overrides.
lazy_loading_only: If True, only instrument lazy loading operations.
"""
if _instrumentation_state["active"]:
return # Already active
try:
# Apply configuration overrides
if config_overrides:
_config.update(config_overrides)
# If lazy_loading_only is True, update config to focus on lazy loading
if lazy_loading_only:
_config.update(
{
"monitor_joined_loading": False, # Don't monitor joined loading
}
)
# Discover engines if not provided
if engines is None:
engines = []
# Try to find engines from the database registry
try:
from letta.server.db import db_registry
if hasattr(db_registry, "_async_engines"):
engines.extend(db_registry._async_engines.values())
if hasattr(db_registry, "_sync_engines"):
engines.extend(db_registry._sync_engines.values())
except ImportError:
pass
# Instrument loader strategies (focus on lazy loading if specified)
_instrument_loader_strategies()
# Instrument loading functions
_instrument_loading_functions()
# Instrument session operations
_instrument_session_operations()
# Instrument engines last to avoid potential errors with async engines
for engine in engines:
try:
_instrument_engine_events(engine)
except Exception as e:
if _config["log_instrumentation_errors"]:
print(f"Error instrumenting engine {engine}: {e}")
# Continue with other engines
_instrumentation_state["active"] = True
except Exception as e:
if _config["log_instrumentation_errors"]:
print(f"Error setting up SQLAlchemy instrumentation: {e}")
import traceback
traceback.print_exc()
raise
def teardown_sqlalchemy_sync_instrumentation() -> None:
"""Tear down SQLAlchemy synchronous operation instrumentation."""
if not _instrumentation_state["active"]:
return # Not active
try:
# Remove engine listeners
for engine, event_name, listener in _instrumentation_state["engine_listeners"]:
event.remove(engine, event_name, listener)
# Remove session listeners
for target, event_name, listener in _instrumentation_state["session_listeners"]:
event.remove(target, event_name, listener)
# Restore original methods
for key, original_method in _instrumentation_state["original_methods"].items():
if "." in key:
module_or_class_name, method_name = key.rsplit(".", 1)
if key.startswith("loading."):
# Restore loading function
import sqlalchemy.orm.loading as loading_module
setattr(loading_module, method_name, original_method)
else:
# Restore class method
class_name = module_or_class_name
# Find the class
for cls in [SelectInLoader, JoinedLoader, LazyLoader, SubqueryLoader, ImmediateLoader]:
if cls.__name__ == class_name:
setattr(cls, method_name, original_method)
break
# Clear state
_instrumentation_state["engine_listeners"].clear()
_instrumentation_state["session_listeners"].clear()
_instrumentation_state["original_methods"].clear()
_instrumentation_state["active"] = False
except Exception as e:
if _config["log_instrumentation_errors"]:
print(f"Error tearing down SQLAlchemy instrumentation: {e}")
traceback.print_exc()
raise
def configure_instrumentation(**kwargs) -> None:
"""
Configure SQLAlchemy synchronous operation instrumentation.
Args:
**kwargs: Configuration options to update.
"""
_config.update(kwargs)
def get_instrumentation_config() -> Dict[str, Any]:
"""Get current instrumentation configuration."""
return _config.copy()
def is_instrumentation_active() -> bool:
"""Check if instrumentation is currently active."""
return _instrumentation_state["active"]
# Context manager for temporary instrumentation
@contextmanager
def temporary_instrumentation(**config_overrides):
"""
Context manager for temporary SQLAlchemy instrumentation.
Args:
**config_overrides: Configuration overrides for the instrumentation.
"""
was_active = _instrumentation_state["active"]
if not was_active:
setup_sqlalchemy_sync_instrumentation(config_overrides=config_overrides)
try:
yield
finally:
if not was_active:
teardown_sqlalchemy_sync_instrumentation()
# FastAPI integration helper
def setup_fastapi_instrumentation(app):
"""
Set up SQLAlchemy instrumentation for FastAPI application.
Args:
app: FastAPI application instance
"""
@app.on_event("startup")
async def startup_instrumentation():
setup_sqlalchemy_sync_instrumentation()
@app.on_event("shutdown")
async def shutdown_instrumentation():
teardown_sqlalchemy_sync_instrumentation()