feat: log request data to otel (#1149)
This commit is contained in:
@@ -60,7 +60,7 @@ from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
||||
from letta.tracing import trace_method
|
||||
from letta.tracing import log_event, trace_method
|
||||
from letta.utils import count_tokens, get_friendly_error_msg, get_tool_call_id, log_telemetry, parse_json, validate_function_response
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -307,7 +307,7 @@ class Agent(BaseAgent):
|
||||
# Return updated messages
|
||||
return messages
|
||||
|
||||
@trace_method("Get AI Reply")
|
||||
@trace_method
|
||||
def _get_ai_reply(
|
||||
self,
|
||||
message_sequence: List[Message],
|
||||
@@ -400,7 +400,7 @@ class Agent(BaseAgent):
|
||||
log_telemetry(self.logger, "_handle_ai_response finish catch-all exception")
|
||||
raise Exception("Retries exhausted and no valid response received.")
|
||||
|
||||
@trace_method("Handle AI Response")
|
||||
@trace_method
|
||||
def _handle_ai_response(
|
||||
self,
|
||||
response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function?
|
||||
@@ -538,7 +538,24 @@ class Agent(BaseAgent):
|
||||
log_telemetry(
|
||||
self.logger, "_handle_ai_response execute tool start", function_name=function_name, function_args=function_args
|
||||
)
|
||||
log_event(
|
||||
"tool_call_initiated",
|
||||
attributes={
|
||||
"function_name": function_name,
|
||||
"target_letta_tool": target_letta_tool.model_dump(),
|
||||
**{f"function_args.{k}": v for k, v in function_args.items()},
|
||||
},
|
||||
)
|
||||
|
||||
function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
|
||||
|
||||
log_event(
|
||||
"tool_call_ended",
|
||||
attributes={
|
||||
"function_response": function_response,
|
||||
"sandbox_run_result": sandbox_run_result.model_dump() if sandbox_run_result else None,
|
||||
},
|
||||
)
|
||||
log_telemetry(
|
||||
self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args
|
||||
)
|
||||
@@ -640,7 +657,7 @@ class Agent(BaseAgent):
|
||||
log_telemetry(self.logger, "_handle_ai_response finish")
|
||||
return messages, heartbeat_request, function_failed
|
||||
|
||||
@trace_method("Agent Step")
|
||||
@trace_method
|
||||
def step(
|
||||
self,
|
||||
messages: Union[Message, List[Message]],
|
||||
@@ -828,6 +845,13 @@ class Agent(BaseAgent):
|
||||
f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}"
|
||||
)
|
||||
|
||||
log_event(
|
||||
name="memory_pressure_warning",
|
||||
attributes={
|
||||
"current_total_tokens": current_total_tokens,
|
||||
"context_window_limit": self.agent_state.llm_config.context_window,
|
||||
},
|
||||
)
|
||||
# Only deliver the alert if we haven't already (this period)
|
||||
if not self.agent_alerted_about_memory_pressure:
|
||||
active_memory_warning = True
|
||||
@@ -1029,9 +1053,18 @@ class Agent(BaseAgent):
|
||||
self.agent_alerted_about_memory_pressure = False
|
||||
curr_in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
|
||||
current_token_count = sum(get_token_counts_for_messages(curr_in_context_messages))
|
||||
logger.info(f"Ran summarizer, messages length {prior_len} -> {len(curr_in_context_messages)}")
|
||||
logger.info(
|
||||
f"Summarizer brought down total token count from {sum(token_counts)} -> {sum(get_token_counts_for_messages(curr_in_context_messages))}"
|
||||
logger.info(f"Summarizer brought down total token count from {sum(token_counts)} -> {current_token_count}")
|
||||
log_event(
|
||||
name="summarization",
|
||||
attributes={
|
||||
"prior_length": prior_len,
|
||||
"current_length": len(curr_in_context_messages),
|
||||
"prior_token_count": sum(token_counts),
|
||||
"current_token_count": current_token_count,
|
||||
"context_window_limit": self.agent_state.llm_config.context_window,
|
||||
},
|
||||
)
|
||||
|
||||
def add_function(self, function_name: str) -> str:
|
||||
|
||||
@@ -40,6 +40,7 @@ from letta.schemas.openai.chat_completion_response import MessageDelta, ToolCall
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.settings import model_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.tracing import log_event
|
||||
|
||||
BASE_URL = "https://api.anthropic.com/v1"
|
||||
|
||||
@@ -677,10 +678,12 @@ def anthropic_chat_completions_request(
|
||||
inner_thoughts_xml_tag=inner_thoughts_xml_tag,
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
log_event(name="llm_request_sent", attributes=data)
|
||||
response = anthropic_client.beta.messages.create(
|
||||
**data,
|
||||
betas=betas,
|
||||
)
|
||||
log_event(name="llm_response_received", attributes={"response": response.json()})
|
||||
return convert_anthropic_response_to_chatcompletion(response=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
|
||||
|
||||
|
||||
@@ -698,8 +701,9 @@ def anthropic_bedrock_chat_completions_request(
|
||||
try:
|
||||
# bedrock does not support certain args
|
||||
data["tool_choice"] = {"type": "any"}
|
||||
|
||||
log_event(name="llm_request_sent", attributes=data)
|
||||
response = client.messages.create(**data)
|
||||
log_event(name="llm_response_received", attributes={"response": response.json()})
|
||||
return convert_anthropic_response_to_chatcompletion(response=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
|
||||
except PermissionDeniedError:
|
||||
raise BedrockPermissionError(f"User does not have access to the Bedrock model with the specified ID. {data['model']}")
|
||||
@@ -839,6 +843,8 @@ def anthropic_chat_completions_process_stream(
|
||||
),
|
||||
)
|
||||
|
||||
log_event(name="llm_request_sent", attributes=chat_completion_request.model_dump())
|
||||
|
||||
if stream_interface:
|
||||
stream_interface.stream_start()
|
||||
|
||||
@@ -987,4 +993,6 @@ def anthropic_chat_completions_process_stream(
|
||||
|
||||
assert len(chat_completion_response.choices) > 0, chat_completion_response
|
||||
|
||||
log_event(name="llm_response_received", attributes=chat_completion_response.model_dump())
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
@@ -8,6 +8,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completions import ChatCompletionRequest
|
||||
from letta.schemas.openai.embedding_response import EmbeddingResponse
|
||||
from letta.settings import ModelSettings
|
||||
from letta.tracing import log_event
|
||||
|
||||
|
||||
def get_azure_chat_completions_endpoint(base_url: str, model: str, api_version: str):
|
||||
@@ -120,10 +121,12 @@ def azure_openai_chat_completions_request(
|
||||
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
||||
|
||||
url = get_azure_chat_completions_endpoint(model_settings.azure_base_url, llm_config.model, model_settings.azure_api_version)
|
||||
log_event(name="llm_request_sent", attributes=data)
|
||||
response_json = make_post_request(url, headers, data)
|
||||
# NOTE: azure openai does not include "content" in the response when it is None, so we need to add it
|
||||
if "content" not in response_json["choices"][0].get("message"):
|
||||
response_json["choices"][0]["message"]["content"] = None
|
||||
log_event(name="llm_response_received", attributes=response_json)
|
||||
response = ChatCompletionResponse(**response_json) # convert to 'dot-dict' style which is the openai python client default
|
||||
return response
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.tracing import log_event
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
|
||||
@@ -422,7 +423,9 @@ def google_ai_chat_completions_request(
|
||||
if add_postfunc_model_messages:
|
||||
data["contents"] = add_dummy_model_messages(data["contents"])
|
||||
|
||||
log_event(name="llm_request_sent", attributes=data)
|
||||
response_json = make_post_request(url, headers, data)
|
||||
log_event(name="llm_response_received", attributes=response_json)
|
||||
try:
|
||||
return convert_google_ai_response_to_chatcompletion(
|
||||
response_json=response_json,
|
||||
|
||||
@@ -8,6 +8,7 @@ from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.tracing import log_event
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
|
||||
@@ -323,6 +324,9 @@ def google_vertex_chat_completions_request(
|
||||
config["tool_config"] = tool_config.model_dump()
|
||||
|
||||
# make request to client
|
||||
attributes = config if isinstance(config, dict) else {"config": config}
|
||||
attributes.update({"contents": contents})
|
||||
log_event(name="llm_request_sent", attributes={"contents": contents, "config": config})
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=contents,
|
||||
|
||||
@@ -120,7 +120,7 @@ def retry_with_exponential_backoff(
|
||||
return wrapper
|
||||
|
||||
|
||||
@trace_method("LLM Request")
|
||||
@trace_method
|
||||
@retry_with_exponential_backoff
|
||||
def create(
|
||||
# agent_state: AgentState,
|
||||
|
||||
@@ -25,6 +25,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
)
|
||||
from letta.schemas.openai.embedding_response import EmbeddingResponse
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.tracing import log_event
|
||||
from letta.utils import get_tool_call_id, smart_urljoin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -243,6 +244,8 @@ def openai_chat_completions_process_stream(
|
||||
),
|
||||
)
|
||||
|
||||
log_event(name="llm_request_sent", attributes=chat_completion_request.model_dump())
|
||||
|
||||
if stream_interface:
|
||||
stream_interface.stream_start()
|
||||
|
||||
@@ -406,6 +409,7 @@ def openai_chat_completions_process_stream(
|
||||
assert len(chat_completion_response.choices) > 0, f"No response from provider {chat_completion_response}"
|
||||
|
||||
# printd(chat_completion_response)
|
||||
log_event(name="llm_response_received", attributes=chat_completion_response.model_dump())
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
@@ -437,7 +441,9 @@ def openai_chat_completions_request(
|
||||
"""
|
||||
data = prepare_openai_payload(chat_completion_request)
|
||||
client = OpenAI(api_key=api_key, base_url=url, max_retries=0)
|
||||
log_event(name="llm_request_sent", attributes=data)
|
||||
chat_completion = client.chat.completions.create(**data)
|
||||
log_event(name="llm_response_received", attributes=chat_completion.model_dump())
|
||||
return ChatCompletionResponse(**chat_completion.model_dump())
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from letta.local_llm.webui.api import get_webui_completion
|
||||
from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
|
||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, UsageStatistics
|
||||
from letta.tracing import log_event
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
has_shown_warning = False
|
||||
@@ -149,7 +150,7 @@ def get_chat_completion(
|
||||
else:
|
||||
model_schema = None
|
||||
"""
|
||||
|
||||
log_event(name="llm_request_sent", attributes={"prompt": prompt, "grammar": grammar})
|
||||
# Run the LLM
|
||||
try:
|
||||
result_reasoning = None
|
||||
@@ -178,6 +179,10 @@ def get_chat_completion(
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}")
|
||||
|
||||
attributes = usage if isinstance(usage, dict) else {"usage": usage}
|
||||
attributes.update({"result": result})
|
||||
log_event(name="llm_request_sent", attributes=attributes)
|
||||
|
||||
if result is None or result == "":
|
||||
raise LocalLLMError(f"Got back an empty response string from {endpoint}")
|
||||
printd(f"Raw LLM output:\n====\n{result}\n====")
|
||||
|
||||
@@ -237,7 +237,11 @@ def create_application() -> "FastAPI":
|
||||
print(f"▶ Using OTLP tracing with endpoint: {endpoint}")
|
||||
from letta.tracing import setup_tracing
|
||||
|
||||
setup_tracing(endpoint=endpoint, service_name="memgpt-server")
|
||||
setup_tracing(
|
||||
endpoint=endpoint,
|
||||
app=app,
|
||||
service_name="memgpt-server",
|
||||
)
|
||||
|
||||
for route in v1_routes:
|
||||
app.include_router(route, prefix=API_PREFIX)
|
||||
|
||||
@@ -24,7 +24,6 @@ from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.tracing import trace_method
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
@@ -486,7 +485,6 @@ def modify_message(
|
||||
response_model=LettaResponse,
|
||||
operation_id="send_message",
|
||||
)
|
||||
@trace_method("POST /v1/agents/{agent_id}/messages")
|
||||
async def send_message(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
@@ -525,7 +523,6 @@ async def send_message(
|
||||
}
|
||||
},
|
||||
)
|
||||
@trace_method("POST /v1/agents/{agent_id}/messages/stream")
|
||||
async def send_message_streaming(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
@@ -601,7 +598,6 @@ async def process_message_background(
|
||||
response_model=Run,
|
||||
operation_id="create_agent_message_async",
|
||||
)
|
||||
@trace_method("POST /v1/agents/{agent_id}/messages/async")
|
||||
async def send_message_async(
|
||||
agent_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
|
||||
@@ -1172,7 +1172,7 @@ class SyncServer(Server):
|
||||
actions = self.get_composio_client(api_key=api_key).actions.get(apps=[composio_app_name])
|
||||
return actions
|
||||
|
||||
@trace_method("Send Message")
|
||||
@trace_method
|
||||
async def send_message_to_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
||||
@@ -50,6 +50,7 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import enforce_types, united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -72,6 +73,7 @@ class AgentManager:
|
||||
# ======================================================================================================================
|
||||
# Basic CRUD operations
|
||||
# ======================================================================================================================
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def create_agent(
|
||||
self,
|
||||
@@ -368,6 +370,7 @@ class AgentManager:
|
||||
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def delete_agent(self, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
|
||||
334
letta/tracing.py
334
letta/tracing.py
@@ -1,205 +1,225 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.instrumentation.requests import RequestsInstrumentor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
# Get a tracer instance - will be no-op until setup_tracing is called
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Track if tracing has been initialized
|
||||
_is_tracing_initialized = False
|
||||
_excluded_v1_endpoints_regex: List[str] = [
|
||||
"^GET /v1/agents/(?P<agent_id>[^/]+)/messages$",
|
||||
"^GET /v1/agents/(?P<agent_id>[^/]+)/context$",
|
||||
"^GET /v1/agents/(?P<agent_id>[^/]+)/archival-memory$",
|
||||
"^GET /v1/agents/(?P<agent_id>[^/]+)/sources$",
|
||||
]
|
||||
|
||||
|
||||
def is_pytest_environment():
|
||||
"""Check if we're running in pytest"""
|
||||
return "pytest" in sys.modules
|
||||
|
||||
|
||||
def trace_method(name=None):
|
||||
"""Decorator to add tracing to a method"""
|
||||
async def trace_request_middleware(request: Request, call_next):
|
||||
if not _is_tracing_initialized:
|
||||
return await call_next(request)
|
||||
initial_span_name = f"{request.method} {request.url.path}"
|
||||
if any(re.match(regex, initial_span_name) for regex in _excluded_v1_endpoints_regex):
|
||||
return await call_next(request)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# Skip tracing if not initialized
|
||||
if not _is_tracing_initialized:
|
||||
return await func(*args, **kwargs)
|
||||
with tracer.start_as_current_span(
|
||||
initial_span_name,
|
||||
kind=trace.SpanKind.SERVER,
|
||||
) as span:
|
||||
try:
|
||||
response = await call_next(request)
|
||||
span.set_attribute("http.status_code", response.status_code)
|
||||
span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR))
|
||||
return response
|
||||
except Exception as e:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
span.record_exception(e)
|
||||
raise
|
||||
|
||||
span_name = name or func.__name__
|
||||
with tracer.start_as_current_span(span_name) as span:
|
||||
span.set_attribute("code.namespace", inspect.getmodule(func).__name__)
|
||||
span.set_attribute("code.function", func.__name__)
|
||||
|
||||
if len(args) > 0 and hasattr(args[0], "__class__"):
|
||||
span.set_attribute("code.class", args[0].__class__.__name__)
|
||||
async def update_trace_attributes(request: Request):
|
||||
"""Dependency to update trace attributes after FastAPI has processed the request"""
|
||||
if not _is_tracing_initialized:
|
||||
return
|
||||
|
||||
request = _extract_request_info(args, span)
|
||||
if request and len(request) > 0:
|
||||
span.set_attribute("agent.id", kwargs.get("agent_id"))
|
||||
span.set_attribute("actor.id", request.get("http.user_id"))
|
||||
span = trace.get_current_span()
|
||||
if not span:
|
||||
return
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
span.record_exception(e)
|
||||
raise
|
||||
# Update span name with route pattern
|
||||
route = request.scope.get("route")
|
||||
if route and hasattr(route, "path"):
|
||||
span.update_name(f"{request.method} {route.path}")
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# Skip tracing if not initialized
|
||||
if not _is_tracing_initialized:
|
||||
return func(*args, **kwargs)
|
||||
# Add request info
|
||||
span.set_attribute("http.method", request.method)
|
||||
span.set_attribute("http.url", str(request.url))
|
||||
|
||||
span_name = name or func.__name__
|
||||
with tracer.start_as_current_span(span_name) as span:
|
||||
span.set_attribute("code.namespace", inspect.getmodule(func).__name__)
|
||||
span.set_attribute("code.function", func.__name__)
|
||||
# Add path params
|
||||
for key, value in request.path_params.items():
|
||||
span.set_attribute(f"http.{key}", value)
|
||||
|
||||
if len(args) > 0 and hasattr(args[0], "__class__"):
|
||||
span.set_attribute("code.class", args[0].__class__.__name__)
|
||||
# Add request body if available
|
||||
try:
|
||||
body = await request.json()
|
||||
for key, value in body.items():
|
||||
span.set_attribute(f"http.request.body.{key}", str(value))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
request = _extract_request_info(args, span)
|
||||
if request and len(request) > 0:
|
||||
span.set_attribute("agent.id", kwargs.get("agent_id"))
|
||||
span.set_attribute("actor.id", request.get("http.user_id"))
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
span.record_exception(e)
|
||||
raise
|
||||
async def trace_error_handler(_request: Request, exc: Exception) -> JSONResponse:
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
error_msg = str(exc)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
# Add error details to current span
|
||||
span = trace.get_current_span()
|
||||
if span:
|
||||
span.add_event(
|
||||
name="exception",
|
||||
attributes={
|
||||
"exception.message": error_msg,
|
||||
"exception.type": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
|
||||
return decorator
|
||||
return JSONResponse(status_code=status_code, content={"detail": error_msg, "trace_id": get_trace_id() or ""})
|
||||
|
||||
|
||||
def setup_tracing(
|
||||
endpoint: str,
|
||||
app: Optional[FastAPI] = None,
|
||||
service_name: str = "memgpt-server",
|
||||
) -> None:
|
||||
if is_pytest_environment():
|
||||
return
|
||||
|
||||
global _is_tracing_initialized
|
||||
|
||||
provider = TracerProvider(resource=Resource.create({"service.name": service_name}))
|
||||
if endpoint:
|
||||
provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint=endpoint)))
|
||||
_is_tracing_initialized = True
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
def requests_callback(span: trace.Span, _: Any, response: Any) -> None:
|
||||
if hasattr(response, "status_code"):
|
||||
span.set_status(Status(StatusCode.OK if response.status_code < 400 else StatusCode.ERROR))
|
||||
|
||||
RequestsInstrumentor().instrument(response_hook=requests_callback)
|
||||
|
||||
if app:
|
||||
# Add middleware first
|
||||
app.middleware("http")(trace_request_middleware)
|
||||
|
||||
# Add dependency to v1 routes
|
||||
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
|
||||
for router in v1_routes:
|
||||
for route in router.routes:
|
||||
full_path = ((next(iter(route.methods)) + " ") if route.methods else "") + "/v1" + route.path
|
||||
if not any(re.match(regex, full_path) for regex in _excluded_v1_endpoints_regex):
|
||||
route.dependencies.append(Depends(update_trace_attributes))
|
||||
|
||||
# Register exception handlers
|
||||
app.exception_handler(HTTPException)(trace_error_handler)
|
||||
app.exception_handler(RequestValidationError)(trace_error_handler)
|
||||
app.exception_handler(Exception)(trace_error_handler)
|
||||
|
||||
|
||||
def trace_method(func):
|
||||
"""Decorator that traces function execution with OpenTelemetry"""
|
||||
|
||||
def _get_span_name(func, args):
|
||||
if args and hasattr(args[0], "__class__"):
|
||||
class_name = args[0].__class__.__name__
|
||||
else:
|
||||
class_name = func.__module__
|
||||
return f"{class_name}.{func.__name__}"
|
||||
|
||||
def _add_parameters_to_span(span, func, args, kwargs):
|
||||
try:
|
||||
# Add method parameters as span attributes
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
# Skip 'self' when adding parameters if it exists
|
||||
param_items = list(bound_args.arguments.items())
|
||||
if args and hasattr(args[0], "__class__"):
|
||||
param_items = param_items[1:]
|
||||
|
||||
for name, value in param_items:
|
||||
# Convert value to string to avoid serialization issues
|
||||
span.set_attribute(f"parameter.{name}", str(value))
|
||||
except:
|
||||
pass
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not _is_tracing_initialized:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
with tracer.start_as_current_span(_get_span_name(func, args)) as span:
|
||||
_add_parameters_to_span(span, func, args, kwargs)
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
if not _is_tracing_initialized:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with tracer.start_as_current_span(_get_span_name(func, args)) as span:
|
||||
_add_parameters_to_span(span, func, args, kwargs)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return result
|
||||
|
||||
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def log_attributes(attributes: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log multiple attributes to the current active span.
|
||||
|
||||
Args:
|
||||
attributes: Dictionary of attribute key-value pairs
|
||||
"""
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attributes(attributes)
|
||||
|
||||
|
||||
def log_event(name: str, attributes: Optional[Dict[str, Any]] = None, timestamp: Optional[int] = None) -> None:
|
||||
"""
|
||||
Log an event to the current active span.
|
||||
|
||||
Args:
|
||||
name: Name of the event
|
||||
attributes: Optional dictionary of event attributes
|
||||
timestamp: Optional timestamp in nanoseconds
|
||||
"""
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
if timestamp is None:
|
||||
timestamp = int(time.perf_counter_ns())
|
||||
|
||||
def _safe_convert(v):
|
||||
if isinstance(v, (str, bool, int, float)):
|
||||
return v
|
||||
return str(v)
|
||||
|
||||
attributes = {k: _safe_convert(v) for k, v in attributes.items()} if attributes else None
|
||||
current_span.add_event(name=name, attributes=attributes, timestamp=timestamp)
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
return format(current_span.get_span_context().trace_id, "032x")
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def request_hook(span: Span, _request_context: Optional[Dict] = None, response: Optional[Any] = None):
|
||||
"""Hook to update span based on response status code"""
|
||||
if response is not None:
|
||||
if hasattr(response, "status_code"):
|
||||
span.set_attribute("http.status_code", response.status_code)
|
||||
if response.status_code >= 400:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
elif 200 <= response.status_code < 300:
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
|
||||
|
||||
def setup_tracing(endpoint: str, service_name: str = "memgpt-server") -> None:
|
||||
"""
|
||||
Sets up OpenTelemetry tracing with OTLP exporter for specific endpoints
|
||||
|
||||
Args:
|
||||
endpoint: OTLP endpoint URL
|
||||
service_name: Name of the service for tracing
|
||||
"""
|
||||
global _is_tracing_initialized
|
||||
|
||||
# Skip tracing in pytest environment
|
||||
if is_pytest_environment():
|
||||
print("ℹ️ Skipping tracing setup in pytest environment")
|
||||
return
|
||||
|
||||
# Create a Resource to identify our service
|
||||
resource = Resource.create({"service.name": service_name, "service.namespace": "default", "deployment.environment": "production"})
|
||||
|
||||
# Initialize the TracerProvider with the resource
|
||||
provider = TracerProvider(resource=resource)
|
||||
|
||||
# Only set up OTLP export if endpoint is provided
|
||||
if endpoint:
|
||||
otlp_exporter = OTLPSpanExporter(endpoint=endpoint)
|
||||
processor = BatchSpanProcessor(otlp_exporter)
|
||||
provider.add_span_processor(processor)
|
||||
_is_tracing_initialized = True
|
||||
else:
|
||||
print("⚠️ Warning: Tracing endpoint not provided, tracing will be disabled")
|
||||
|
||||
# Set the global TracerProvider
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
# Initialize automatic instrumentation for the requests library with response hook
|
||||
if _is_tracing_initialized:
|
||||
RequestsInstrumentor().instrument(response_hook=request_hook)
|
||||
|
||||
|
||||
def _extract_request_info(args: tuple, span: Span) -> Dict[str, Any]:
|
||||
"""
|
||||
Safely extracts request information from function arguments.
|
||||
Works with both FastAPI route handlers and inner functions.
|
||||
"""
|
||||
attributes = {}
|
||||
|
||||
# Look for FastAPI Request object in args
|
||||
request = next((arg for arg in args if isinstance(arg, Request)), None)
|
||||
|
||||
if request:
|
||||
attributes.update(
|
||||
{
|
||||
"http.route": request.url.path,
|
||||
"http.method": request.method,
|
||||
"http.scheme": request.url.scheme,
|
||||
"http.target": str(request.url.path),
|
||||
"http.url": str(request.url),
|
||||
"http.flavor": request.scope.get("http_version", ""),
|
||||
"http.client_ip": request.client.host if request.client else None,
|
||||
"http.user_id": request.headers.get("user_id"),
|
||||
}
|
||||
)
|
||||
|
||||
span.set_attributes(attributes)
|
||||
return attributes
|
||||
def get_trace_id() -> Optional[str]:
|
||||
span = trace.get_current_span()
|
||||
if span and span.get_span_context().trace_id:
|
||||
return format(span.get_span_context().trace_id, "032x")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user