Files
letta-server/letta/tracing.py
2025-02-18 22:18:48 -08:00

206 lines
7.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import inspect
import sys
import time
from functools import wraps
from typing import Any, Dict, Optional
from fastapi import Request
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import Span, Status, StatusCode
# Get a tracer instance - will be no-op until setup_tracing is called
tracer = trace.get_tracer(__name__)
# Track if tracing has been initialized
_is_tracing_initialized = False
def is_pytest_environment():
"""Check if we're running in pytest"""
return "pytest" in sys.modules
def trace_method(name=None):
"""Decorator to add tracing to a method"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
# Skip tracing if not initialized
if not _is_tracing_initialized:
return await func(*args, **kwargs)
span_name = name or func.__name__
with tracer.start_as_current_span(span_name) as span:
span.set_attribute("code.namespace", inspect.getmodule(func).__name__)
span.set_attribute("code.function", func.__name__)
if len(args) > 0 and hasattr(args[0], "__class__"):
span.set_attribute("code.class", args[0].__class__.__name__)
request = _extract_request_info(args, span)
if request and len(request) > 0:
span.set_attribute("agent.id", kwargs.get("agent_id"))
span.set_attribute("actor.id", request.get("http.user_id"))
try:
result = await func(*args, **kwargs)
span.set_status(Status(StatusCode.OK))
return result
except Exception as e:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
raise
@wraps(func)
def sync_wrapper(*args, **kwargs):
# Skip tracing if not initialized
if not _is_tracing_initialized:
return func(*args, **kwargs)
span_name = name or func.__name__
with tracer.start_as_current_span(span_name) as span:
span.set_attribute("code.namespace", inspect.getmodule(func).__name__)
span.set_attribute("code.function", func.__name__)
if len(args) > 0 and hasattr(args[0], "__class__"):
span.set_attribute("code.class", args[0].__class__.__name__)
request = _extract_request_info(args, span)
if request and len(request) > 0:
span.set_attribute("agent.id", kwargs.get("agent_id"))
span.set_attribute("actor.id", request.get("http.user_id"))
try:
result = func(*args, **kwargs)
span.set_status(Status(StatusCode.OK))
return result
except Exception as e:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
raise
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
return decorator
def log_attributes(attributes: Dict[str, Any]) -> None:
"""
Log multiple attributes to the current active span.
Args:
attributes: Dictionary of attribute key-value pairs
"""
current_span = trace.get_current_span()
if current_span:
current_span.set_attributes(attributes)
def log_event(name: str, attributes: Optional[Dict[str, Any]] = None, timestamp: Optional[int] = None) -> None:
"""
Log an event to the current active span.
Args:
name: Name of the event
attributes: Optional dictionary of event attributes
timestamp: Optional timestamp in nanoseconds
"""
current_span = trace.get_current_span()
if current_span:
if timestamp is None:
timestamp = int(time.perf_counter_ns())
current_span.add_event(name=name, attributes=attributes, timestamp=timestamp)
def get_trace_id() -> str:
current_span = trace.get_current_span()
if current_span:
return format(current_span.get_span_context().trace_id, "032x")
else:
return ""
def request_hook(span: Span, _request_context: Optional[Dict] = None, response: Optional[Any] = None):
"""Hook to update span based on response status code"""
if response is not None:
if hasattr(response, "status_code"):
span.set_attribute("http.status_code", response.status_code)
if response.status_code >= 400:
span.set_status(Status(StatusCode.ERROR))
elif 200 <= response.status_code < 300:
span.set_status(Status(StatusCode.OK))
def setup_tracing(endpoint: str, service_name: str = "memgpt-server") -> None:
"""
Sets up OpenTelemetry tracing with OTLP exporter for specific endpoints
Args:
endpoint: OTLP endpoint URL
service_name: Name of the service for tracing
"""
global _is_tracing_initialized
# Skip tracing in pytest environment
if is_pytest_environment():
print(" Skipping tracing setup in pytest environment")
return
# Create a Resource to identify our service
resource = Resource.create({"service.name": service_name, "service.namespace": "default", "deployment.environment": "production"})
# Initialize the TracerProvider with the resource
provider = TracerProvider(resource=resource)
# Only set up OTLP export if endpoint is provided
if endpoint:
otlp_exporter = OTLPSpanExporter(endpoint=endpoint)
processor = BatchSpanProcessor(otlp_exporter)
provider.add_span_processor(processor)
_is_tracing_initialized = True
else:
print("⚠️ Warning: Tracing endpoint not provided, tracing will be disabled")
# Set the global TracerProvider
trace.set_tracer_provider(provider)
# Initialize automatic instrumentation for the requests library with response hook
if _is_tracing_initialized:
RequestsInstrumentor().instrument(response_hook=request_hook)
def _extract_request_info(args: tuple, span: Span) -> Dict[str, Any]:
"""
Safely extracts request information from function arguments.
Works with both FastAPI route handlers and inner functions.
"""
attributes = {}
# Look for FastAPI Request object in args
request = next((arg for arg in args if isinstance(arg, Request)), None)
if request:
attributes.update(
{
"http.route": request.url.path,
"http.method": request.method,
"http.scheme": request.url.scheme,
"http.target": str(request.url.path),
"http.url": str(request.url),
"http.flavor": request.scope.get("http_version", ""),
"http.client_ip": request.client.host if request.client else None,
"http.user_id": request.headers.get("user_id"),
}
)
span.set_attributes(attributes)
return attributes