feat(logs): Enrich logs with context-aware primtive types (#5949)

* enrich logs with context-aware primtive types

* Delete apps/core/docs/LOG_CONTEXT.md
This commit is contained in:
Kian Jones
2025-11-05 16:46:24 -08:00
committed by Caren Thomas
parent e2774c07c6
commit ea3248593c
7 changed files with 389 additions and 4 deletions

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from sys import stdout
from typing import Any, Optional
from letta.log_context import get_log_context
from letta.settings import log_settings, settings, telemetry_settings
selected_log_level = logging.DEBUG if settings.debug else logging.INFO
@@ -131,6 +132,40 @@ class DatadogEnvFilter(logging.Filter):
return True
class LogContextFilter(logging.Filter):
"""
Logging filter that enriches log records with request context.
Injects context-specific attributes like actor_id, agent_id, org_id, etc.
into log records. These attributes are stored in a context variable
and automatically included in all log messages within that context.
This enables correlation of logs with specific requests, agents, and users
in monitoring systems like Datadog.
Usage:
from letta.log_context import set_log_context, update_log_context
# Set a single context value
set_log_context("agent_id", "agent-123")
# Set multiple context values
update_log_context(agent_id="agent-123", actor_id="user-456")
# All subsequent logs will include these attributes
logger.info("Processing request")
# Output: {"message": "Processing request", "agent_id": "agent-123", "actor_id": "user-456", ...}
"""
def filter(self, record: logging.LogRecord) -> bool:
"""Add request context attributes to log record."""
context = get_log_context()
for key, value in context.items():
if not hasattr(record, key):
setattr(record, key, value)
return True
def _setup_logfile() -> "Path":
"""ensure the logger filepath is in place
@@ -184,6 +219,9 @@ DEVELOPMENT_LOGGING = {
"datadog_env": {
"()": DatadogEnvFilter,
},
"log_context": {
"()": LogContextFilter,
},
},
"handlers": {
"console": {
@@ -191,7 +229,7 @@ DEVELOPMENT_LOGGING = {
"class": "logging.StreamHandler",
"stream": stdout,
"formatter": _get_console_formatter(),
"filters": ["datadog_env"] if telemetry_settings.enable_datadog and not log_settings.json_logging else [],
"filters": (["datadog_env"] if telemetry_settings.enable_datadog and not log_settings.json_logging else []) + ["log_context"],
},
"file": {
"level": "DEBUG",
@@ -200,7 +238,7 @@ DEVELOPMENT_LOGGING = {
"maxBytes": 1024**2 * 10, # 10 MB per file
"backupCount": 3, # Keep 3 backup files
"formatter": _get_file_formatter(),
"filters": ["datadog_env"] if telemetry_settings.enable_datadog and not log_settings.json_logging else [],
"filters": (["datadog_env"] if telemetry_settings.enable_datadog and not log_settings.json_logging else []) + ["log_context"],
},
},
"root": { # Root logger handles all logs

33
letta/log_context.py Normal file
View File

@@ -0,0 +1,33 @@
from contextvars import ContextVar
from typing import Any, Optional
_log_context: ContextVar[dict[str, Any]] = ContextVar("log_context", default={})
def set_log_context(key: str, value: Any) -> None:
ctx = _log_context.get().copy()
ctx[key] = value
_log_context.set(ctx)
def get_log_context(key: Optional[str] = None) -> Any:
ctx = _log_context.get()
if key is None:
return ctx
return ctx.get(key)
def clear_log_context() -> None:
_log_context.set({})
def update_log_context(**kwargs: Any) -> None:
ctx = _log_context.get().copy()
ctx.update(kwargs)
_log_context.set(ctx)
def remove_log_context(key: str) -> None:
ctx = _log_context.get().copy()
ctx.pop(key, None)
_log_context.set(ctx)

View File

@@ -62,7 +62,7 @@ from letta.server.db import db_registry
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.middleware import CheckPasswordMiddleware, ProfilerContextMiddleware
from letta.server.rest_api.middleware import CheckPasswordMiddleware, LogContextMiddleware, ProfilerContextMiddleware
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
@@ -519,6 +519,8 @@ def create_application() -> "FastAPI":
if telemetry_settings.profiler:
app.add_middleware(ProfilerContextMiddleware)
app.add_middleware(LogContextMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,

View File

@@ -1,4 +1,5 @@
from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware
from letta.server.rest_api.middleware.log_context import LogContextMiddleware
from letta.server.rest_api.middleware.profiler_context import ProfilerContextMiddleware
__all__ = ["CheckPasswordMiddleware", "ProfilerContextMiddleware"]
__all__ = ["CheckPasswordMiddleware", "LogContextMiddleware", "ProfilerContextMiddleware"]

View File

@@ -0,0 +1,63 @@
import re
from starlette.middleware.base import BaseHTTPMiddleware
from letta.log_context import clear_log_context, update_log_context
from letta.schemas.enums import PrimitiveType
from letta.validators import PRIMITIVE_ID_PATTERNS
class LogContextMiddleware(BaseHTTPMiddleware):
"""
Middleware that enriches log context with request-specific attributes.
Automatically extracts and sets:
- actor_id: From the 'user_id' header
- org_id: From organization-related endpoints
- Letta primitive IDs: agent_id, tool_id, block_id, etc. from URL paths
This enables all logs within a request to be automatically tagged with
relevant context for better filtering and correlation in monitoring systems.
"""
async def dispatch(self, request, call_next):
clear_log_context()
try:
context = {}
actor_id = request.headers.get("user_id")
if actor_id:
context["actor_id"] = actor_id
path = request.url.path
path_parts = [p for p in path.split("/") if p]
matched_parts = set()
for part in path_parts:
if part in matched_parts:
continue
for primitive_type in PrimitiveType:
prefix = primitive_type.value
pattern = PRIMITIVE_ID_PATTERNS.get(prefix)
if pattern and pattern.match(part):
context_key = f"{primitive_type.name.lower()}_id"
if primitive_type == PrimitiveType.ORGANIZATION:
context_key = "org_id"
elif primitive_type == PrimitiveType.USER:
context_key = "user_id"
context[context_key] = part
matched_parts.add(part)
break
if context:
update_log_context(**context)
response = await call_next(request)
return response
finally:
clear_log_context()

153
tests/test_log_context.py Normal file
View File

@@ -0,0 +1,153 @@
import json
import logging
from io import StringIO
import pytest
from letta.log import JSONFormatter, LogContextFilter
from letta.log_context import clear_log_context, get_log_context, remove_log_context, set_log_context, update_log_context
class TestLogContext:
def test_set_log_context(self):
clear_log_context()
set_log_context("agent_id", "agent-123")
assert get_log_context("agent_id") == "agent-123"
clear_log_context()
def test_update_log_context(self):
clear_log_context()
update_log_context(agent_id="agent-123", actor_id="user-456")
context = get_log_context()
assert context["agent_id"] == "agent-123"
assert context["actor_id"] == "user-456"
clear_log_context()
def test_remove_log_context(self):
clear_log_context()
update_log_context(agent_id="agent-123", actor_id="user-456")
remove_log_context("agent_id")
context = get_log_context()
assert "agent_id" not in context
assert context["actor_id"] == "user-456"
clear_log_context()
def test_clear_log_context(self):
update_log_context(agent_id="agent-123", actor_id="user-456")
clear_log_context()
context = get_log_context()
assert context == {}
def test_get_log_context_all(self):
clear_log_context()
update_log_context(agent_id="agent-123", actor_id="user-456")
context = get_log_context()
assert isinstance(context, dict)
assert len(context) == 2
clear_log_context()
class TestLogContextFilter:
def test_filter_adds_context_to_record(self):
clear_log_context()
update_log_context(agent_id="agent-123", actor_id="user-456")
log_filter = LogContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test message",
args=(),
exc_info=None,
)
result = log_filter.filter(record)
assert result is True
assert hasattr(record, "agent_id")
assert record.agent_id == "agent-123"
assert hasattr(record, "actor_id")
assert record.actor_id == "user-456"
clear_log_context()
def test_filter_does_not_override_existing_attributes(self):
clear_log_context()
update_log_context(agent_id="agent-123")
log_filter = LogContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test message",
args=(),
exc_info=None,
)
record.agent_id = "agent-999"
log_filter.filter(record)
assert record.agent_id == "agent-999"
clear_log_context()
class TestLogContextIntegration:
def test_json_formatter_includes_context(self):
clear_log_context()
update_log_context(agent_id="agent-123", actor_id="user-456")
logger = logging.getLogger("test_logger")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(StringIO())
handler.setFormatter(JSONFormatter())
handler.addFilter(LogContextFilter())
logger.addHandler(handler)
log_stream = handler.stream
logger.info("Test message")
log_stream.seek(0)
log_output = log_stream.read()
log_data = json.loads(log_output)
assert log_data["message"] == "Test message"
assert log_data["agent_id"] == "agent-123"
assert log_data["actor_id"] == "user-456"
logger.removeHandler(handler)
clear_log_context()
def test_multiple_log_calls_with_changing_context(self):
clear_log_context()
logger = logging.getLogger("test_logger_2")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(StringIO())
handler.setFormatter(JSONFormatter())
handler.addFilter(LogContextFilter())
logger.addHandler(handler)
log_stream = handler.stream
update_log_context(agent_id="agent-123")
logger.info("First message")
update_log_context(actor_id="user-456")
logger.info("Second message")
log_stream.seek(0)
log_lines = log_stream.readlines()
assert len(log_lines) == 2
first_log = json.loads(log_lines[0])
assert first_log["agent_id"] == "agent-123"
assert "actor_id" not in first_log
second_log = json.loads(log_lines[1])
assert second_log["agent_id"] == "agent-123"
assert second_log["actor_id"] == "user-456"
logger.removeHandler(handler)
clear_log_context()

View File

@@ -0,0 +1,95 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from letta.log_context import get_log_context
from letta.server.rest_api.middleware import LogContextMiddleware
@pytest.fixture
def app():
app = FastAPI()
app.add_middleware(LogContextMiddleware)
@app.get("/v1/agents/{agent_id}")
async def get_agent(agent_id: str):
context = get_log_context()
return {"agent_id": agent_id, "context": context}
@app.get("/v1/agents/{agent_id}/tools/{tool_id}")
async def get_agent_tool(agent_id: str, tool_id: str):
context = get_log_context()
return {"agent_id": agent_id, "tool_id": tool_id, "context": context}
@app.get("/v1/organizations/{org_id}/users/{user_id}")
async def get_org_user(org_id: str, user_id: str):
context = get_log_context()
return {"org_id": org_id, "user_id": user_id, "context": context}
return app
@pytest.fixture
def client(app):
return TestClient(app)
class TestLogContextMiddleware:
def test_extracts_actor_id_from_headers(self, client):
response = client.get("/v1/agents/agent-123e4567-e89b-42d3-8456-426614174000", headers={"user_id": "user-abc123"})
assert response.status_code == 200
data = response.json()
assert data["context"]["actor_id"] == "user-abc123"
def test_extracts_agent_id_from_path(self, client):
agent_id = "agent-123e4567-e89b-42d3-8456-426614174000"
response = client.get(f"/v1/agents/{agent_id}")
assert response.status_code == 200
data = response.json()
assert data["context"]["agent_id"] == agent_id
def test_extracts_multiple_primitive_ids_from_path(self, client):
agent_id = "agent-123e4567-e89b-42d3-8456-426614174000"
tool_id = "tool-987e6543-e21c-42d3-9456-426614174000"
response = client.get(f"/v1/agents/{agent_id}/tools/{tool_id}")
assert response.status_code == 200
data = response.json()
assert data["context"]["agent_id"] == agent_id
assert data["context"]["tool_id"] == tool_id
def test_extracts_org_id_with_custom_mapping(self, client):
org_id = "org-123e4567-e89b-42d3-8456-426614174000"
user_id = "user-987e6543-e21c-42d3-9456-426614174000"
response = client.get(f"/v1/organizations/{org_id}/users/{user_id}")
assert response.status_code == 200
data = response.json()
assert data["context"]["org_id"] == org_id
assert data["context"]["user_id"] == user_id
def test_extracts_both_header_and_path_context(self, client):
agent_id = "agent-123e4567-e89b-42d3-8456-426614174000"
response = client.get(f"/v1/agents/{agent_id}", headers={"user_id": "user-abc123"})
assert response.status_code == 200
data = response.json()
assert data["context"]["actor_id"] == "user-abc123"
assert data["context"]["agent_id"] == agent_id
def test_handles_request_without_context(self, client):
response = client.get("/v1/health")
assert response.status_code == 404
def test_context_cleared_between_requests(self, client):
agent_id_1 = "agent-111e4567-e89b-42d3-8456-426614174000"
agent_id_2 = "agent-222e4567-e89b-42d3-8456-426614174000"
response1 = client.get(f"/v1/agents/{agent_id_1}", headers={"user_id": "user-1"})
assert response1.status_code == 200
data1 = response1.json()
assert data1["context"]["agent_id"] == agent_id_1
assert data1["context"]["actor_id"] == "user-1"
response2 = client.get(f"/v1/agents/{agent_id_2}", headers={"user_id": "user-2"})
assert response2.status_code == 200
data2 = response2.json()
assert data2["context"]["agent_id"] == agent_id_2
assert data2["context"]["actor_id"] == "user-2"