diff --git a/letta/agent.py b/letta/agent.py index a49cbe36..6d7e4e2a 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -60,7 +60,7 @@ from letta.services.tool_manager import ToolManager from letta.settings import summarizer_settings from letta.streaming_interface import StreamingRefreshCLIInterface from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message -from letta.tracing import trace_method +from letta.tracing import log_event, trace_method from letta.utils import count_tokens, get_friendly_error_msg, get_tool_call_id, log_telemetry, parse_json, validate_function_response logger = get_logger(__name__) @@ -307,7 +307,7 @@ class Agent(BaseAgent): # Return updated messages return messages - @trace_method("Get AI Reply") + @trace_method def _get_ai_reply( self, message_sequence: List[Message], @@ -400,7 +400,7 @@ class Agent(BaseAgent): log_telemetry(self.logger, "_handle_ai_response finish catch-all exception") raise Exception("Retries exhausted and no valid response received.") - @trace_method("Handle AI Response") + @trace_method def _handle_ai_response( self, response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function? @@ -538,7 +538,24 @@ class Agent(BaseAgent): log_telemetry( self.logger, "_handle_ai_response execute tool start", function_name=function_name, function_args=function_args ) + log_event( + "tool_call_initiated", + attributes={ + "function_name": function_name, + "target_letta_tool": target_letta_tool.model_dump(), + **{f"function_args.{k}": v for k, v in function_args.items()}, + }, + ) + function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) + + log_event( + "tool_call_ended", + attributes={ + "function_response": function_response, + "sandbox_run_result": sandbox_run_result.model_dump() if sandbox_run_result else None, + }, + ) log_telemetry( self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args ) @@ -640,7 +657,7 @@ class Agent(BaseAgent): log_telemetry(self.logger, "_handle_ai_response finish") return messages, heartbeat_request, function_failed - @trace_method("Agent Step") + @trace_method def step( self, messages: Union[Message, List[Message]], @@ -828,6 +845,13 @@ class Agent(BaseAgent): f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}" ) + log_event( + name="memory_pressure_warning", + attributes={ + "current_total_tokens": current_total_tokens, + "context_window_limit": self.agent_state.llm_config.context_window, + }, + ) # Only deliver the alert if we haven't already (this period) if not self.agent_alerted_about_memory_pressure: active_memory_warning = True @@ -1029,9 +1053,18 @@ class Agent(BaseAgent): self.agent_alerted_about_memory_pressure = False curr_in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) + current_token_count = sum(get_token_counts_for_messages(curr_in_context_messages)) logger.info(f"Ran summarizer, messages length {prior_len} -> {len(curr_in_context_messages)}") - logger.info( - f"Summarizer brought down total token count from {sum(token_counts)} -> {sum(get_token_counts_for_messages(curr_in_context_messages))}" + logger.info(f"Summarizer brought down total token count from {sum(token_counts)} -> {current_token_count}") + log_event( + name="summarization", + attributes={ + "prior_length": prior_len, + "current_length": len(curr_in_context_messages), + "prior_token_count": sum(token_counts), + "current_token_count": current_token_count, + "context_window_limit": self.agent_state.llm_config.context_window, + }, ) def add_function(self, function_name: str) -> str: diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 205f4cb7..af3780c8 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -40,6 +40,7 @@ from letta.schemas.openai.chat_completion_response import MessageDelta, ToolCall from letta.services.provider_manager import ProviderManager from letta.settings import model_settings from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface +from letta.tracing import log_event BASE_URL = "https://api.anthropic.com/v1" @@ -677,10 +678,12 @@ def anthropic_chat_completions_request( inner_thoughts_xml_tag=inner_thoughts_xml_tag, put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, ) + log_event(name="llm_request_sent", attributes=data) response = anthropic_client.beta.messages.create( **data, betas=betas, ) + log_event(name="llm_response_received", attributes={"response": response.json()}) return convert_anthropic_response_to_chatcompletion(response=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag) @@ -698,8 +701,9 @@ def anthropic_bedrock_chat_completions_request( try: # bedrock does not support certain args data["tool_choice"] = {"type": "any"} - + log_event(name="llm_request_sent", attributes=data) response = client.messages.create(**data) + log_event(name="llm_response_received", attributes={"response": response.json()}) return convert_anthropic_response_to_chatcompletion(response=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag) except PermissionDeniedError: raise BedrockPermissionError(f"User does not have access to the Bedrock model with the specified ID. {data['model']}") @@ -839,6 +843,8 @@ def anthropic_chat_completions_process_stream( ), ) + log_event(name="llm_request_sent", attributes=chat_completion_request.model_dump()) + if stream_interface: stream_interface.stream_start() @@ -987,4 +993,6 @@ def anthropic_chat_completions_process_stream( assert len(chat_completion_response.choices) > 0, chat_completion_response + log_event(name="llm_response_received", attributes=chat_completion_response.model_dump()) + return chat_completion_response diff --git a/letta/llm_api/azure_openai.py b/letta/llm_api/azure_openai.py index e60b547b..368850ec 100644 --- a/letta/llm_api/azure_openai.py +++ b/letta/llm_api/azure_openai.py @@ -8,6 +8,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completions import ChatCompletionRequest from letta.schemas.openai.embedding_response import EmbeddingResponse from letta.settings import ModelSettings +from letta.tracing import log_event def get_azure_chat_completions_endpoint(base_url: str, model: str, api_version: str): @@ -120,10 +121,12 @@ def azure_openai_chat_completions_request( data.pop("tool_choice", None) # extra safe, should exist always (default="auto") url = get_azure_chat_completions_endpoint(model_settings.azure_base_url, llm_config.model, model_settings.azure_api_version) + log_event(name="llm_request_sent", attributes=data) response_json = make_post_request(url, headers, data) # NOTE: azure openai does not include "content" in the response when it is None, so we need to add it if "content" not in response_json["choices"][0].get("message"): response_json["choices"][0]["message"]["content"] = None + log_event(name="llm_response_received", attributes=response_json) response = ChatCompletionResponse(**response_json) # convert to 'dot-dict' style which is the openai python client default return response diff --git a/letta/llm_api/google_ai.py b/letta/llm_api/google_ai.py index 6c2fca8c..abf707d0 100644 --- a/letta/llm_api/google_ai.py +++ b/letta/llm_api/google_ai.py @@ -11,6 +11,7 @@ from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics +from letta.tracing import log_event from letta.utils import get_tool_call_id @@ -422,7 +423,9 @@ def google_ai_chat_completions_request( if add_postfunc_model_messages: data["contents"] = add_dummy_model_messages(data["contents"]) + log_event(name="llm_request_sent", attributes=data) response_json = make_post_request(url, headers, data) + log_event(name="llm_response_received", attributes=response_json) try: return convert_google_ai_response_to_chatcompletion( response_json=response_json, diff --git a/letta/llm_api/google_vertex.py b/letta/llm_api/google_vertex.py index c7c57729..4e85abf3 100644 --- a/letta/llm_api/google_vertex.py +++ b/letta/llm_api/google_vertex.py @@ -8,6 +8,7 @@ from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics +from letta.tracing import log_event from letta.utils import get_tool_call_id @@ -323,6 +324,9 @@ def google_vertex_chat_completions_request( config["tool_config"] = tool_config.model_dump() # make request to client + attributes = config if isinstance(config, dict) else {"config": config} + attributes.update({"contents": contents}) + log_event(name="llm_request_sent", attributes={"contents": contents, "config": config}) response = client.models.generate_content( model=model, contents=contents, diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 32a07136..1d8a4af7 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -120,7 +120,7 @@ def retry_with_exponential_backoff( return wrapper -@trace_method("LLM Request") +@trace_method @retry_with_exponential_backoff def create( # agent_state: AgentState, diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 56710682..d8ad521b 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -25,6 +25,7 @@ from letta.schemas.openai.chat_completion_response import ( ) from letta.schemas.openai.embedding_response import EmbeddingResponse from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface +from letta.tracing import log_event from letta.utils import get_tool_call_id, smart_urljoin logger = get_logger(__name__) @@ -243,6 +244,8 @@ def openai_chat_completions_process_stream( ), ) + log_event(name="llm_request_sent", attributes=chat_completion_request.model_dump()) + if stream_interface: stream_interface.stream_start() @@ -406,6 +409,7 @@ def openai_chat_completions_process_stream( assert len(chat_completion_response.choices) > 0, f"No response from provider {chat_completion_response}" # printd(chat_completion_response) + log_event(name="llm_response_received", attributes=chat_completion_response.model_dump()) return chat_completion_response @@ -437,7 +441,9 @@ def openai_chat_completions_request( """ data = prepare_openai_payload(chat_completion_request) client = OpenAI(api_key=api_key, base_url=url, max_retries=0) + log_event(name="llm_request_sent", attributes=data) chat_completion = client.chat.completions.create(**data) + log_event(name="llm_response_received", attributes=chat_completion.model_dump()) return ChatCompletionResponse(**chat_completion.model_dump()) diff --git a/letta/local_llm/chat_completion_proxy.py b/letta/local_llm/chat_completion_proxy.py index c5e7d025..4abc01ee 100644 --- a/letta/local_llm/chat_completion_proxy.py +++ b/letta/local_llm/chat_completion_proxy.py @@ -22,6 +22,7 @@ from letta.local_llm.webui.api import get_webui_completion from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, UsageStatistics +from letta.tracing import log_event from letta.utils import get_tool_call_id has_shown_warning = False @@ -149,7 +150,7 @@ def get_chat_completion( else: model_schema = None """ - + log_event(name="llm_request_sent", attributes={"prompt": prompt, "grammar": grammar}) # Run the LLM try: result_reasoning = None @@ -178,6 +179,10 @@ def get_chat_completion( except requests.exceptions.ConnectionError as e: raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}") + attributes = usage if isinstance(usage, dict) else {"usage": usage} + attributes.update({"result": result}) + log_event(name="llm_request_sent", attributes=attributes) + if result is None or result == "": raise LocalLLMError(f"Got back an empty response string from {endpoint}") printd(f"Raw LLM output:\n====\n{result}\n====") diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index de1eb2ff..d6a4ff60 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -237,7 +237,11 @@ def create_application() -> "FastAPI": print(f"▶ Using OTLP tracing with endpoint: {endpoint}") from letta.tracing import setup_tracing - setup_tracing(endpoint=endpoint, service_name="memgpt-server") + setup_tracing( + endpoint=endpoint, + app=app, + service_name="memgpt-server", + ) for route in v1_routes: app.include_router(route, prefix=API_PREFIX) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index a5b8c324..7afcc22b 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -24,7 +24,6 @@ from letta.schemas.tool import Tool from letta.schemas.user import User from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer -from letta.tracing import trace_method # These can be forward refs, but because Fastapi needs them at runtime the must be imported normally @@ -486,7 +485,6 @@ def modify_message( response_model=LettaResponse, operation_id="send_message", ) -@trace_method("POST /v1/agents/{agent_id}/messages") async def send_message( agent_id: str, server: SyncServer = Depends(get_letta_server), @@ -525,7 +523,6 @@ async def send_message( } }, ) -@trace_method("POST /v1/agents/{agent_id}/messages/stream") async def send_message_streaming( agent_id: str, server: SyncServer = Depends(get_letta_server), @@ -601,7 +598,6 @@ async def process_message_background( response_model=Run, operation_id="create_agent_message_async", ) -@trace_method("POST /v1/agents/{agent_id}/messages/async") async def send_message_async( agent_id: str, background_tasks: BackgroundTasks, diff --git a/letta/server/server.py b/letta/server/server.py index 796af566..8afee322 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1172,7 +1172,7 @@ class SyncServer(Server): actions = self.get_composio_client(api_key=api_key).actions.get(apps=[composio_app_name]) return actions - @trace_method("Send Message") + @trace_method async def send_message_to_agent( self, agent_id: str, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index ada1f7c1..07dade40 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -50,6 +50,7 @@ from letta.services.message_manager import MessageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings +from letta.tracing import trace_method from letta.utils import enforce_types, united_diff logger = get_logger(__name__) @@ -72,6 +73,7 @@ class AgentManager: # ====================================================================================================================== # Basic CRUD operations # ====================================================================================================================== + @trace_method @enforce_types def create_agent( self, @@ -368,6 +370,7 @@ class AgentManager: agent = AgentModel.read(db_session=session, name=agent_name, actor=actor) return agent.to_pydantic() + @trace_method @enforce_types def delete_agent(self, agent_id: str, actor: PydanticUser) -> None: """ diff --git a/letta/tracing.py b/letta/tracing.py index 746f468c..6971cad1 100644 --- a/letta/tracing.py +++ b/letta/tracing.py @@ -1,205 +1,225 @@ -import asyncio import inspect +import re import sys import time from functools import wraps -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional -from fastapi import Request +from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse from opentelemetry import trace -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace import Span, Status, StatusCode +from opentelemetry.trace import Status, StatusCode -# Get a tracer instance - will be no-op until setup_tracing is called tracer = trace.get_tracer(__name__) - -# Track if tracing has been initialized _is_tracing_initialized = False +_excluded_v1_endpoints_regex: List[str] = [ + "^GET /v1/agents/(?P[^/]+)/messages$", + "^GET /v1/agents/(?P[^/]+)/context$", + "^GET /v1/agents/(?P[^/]+)/archival-memory$", + "^GET /v1/agents/(?P[^/]+)/sources$", +] def is_pytest_environment(): - """Check if we're running in pytest""" return "pytest" in sys.modules -def trace_method(name=None): - """Decorator to add tracing to a method""" +async def trace_request_middleware(request: Request, call_next): + if not _is_tracing_initialized: + return await call_next(request) + initial_span_name = f"{request.method} {request.url.path}" + if any(re.match(regex, initial_span_name) for regex in _excluded_v1_endpoints_regex): + return await call_next(request) - def decorator(func): - @wraps(func) - async def async_wrapper(*args, **kwargs): - # Skip tracing if not initialized - if not _is_tracing_initialized: - return await func(*args, **kwargs) + with tracer.start_as_current_span( + initial_span_name, + kind=trace.SpanKind.SERVER, + ) as span: + try: + response = await call_next(request) + span.set_attribute("http.status_code", response.status_code) + span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR)) + return response + except Exception as e: + span.set_status(Status(StatusCode.ERROR)) + span.record_exception(e) + raise - span_name = name or func.__name__ - with tracer.start_as_current_span(span_name) as span: - span.set_attribute("code.namespace", inspect.getmodule(func).__name__) - span.set_attribute("code.function", func.__name__) - if len(args) > 0 and hasattr(args[0], "__class__"): - span.set_attribute("code.class", args[0].__class__.__name__) +async def update_trace_attributes(request: Request): + """Dependency to update trace attributes after FastAPI has processed the request""" + if not _is_tracing_initialized: + return - request = _extract_request_info(args, span) - if request and len(request) > 0: - span.set_attribute("agent.id", kwargs.get("agent_id")) - span.set_attribute("actor.id", request.get("http.user_id")) + span = trace.get_current_span() + if not span: + return - try: - result = await func(*args, **kwargs) - span.set_status(Status(StatusCode.OK)) - return result - except Exception as e: - span.set_status(Status(StatusCode.ERROR)) - span.record_exception(e) - raise + # Update span name with route pattern + route = request.scope.get("route") + if route and hasattr(route, "path"): + span.update_name(f"{request.method} {route.path}") - @wraps(func) - def sync_wrapper(*args, **kwargs): - # Skip tracing if not initialized - if not _is_tracing_initialized: - return func(*args, **kwargs) + # Add request info + span.set_attribute("http.method", request.method) + span.set_attribute("http.url", str(request.url)) - span_name = name or func.__name__ - with tracer.start_as_current_span(span_name) as span: - span.set_attribute("code.namespace", inspect.getmodule(func).__name__) - span.set_attribute("code.function", func.__name__) + # Add path params + for key, value in request.path_params.items(): + span.set_attribute(f"http.{key}", value) - if len(args) > 0 and hasattr(args[0], "__class__"): - span.set_attribute("code.class", args[0].__class__.__name__) + # Add request body if available + try: + body = await request.json() + for key, value in body.items(): + span.set_attribute(f"http.request.body.{key}", str(value)) + except Exception: + pass - request = _extract_request_info(args, span) - if request and len(request) > 0: - span.set_attribute("agent.id", kwargs.get("agent_id")) - span.set_attribute("actor.id", request.get("http.user_id")) - try: - result = func(*args, **kwargs) - span.set_status(Status(StatusCode.OK)) - return result - except Exception as e: - span.set_status(Status(StatusCode.ERROR)) - span.record_exception(e) - raise +async def trace_error_handler(_request: Request, exc: Exception) -> JSONResponse: + status_code = getattr(exc, "status_code", 500) + error_msg = str(exc) - return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + # Add error details to current span + span = trace.get_current_span() + if span: + span.add_event( + name="exception", + attributes={ + "exception.message": error_msg, + "exception.type": type(exc).__name__, + }, + ) - return decorator + return JSONResponse(status_code=status_code, content={"detail": error_msg, "trace_id": get_trace_id() or ""}) + + +def setup_tracing( + endpoint: str, + app: Optional[FastAPI] = None, + service_name: str = "memgpt-server", +) -> None: + if is_pytest_environment(): + return + + global _is_tracing_initialized + + provider = TracerProvider(resource=Resource.create({"service.name": service_name})) + if endpoint: + provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint=endpoint))) + _is_tracing_initialized = True + trace.set_tracer_provider(provider) + + def requests_callback(span: trace.Span, _: Any, response: Any) -> None: + if hasattr(response, "status_code"): + span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR)) + + RequestsInstrumentor().instrument(response_hook=requests_callback) + + if app: + # Add middleware first + app.middleware("http")(trace_request_middleware) + + # Add dependency to v1 routes + from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes + + for router in v1_routes: + for route in router.routes: + full_path = ((next(iter(route.methods)) + " ") if route.methods else "") + "/v1" + route.path + if not any(re.match(regex, full_path) for regex in _excluded_v1_endpoints_regex): + route.dependencies.append(Depends(update_trace_attributes)) + + # Register exception handlers + app.exception_handler(HTTPException)(trace_error_handler) + app.exception_handler(RequestValidationError)(trace_error_handler) + app.exception_handler(Exception)(trace_error_handler) + + +def trace_method(func): + """Decorator that traces function execution with OpenTelemetry""" + + def _get_span_name(func, args): + if args and hasattr(args[0], "__class__"): + class_name = args[0].__class__.__name__ + else: + class_name = func.__module__ + return f"{class_name}.{func.__name__}" + + def _add_parameters_to_span(span, func, args, kwargs): + try: + # Add method parameters as span attributes + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + # Skip 'self' when adding parameters if it exists + param_items = list(bound_args.arguments.items()) + if args and hasattr(args[0], "__class__"): + param_items = param_items[1:] + + for name, value in param_items: + # Convert value to string to avoid serialization issues + span.set_attribute(f"parameter.{name}", str(value)) + except: + pass + + @wraps(func) + async def async_wrapper(*args, **kwargs): + if not _is_tracing_initialized: + return await func(*args, **kwargs) + + with tracer.start_as_current_span(_get_span_name(func, args)) as span: + _add_parameters_to_span(span, func, args, kwargs) + + result = await func(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + return result + + @wraps(func) + def sync_wrapper(*args, **kwargs): + if not _is_tracing_initialized: + return func(*args, **kwargs) + + with tracer.start_as_current_span(_get_span_name(func, args)) as span: + _add_parameters_to_span(span, func, args, kwargs) + + result = func(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + return result + + return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper def log_attributes(attributes: Dict[str, Any]) -> None: - """ - Log multiple attributes to the current active span. - - Args: - attributes: Dictionary of attribute key-value pairs - """ current_span = trace.get_current_span() if current_span: current_span.set_attributes(attributes) def log_event(name: str, attributes: Optional[Dict[str, Any]] = None, timestamp: Optional[int] = None) -> None: - """ - Log an event to the current active span. - - Args: - name: Name of the event - attributes: Optional dictionary of event attributes - timestamp: Optional timestamp in nanoseconds - """ current_span = trace.get_current_span() if current_span: if timestamp is None: timestamp = int(time.perf_counter_ns()) + def _safe_convert(v): + if isinstance(v, (str, bool, int, float)): + return v + return str(v) + + attributes = {k: _safe_convert(v) for k, v in attributes.items()} if attributes else None current_span.add_event(name=name, attributes=attributes, timestamp=timestamp) -def get_trace_id() -> str: - current_span = trace.get_current_span() - if current_span: - return format(current_span.get_span_context().trace_id, "032x") - else: - return "" - - -def request_hook(span: Span, _request_context: Optional[Dict] = None, response: Optional[Any] = None): - """Hook to update span based on response status code""" - if response is not None: - if hasattr(response, "status_code"): - span.set_attribute("http.status_code", response.status_code) - if response.status_code >= 400: - span.set_status(Status(StatusCode.ERROR)) - elif 200 <= response.status_code < 300: - span.set_status(Status(StatusCode.OK)) - - -def setup_tracing(endpoint: str, service_name: str = "memgpt-server") -> None: - """ - Sets up OpenTelemetry tracing with OTLP exporter for specific endpoints - - Args: - endpoint: OTLP endpoint URL - service_name: Name of the service for tracing - """ - global _is_tracing_initialized - - # Skip tracing in pytest environment - if is_pytest_environment(): - print("ℹ️ Skipping tracing setup in pytest environment") - return - - # Create a Resource to identify our service - resource = Resource.create({"service.name": service_name, "service.namespace": "default", "deployment.environment": "production"}) - - # Initialize the TracerProvider with the resource - provider = TracerProvider(resource=resource) - - # Only set up OTLP export if endpoint is provided - if endpoint: - otlp_exporter = OTLPSpanExporter(endpoint=endpoint) - processor = BatchSpanProcessor(otlp_exporter) - provider.add_span_processor(processor) - _is_tracing_initialized = True - else: - print("⚠️ Warning: Tracing endpoint not provided, tracing will be disabled") - - # Set the global TracerProvider - trace.set_tracer_provider(provider) - - # Initialize automatic instrumentation for the requests library with response hook - if _is_tracing_initialized: - RequestsInstrumentor().instrument(response_hook=request_hook) - - -def _extract_request_info(args: tuple, span: Span) -> Dict[str, Any]: - """ - Safely extracts request information from function arguments. - Works with both FastAPI route handlers and inner functions. - """ - attributes = {} - - # Look for FastAPI Request object in args - request = next((arg for arg in args if isinstance(arg, Request)), None) - - if request: - attributes.update( - { - "http.route": request.url.path, - "http.method": request.method, - "http.scheme": request.url.scheme, - "http.target": str(request.url.path), - "http.url": str(request.url), - "http.flavor": request.scope.get("http_version", ""), - "http.client_ip": request.client.host if request.client else None, - "http.user_id": request.headers.get("user_id"), - } - ) - - span.set_attributes(attributes) - return attributes +def get_trace_id() -> Optional[str]: + span = trace.get_current_span() + if span and span.get_span_context().trace_id: + return format(span.get_span_context().trace_id, "032x") + return None