206 lines
7.3 KiB
Python
206 lines
7.3 KiB
Python
import asyncio
|
||
import inspect
|
||
import sys
|
||
import time
|
||
from functools import wraps
|
||
from typing import Any, Dict, Optional
|
||
|
||
from fastapi import Request
|
||
from opentelemetry import trace
|
||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||
from opentelemetry.instrumentation.requests import RequestsInstrumentor
|
||
from opentelemetry.sdk.resources import Resource
|
||
from opentelemetry.sdk.trace import TracerProvider
|
||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||
from opentelemetry.trace import Span, Status, StatusCode
|
||
|
||
# Get a tracer instance - will be no-op until setup_tracing is called
|
||
tracer = trace.get_tracer(__name__)
|
||
|
||
# Track if tracing has been initialized
|
||
_is_tracing_initialized = False
|
||
|
||
|
||
def is_pytest_environment():
|
||
"""Check if we're running in pytest"""
|
||
return "pytest" in sys.modules
|
||
|
||
|
||
def trace_method(name=None):
|
||
"""Decorator to add tracing to a method"""
|
||
|
||
def decorator(func):
|
||
@wraps(func)
|
||
async def async_wrapper(*args, **kwargs):
|
||
# Skip tracing if not initialized
|
||
if not _is_tracing_initialized:
|
||
return await func(*args, **kwargs)
|
||
|
||
span_name = name or func.__name__
|
||
with tracer.start_as_current_span(span_name) as span:
|
||
span.set_attribute("code.namespace", inspect.getmodule(func).__name__)
|
||
span.set_attribute("code.function", func.__name__)
|
||
|
||
if len(args) > 0 and hasattr(args[0], "__class__"):
|
||
span.set_attribute("code.class", args[0].__class__.__name__)
|
||
|
||
request = _extract_request_info(args, span)
|
||
if request and len(request) > 0:
|
||
span.set_attribute("agent.id", kwargs.get("agent_id"))
|
||
span.set_attribute("actor.id", request.get("http.user_id"))
|
||
|
||
try:
|
||
result = await func(*args, **kwargs)
|
||
span.set_status(Status(StatusCode.OK))
|
||
return result
|
||
except Exception as e:
|
||
span.set_status(Status(StatusCode.ERROR))
|
||
span.record_exception(e)
|
||
raise
|
||
|
||
@wraps(func)
|
||
def sync_wrapper(*args, **kwargs):
|
||
# Skip tracing if not initialized
|
||
if not _is_tracing_initialized:
|
||
return func(*args, **kwargs)
|
||
|
||
span_name = name or func.__name__
|
||
with tracer.start_as_current_span(span_name) as span:
|
||
span.set_attribute("code.namespace", inspect.getmodule(func).__name__)
|
||
span.set_attribute("code.function", func.__name__)
|
||
|
||
if len(args) > 0 and hasattr(args[0], "__class__"):
|
||
span.set_attribute("code.class", args[0].__class__.__name__)
|
||
|
||
request = _extract_request_info(args, span)
|
||
if request and len(request) > 0:
|
||
span.set_attribute("agent.id", kwargs.get("agent_id"))
|
||
span.set_attribute("actor.id", request.get("http.user_id"))
|
||
|
||
try:
|
||
result = func(*args, **kwargs)
|
||
span.set_status(Status(StatusCode.OK))
|
||
return result
|
||
except Exception as e:
|
||
span.set_status(Status(StatusCode.ERROR))
|
||
span.record_exception(e)
|
||
raise
|
||
|
||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
def log_attributes(attributes: Dict[str, Any]) -> None:
|
||
"""
|
||
Log multiple attributes to the current active span.
|
||
|
||
Args:
|
||
attributes: Dictionary of attribute key-value pairs
|
||
"""
|
||
current_span = trace.get_current_span()
|
||
if current_span:
|
||
current_span.set_attributes(attributes)
|
||
|
||
|
||
def log_event(name: str, attributes: Optional[Dict[str, Any]] = None, timestamp: Optional[int] = None) -> None:
|
||
"""
|
||
Log an event to the current active span.
|
||
|
||
Args:
|
||
name: Name of the event
|
||
attributes: Optional dictionary of event attributes
|
||
timestamp: Optional timestamp in nanoseconds
|
||
"""
|
||
current_span = trace.get_current_span()
|
||
if current_span:
|
||
if timestamp is None:
|
||
timestamp = int(time.perf_counter_ns())
|
||
|
||
current_span.add_event(name=name, attributes=attributes, timestamp=timestamp)
|
||
|
||
|
||
def get_trace_id() -> str:
|
||
current_span = trace.get_current_span()
|
||
if current_span:
|
||
return format(current_span.get_span_context().trace_id, "032x")
|
||
else:
|
||
return ""
|
||
|
||
|
||
def request_hook(span: Span, _request_context: Optional[Dict] = None, response: Optional[Any] = None):
|
||
"""Hook to update span based on response status code"""
|
||
if response is not None:
|
||
if hasattr(response, "status_code"):
|
||
span.set_attribute("http.status_code", response.status_code)
|
||
if response.status_code >= 400:
|
||
span.set_status(Status(StatusCode.ERROR))
|
||
elif 200 <= response.status_code < 300:
|
||
span.set_status(Status(StatusCode.OK))
|
||
|
||
|
||
def setup_tracing(endpoint: str, service_name: str = "memgpt-server") -> None:
|
||
"""
|
||
Sets up OpenTelemetry tracing with OTLP exporter for specific endpoints
|
||
|
||
Args:
|
||
endpoint: OTLP endpoint URL
|
||
service_name: Name of the service for tracing
|
||
"""
|
||
global _is_tracing_initialized
|
||
|
||
# Skip tracing in pytest environment
|
||
if is_pytest_environment():
|
||
print("ℹ️ Skipping tracing setup in pytest environment")
|
||
return
|
||
|
||
# Create a Resource to identify our service
|
||
resource = Resource.create({"service.name": service_name, "service.namespace": "default", "deployment.environment": "production"})
|
||
|
||
# Initialize the TracerProvider with the resource
|
||
provider = TracerProvider(resource=resource)
|
||
|
||
# Only set up OTLP export if endpoint is provided
|
||
if endpoint:
|
||
otlp_exporter = OTLPSpanExporter(endpoint=endpoint)
|
||
processor = BatchSpanProcessor(otlp_exporter)
|
||
provider.add_span_processor(processor)
|
||
_is_tracing_initialized = True
|
||
else:
|
||
print("⚠️ Warning: Tracing endpoint not provided, tracing will be disabled")
|
||
|
||
# Set the global TracerProvider
|
||
trace.set_tracer_provider(provider)
|
||
|
||
# Initialize automatic instrumentation for the requests library with response hook
|
||
if _is_tracing_initialized:
|
||
RequestsInstrumentor().instrument(response_hook=request_hook)
|
||
|
||
|
||
def _extract_request_info(args: tuple, span: Span) -> Dict[str, Any]:
|
||
"""
|
||
Safely extracts request information from function arguments.
|
||
Works with both FastAPI route handlers and inner functions.
|
||
"""
|
||
attributes = {}
|
||
|
||
# Look for FastAPI Request object in args
|
||
request = next((arg for arg in args if isinstance(arg, Request)), None)
|
||
|
||
if request:
|
||
attributes.update(
|
||
{
|
||
"http.route": request.url.path,
|
||
"http.method": request.method,
|
||
"http.scheme": request.url.scheme,
|
||
"http.target": str(request.url.path),
|
||
"http.url": str(request.url),
|
||
"http.flavor": request.scope.get("http_version", ""),
|
||
"http.client_ip": request.client.host if request.client else None,
|
||
"http.user_id": request.headers.get("user_id"),
|
||
}
|
||
)
|
||
|
||
span.set_attributes(attributes)
|
||
return attributes
|