* wait I forgot to comit locally * cp the entire core directory and then rm the .git subdir
554 lines
18 KiB
Python
554 lines
18 KiB
Python
import asyncio
|
|
import threading
|
|
import traceback
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from opentelemetry import trace
|
|
from opentelemetry.trace import Status, StatusCode
|
|
from sqlalchemy import Engine, event
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm.loading import load_on_ident, load_on_pk_identity
|
|
from sqlalchemy.orm.strategies import ImmediateLoader, JoinedLoader, LazyLoader, SelectInLoader, SubqueryLoader
|
|
|
|
_config = {
|
|
"enabled": True,
|
|
"sql_truncate_length": 1000,
|
|
"monitor_joined_loading": True,
|
|
"log_instrumentation_errors": True,
|
|
}
|
|
|
|
_instrumentation_state = {
|
|
"engine_listeners": [],
|
|
"session_listeners": [],
|
|
"original_methods": {},
|
|
"active": False,
|
|
}
|
|
|
|
_context = threading.local()
|
|
|
|
|
|
def _get_tracer():
|
|
"""Get the OpenTelemetry tracer for SQLAlchemy instrumentation."""
|
|
return trace.get_tracer("sqlalchemy_sync_instrumentation", "1.0.0")
|
|
|
|
|
|
def _is_event_loop_running() -> bool:
|
|
"""Check if an asyncio event loop is running in the current thread."""
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
return loop.is_running()
|
|
except RuntimeError:
|
|
return False
|
|
|
|
|
|
def _is_main_thread() -> bool:
|
|
"""Check if we're running on the main thread."""
|
|
return threading.current_thread() is threading.main_thread()
|
|
|
|
|
|
def _truncate_sql(sql: str, max_length: int = 1000) -> str:
|
|
"""Truncate SQL statement to specified length."""
|
|
if len(sql) <= max_length:
|
|
return sql
|
|
return sql[: max_length - 3] + "..."
|
|
|
|
|
|
def _create_sync_db_span(
|
|
operation_type: str,
|
|
sql_statement: Optional[str] = None,
|
|
loader_type: Optional[str] = None,
|
|
relationship_key: Optional[str] = None,
|
|
is_joined: bool = False,
|
|
additional_attrs: Optional[Dict[str, Any]] = None,
|
|
) -> Any:
|
|
"""
|
|
Create an OpenTelemetry span for a synchronous database operation.
|
|
|
|
Args:
|
|
operation_type: Type of database operation
|
|
sql_statement: SQL statement being executed
|
|
loader_type: Type of SQLAlchemy loader (selectin, joined, lazy, etc.)
|
|
relationship_key: Name of relationship attribute if applicable
|
|
is_joined: Whether this is from joined loading
|
|
additional_attrs: Additional attributes to add to the span
|
|
|
|
Returns:
|
|
OpenTelemetry span
|
|
"""
|
|
if not _config["enabled"]:
|
|
return None
|
|
|
|
# Only create spans for potentially problematic operations
|
|
if not _is_event_loop_running():
|
|
return None
|
|
|
|
tracer = _get_tracer()
|
|
span = tracer.start_span("db_operation")
|
|
|
|
# Set core attributes
|
|
span.set_attribute("db.operation.type", operation_type)
|
|
|
|
# SQL statement
|
|
if sql_statement:
|
|
span.set_attribute("db.statement", _truncate_sql(sql_statement, _config["sql_truncate_length"]))
|
|
|
|
# Loader information
|
|
if loader_type:
|
|
span.set_attribute("sqlalchemy.loader.type", loader_type)
|
|
span.set_attribute("sqlalchemy.loader.is_joined", is_joined)
|
|
|
|
# Relationship information
|
|
if relationship_key:
|
|
span.set_attribute("sqlalchemy.relationship.key", relationship_key)
|
|
|
|
# Additional attributes
|
|
if additional_attrs:
|
|
for key, value in additional_attrs.items():
|
|
span.set_attribute(key, value)
|
|
|
|
return span
|
|
|
|
|
|
def _instrument_engine_events(engine: Engine) -> None:
|
|
"""Instrument SQLAlchemy engine events to detect sync operations."""
|
|
|
|
# Check if this is an AsyncEngine and get its sync_engine if it is
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
|
|
if isinstance(engine, AsyncEngine):
|
|
engine = engine.sync_engine
|
|
|
|
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
|
"""Track cursor execution start."""
|
|
if not _config["enabled"]:
|
|
return
|
|
|
|
# Store context for the after event
|
|
context._sync_instrumentation_span = _create_sync_db_span(
|
|
operation_type="cursor_execute",
|
|
sql_statement=statement,
|
|
additional_attrs={
|
|
"db.executemany": executemany,
|
|
"db.connection.info": str(conn.info),
|
|
},
|
|
)
|
|
|
|
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
|
"""Track cursor execution completion."""
|
|
if not _config["enabled"]:
|
|
return
|
|
|
|
span = getattr(context, "_sync_instrumentation_span", None)
|
|
if span:
|
|
span.set_status(Status(StatusCode.OK))
|
|
span.end()
|
|
context._sync_instrumentation_span = None
|
|
|
|
def handle_cursor_error(exception_context):
|
|
"""Handle cursor execution errors."""
|
|
if not _config["enabled"]:
|
|
return
|
|
|
|
# Extract context from exception_context
|
|
context = getattr(exception_context, "execution_context", None)
|
|
if not context:
|
|
return
|
|
|
|
span = getattr(context, "_sync_instrumentation_span", None)
|
|
if span:
|
|
span.set_status(Status(StatusCode.ERROR, "Database operation failed"))
|
|
span.end()
|
|
context._sync_instrumentation_span = None
|
|
|
|
# Register engine events
|
|
event.listen(engine, "before_cursor_execute", before_cursor_execute)
|
|
event.listen(engine, "after_cursor_execute", after_cursor_execute)
|
|
event.listen(engine, "handle_error", handle_cursor_error)
|
|
|
|
# Store listeners for cleanup
|
|
_instrumentation_state["engine_listeners"].extend(
|
|
[
|
|
(engine, "before_cursor_execute", before_cursor_execute),
|
|
(engine, "after_cursor_execute", after_cursor_execute),
|
|
(engine, "handle_error", handle_cursor_error),
|
|
]
|
|
)
|
|
|
|
|
|
def _instrument_loader_strategies() -> None:
|
|
"""Instrument SQLAlchemy loader strategies to detect lazy loading."""
|
|
|
|
def create_loader_wrapper(loader_class: type, loader_type: str, is_joined: bool = False):
|
|
"""Create a wrapper for loader strategy methods."""
|
|
|
|
def wrapper(original_method: Callable):
|
|
@wraps(original_method)
|
|
def instrumented_method(self, *args, **kwargs):
|
|
# Extract relationship information if available
|
|
relationship_key = getattr(self, "key", None)
|
|
if hasattr(self, "parent_property"):
|
|
relationship_key = getattr(self.parent_property, "key", relationship_key)
|
|
|
|
span = _create_sync_db_span(
|
|
operation_type="loader_strategy",
|
|
loader_type=loader_type,
|
|
relationship_key=relationship_key,
|
|
is_joined=is_joined,
|
|
additional_attrs={
|
|
"sqlalchemy.loader.class": loader_class.__name__,
|
|
"sqlalchemy.loader.method": original_method.__name__,
|
|
},
|
|
)
|
|
|
|
try:
|
|
result = original_method(self, *args, **kwargs)
|
|
if span:
|
|
span.set_status(Status(StatusCode.OK))
|
|
return result
|
|
except Exception as e:
|
|
if span:
|
|
span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
raise
|
|
finally:
|
|
if span:
|
|
span.end()
|
|
|
|
return instrumented_method
|
|
|
|
return wrapper
|
|
|
|
# Instrument different loader strategies
|
|
loaders_to_instrument = [
|
|
(SelectInLoader, "selectin", False),
|
|
(JoinedLoader, "joined", True),
|
|
(LazyLoader, "lazy", False),
|
|
(SubqueryLoader, "subquery", False),
|
|
(ImmediateLoader, "immediate", False),
|
|
]
|
|
|
|
for loader_class, loader_type, is_joined in loaders_to_instrument:
|
|
# Skip if monitoring joined loading is disabled
|
|
if is_joined and not _config["monitor_joined_loading"]:
|
|
continue
|
|
|
|
wrapper = create_loader_wrapper(loader_class, loader_type, is_joined)
|
|
|
|
# Instrument key methods
|
|
methods_to_instrument = ["_load_for_path", "load_for_path"]
|
|
|
|
for method_name in methods_to_instrument:
|
|
if hasattr(loader_class, method_name):
|
|
original_method = getattr(loader_class, method_name)
|
|
key = f"{loader_class.__name__}.{method_name}"
|
|
|
|
# Store original method for cleanup
|
|
_instrumentation_state["original_methods"][key] = original_method
|
|
|
|
# Apply wrapper
|
|
setattr(loader_class, method_name, wrapper(original_method))
|
|
|
|
# Instrument additional joined loading specific methods
|
|
if _config["monitor_joined_loading"]:
|
|
joined_methods = [
|
|
(JoinedLoader, "_create_eager_join"),
|
|
(JoinedLoader, "_generate_cache_key"),
|
|
]
|
|
|
|
wrapper = create_loader_wrapper(JoinedLoader, "joined", True)
|
|
|
|
for loader_class, method_name in joined_methods:
|
|
if hasattr(loader_class, method_name):
|
|
original_method = getattr(loader_class, method_name)
|
|
key = f"{loader_class.__name__}.{method_name}"
|
|
|
|
_instrumentation_state["original_methods"][key] = original_method
|
|
setattr(loader_class, method_name, wrapper(original_method))
|
|
|
|
|
|
def _instrument_loading_functions() -> None:
|
|
"""Instrument SQLAlchemy loading functions."""
|
|
|
|
def create_loading_wrapper(func_name: str):
|
|
"""Create a wrapper for loading functions."""
|
|
|
|
def wrapper(original_func: Callable):
|
|
@wraps(original_func)
|
|
def instrumented_func(*args, **kwargs):
|
|
span = _create_sync_db_span(
|
|
operation_type="loading_function",
|
|
additional_attrs={
|
|
"sqlalchemy.loading.function": func_name,
|
|
},
|
|
)
|
|
|
|
try:
|
|
result = original_func(*args, **kwargs)
|
|
if span:
|
|
span.set_status(Status(StatusCode.OK))
|
|
return result
|
|
except Exception as e:
|
|
if span:
|
|
span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
raise
|
|
finally:
|
|
if span:
|
|
span.end()
|
|
|
|
return instrumented_func
|
|
|
|
return wrapper
|
|
|
|
# Instrument loading functions
|
|
import sqlalchemy.orm.loading as loading_module
|
|
|
|
functions_to_instrument = [
|
|
(loading_module, "load_on_ident", load_on_ident),
|
|
(loading_module, "load_on_pk_identity", load_on_pk_identity),
|
|
]
|
|
|
|
for module, func_name, original_func in functions_to_instrument:
|
|
wrapper = create_loading_wrapper(func_name)
|
|
|
|
# Store original function for cleanup
|
|
_instrumentation_state["original_methods"][f"loading.{func_name}"] = original_func
|
|
|
|
# Apply wrapper
|
|
setattr(module, func_name, wrapper(original_func))
|
|
|
|
|
|
def _instrument_session_operations() -> None:
|
|
"""Instrument SQLAlchemy session operations."""
|
|
|
|
def before_flush(session, flush_context, instances):
|
|
"""Track session flush operations."""
|
|
if not _config["enabled"]:
|
|
return
|
|
|
|
span = _create_sync_db_span(
|
|
operation_type="session_flush",
|
|
additional_attrs={
|
|
"sqlalchemy.session.new_count": len(session.new),
|
|
"sqlalchemy.session.dirty_count": len(session.dirty),
|
|
"sqlalchemy.session.deleted_count": len(session.deleted),
|
|
},
|
|
)
|
|
|
|
# Store span in session for cleanup
|
|
session._sync_instrumentation_flush_span = span
|
|
|
|
def after_flush(session, flush_context):
|
|
"""Track session flush completion."""
|
|
if not _config["enabled"]:
|
|
return
|
|
|
|
span = getattr(session, "_sync_instrumentation_flush_span", None)
|
|
if span:
|
|
span.set_status(Status(StatusCode.OK))
|
|
span.end()
|
|
session._sync_instrumentation_flush_span = None
|
|
|
|
def after_flush_postexec(session, flush_context):
|
|
"""Track session flush post-execution."""
|
|
if not _config["enabled"]:
|
|
return
|
|
|
|
span = getattr(session, "_sync_instrumentation_flush_span", None)
|
|
if span:
|
|
span.set_status(Status(StatusCode.OK))
|
|
span.end()
|
|
session._sync_instrumentation_flush_span = None
|
|
|
|
# Register session events
|
|
event.listen(Session, "before_flush", before_flush)
|
|
event.listen(Session, "after_flush", after_flush)
|
|
event.listen(Session, "after_flush_postexec", after_flush_postexec)
|
|
|
|
# Store listeners for cleanup
|
|
_instrumentation_state["session_listeners"].extend(
|
|
[
|
|
(Session, "before_flush", before_flush),
|
|
(Session, "after_flush", after_flush),
|
|
(Session, "after_flush_postexec", after_flush_postexec),
|
|
]
|
|
)
|
|
|
|
|
|
def setup_sqlalchemy_sync_instrumentation(
|
|
engines: Optional[List[Engine]] = None,
|
|
config_overrides: Optional[Dict[str, Any]] = None,
|
|
lazy_loading_only: bool = True,
|
|
) -> None:
|
|
"""
|
|
Set up SQLAlchemy synchronous operation instrumentation.
|
|
|
|
Args:
|
|
engines: List of SQLAlchemy engines to instrument. If None, will attempt
|
|
to discover engines automatically.
|
|
config_overrides: Dictionary of configuration overrides.
|
|
lazy_loading_only: If True, only instrument lazy loading operations.
|
|
"""
|
|
if _instrumentation_state["active"]:
|
|
return # Already active
|
|
|
|
try:
|
|
# Apply configuration overrides
|
|
if config_overrides:
|
|
_config.update(config_overrides)
|
|
|
|
# If lazy_loading_only is True, update config to focus on lazy loading
|
|
if lazy_loading_only:
|
|
_config.update(
|
|
{
|
|
"monitor_joined_loading": False, # Don't monitor joined loading
|
|
}
|
|
)
|
|
|
|
# Discover engines if not provided
|
|
if engines is None:
|
|
engines = []
|
|
# Try to find engines from the database registry
|
|
try:
|
|
from letta.server.db import db_registry
|
|
|
|
if hasattr(db_registry, "_async_engines"):
|
|
engines.extend(db_registry._async_engines.values())
|
|
if hasattr(db_registry, "_sync_engines"):
|
|
engines.extend(db_registry._sync_engines.values())
|
|
except ImportError:
|
|
pass
|
|
|
|
# Instrument loader strategies (focus on lazy loading if specified)
|
|
_instrument_loader_strategies()
|
|
|
|
# Instrument loading functions
|
|
_instrument_loading_functions()
|
|
|
|
# Instrument session operations
|
|
_instrument_session_operations()
|
|
|
|
# Instrument engines last to avoid potential errors with async engines
|
|
for engine in engines:
|
|
try:
|
|
_instrument_engine_events(engine)
|
|
except Exception as e:
|
|
if _config["log_instrumentation_errors"]:
|
|
print(f"Error instrumenting engine {engine}: {e}")
|
|
# Continue with other engines
|
|
|
|
_instrumentation_state["active"] = True
|
|
|
|
except Exception as e:
|
|
if _config["log_instrumentation_errors"]:
|
|
print(f"Error setting up SQLAlchemy instrumentation: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
|
|
def teardown_sqlalchemy_sync_instrumentation() -> None:
|
|
"""Tear down SQLAlchemy synchronous operation instrumentation."""
|
|
if not _instrumentation_state["active"]:
|
|
return # Not active
|
|
|
|
try:
|
|
# Remove engine listeners
|
|
for engine, event_name, listener in _instrumentation_state["engine_listeners"]:
|
|
event.remove(engine, event_name, listener)
|
|
|
|
# Remove session listeners
|
|
for target, event_name, listener in _instrumentation_state["session_listeners"]:
|
|
event.remove(target, event_name, listener)
|
|
|
|
# Restore original methods
|
|
for key, original_method in _instrumentation_state["original_methods"].items():
|
|
if "." in key:
|
|
module_or_class_name, method_name = key.rsplit(".", 1)
|
|
|
|
if key.startswith("loading."):
|
|
# Restore loading function
|
|
import sqlalchemy.orm.loading as loading_module
|
|
|
|
setattr(loading_module, method_name, original_method)
|
|
else:
|
|
# Restore class method
|
|
class_name = module_or_class_name
|
|
# Find the class
|
|
for cls in [SelectInLoader, JoinedLoader, LazyLoader, SubqueryLoader, ImmediateLoader]:
|
|
if cls.__name__ == class_name:
|
|
setattr(cls, method_name, original_method)
|
|
break
|
|
|
|
# Clear state
|
|
_instrumentation_state["engine_listeners"].clear()
|
|
_instrumentation_state["session_listeners"].clear()
|
|
_instrumentation_state["original_methods"].clear()
|
|
_instrumentation_state["active"] = False
|
|
|
|
except Exception as e:
|
|
if _config["log_instrumentation_errors"]:
|
|
print(f"Error tearing down SQLAlchemy instrumentation: {e}")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
|
|
def configure_instrumentation(**kwargs) -> None:
|
|
"""
|
|
Configure SQLAlchemy synchronous operation instrumentation.
|
|
|
|
Args:
|
|
**kwargs: Configuration options to update.
|
|
"""
|
|
_config.update(kwargs)
|
|
|
|
|
|
def get_instrumentation_config() -> Dict[str, Any]:
|
|
"""Get current instrumentation configuration."""
|
|
return _config.copy()
|
|
|
|
|
|
def is_instrumentation_active() -> bool:
|
|
"""Check if instrumentation is currently active."""
|
|
return _instrumentation_state["active"]
|
|
|
|
|
|
# Context manager for temporary instrumentation
|
|
@contextmanager
|
|
def temporary_instrumentation(**config_overrides):
|
|
"""
|
|
Context manager for temporary SQLAlchemy instrumentation.
|
|
|
|
Args:
|
|
**config_overrides: Configuration overrides for the instrumentation.
|
|
"""
|
|
was_active = _instrumentation_state["active"]
|
|
|
|
if not was_active:
|
|
setup_sqlalchemy_sync_instrumentation(config_overrides=config_overrides)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
if not was_active:
|
|
teardown_sqlalchemy_sync_instrumentation()
|
|
|
|
|
|
# FastAPI integration helper
|
|
def setup_fastapi_instrumentation(app):
|
|
"""
|
|
Set up SQLAlchemy instrumentation for FastAPI application.
|
|
|
|
Args:
|
|
app: FastAPI application instance
|
|
"""
|
|
|
|
@app.on_event("startup")
|
|
async def startup_instrumentation():
|
|
setup_sqlalchemy_sync_instrumentation()
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_instrumentation():
|
|
teardown_sqlalchemy_sync_instrumentation()
|