From 4d8d9757aac7c8a4bad85e60ca3e76f40b4b41f3 Mon Sep 17 00:00:00 2001 From: Ari Webb Date: Thu, 18 Dec 2025 14:59:00 -0800 Subject: [PATCH] feat: add request-id for steps [LET-6587] (#7349) * feat: add request-id for steps * order revisions correctly * stage publish api --- ...b43eea55e_add_request_id_to_steps_table.py | 31 +++++++++ fern/openapi.json | 12 ++++ letta/orm/step.py | 3 + letta/schemas/step.py | 1 + letta/server/rest_api/app.py | 6 +- letta/server/rest_api/middleware/__init__.py | 3 +- .../server/rest_api/middleware/request_id.py | 63 +++++++++++++++++++ letta/services/step_manager.py | 3 + 8 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 alembic/versions/ee2b43eea55e_add_request_id_to_steps_table.py create mode 100644 letta/server/rest_api/middleware/request_id.py diff --git a/alembic/versions/ee2b43eea55e_add_request_id_to_steps_table.py b/alembic/versions/ee2b43eea55e_add_request_id_to_steps_table.py new file mode 100644 index 00000000..5bf92401 --- /dev/null +++ b/alembic/versions/ee2b43eea55e_add_request_id_to_steps_table.py @@ -0,0 +1,31 @@ +"""add request_id to steps table + +Revision ID: ee2b43eea55e +Revises: 39577145c45d +Create Date: 2025-12-17 13:48:08.642245 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "ee2b43eea55e" +down_revision: Union[str, None] = "39577145c45d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("steps", sa.Column("request_id", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("steps", "request_id") + # ### end Alembic commands ### diff --git a/fern/openapi.json b/fern/openapi.json index ad21b85d..762bf34a 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -36560,6 +36560,18 @@ "title": "Trace Id", "description": "The trace id of the agent step." }, + "request_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Request Id", + "description": "The API request log ID from cloud-api for correlating steps with API requests." + }, "messages": { "items": { "$ref": "#/components/schemas/Message" diff --git a/letta/orm/step.py b/letta/orm/step.py index eca32ed5..00f7a37b 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -60,6 +60,9 @@ class Step(SqlalchemyBase, ProjectMixin): tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.") tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.") trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.") + request_id: Mapped[Optional[str]] = mapped_column( + None, nullable=True, doc="The API request log ID from cloud-api for correlating steps with API requests." + ) feedback: Mapped[Optional[str]] = mapped_column( None, nullable=True, doc="The feedback for this step. Must be either 'positive' or 'negative'." ) diff --git a/letta/schemas/step.py b/letta/schemas/step.py index a1609b5e..83126a96 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -38,6 +38,7 @@ class Step(StepBase): tags: List[str] = Field([], description="Metadata tags.") tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.") trace_id: Optional[str] = Field(None, description="The trace id of the agent step.") + request_id: Optional[str] = Field(None, description="The API request log ID from cloud-api for correlating steps with API requests.") messages: List[Message] = Field( [], description="The messages generated during this step. Deprecated: use `GET /v1/steps/{step_id}/messages` endpoint instead", diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index be4b821d..b40e9c52 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -69,7 +69,7 @@ from letta.server.global_exception_handler import setup_global_exception_handler # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right? from letta.server.rest_api.interface import StreamingServerInterface -from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware +from letta.server.rest_api.middleware import CheckPasswordMiddleware, LoggingMiddleware, RequestIdMiddleware from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes from letta.server.rest_api.routers.v1.organizations import router as organizations_router from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin @@ -591,6 +591,10 @@ def create_application() -> "FastAPI": # Add unified logging middleware - enriches log context and logs exceptions app.add_middleware(LoggingMiddleware) + # Add request ID middleware - extracts x-api-request-log-id header and sets it in contextvar + # This is a pure ASGI middleware to properly propagate contextvars to streaming responses + app.add_middleware(RequestIdMiddleware) + app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, diff --git a/letta/server/rest_api/middleware/__init__.py b/letta/server/rest_api/middleware/__init__.py index 50560577..aa3b20ff 100644 --- a/letta/server/rest_api/middleware/__init__.py +++ b/letta/server/rest_api/middleware/__init__.py @@ -1,4 +1,5 @@ from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware from letta.server.rest_api.middleware.logging import LoggingMiddleware +from letta.server.rest_api.middleware.request_id import RequestIdMiddleware -__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware"] +__all__ = ["CheckPasswordMiddleware", "LoggingMiddleware", "RequestIdMiddleware"] diff --git a/letta/server/rest_api/middleware/request_id.py b/letta/server/rest_api/middleware/request_id.py new file mode 100644 index 00000000..b147ee61 --- /dev/null +++ b/letta/server/rest_api/middleware/request_id.py @@ -0,0 +1,63 @@ +""" +Middleware for extracting and propagating API request IDs from cloud-api. + +Uses a pure ASGI middleware pattern to properly propagate the request_id +to streaming responses. BaseHTTPMiddleware has a known limitation where +contextvars are not propagated to streaming response generators. +See: https://github.com/encode/starlette/discussions/1729 + +This middleware: +1. Extracts the x-api-request-log-id header from cloud-api +2. Sets it in the contextvar (for non-streaming code) +3. Stores it in request.state (for streaming responses where contextvars don't propagate) +""" + +from contextvars import ContextVar +from typing import Optional + +from starlette.requests import Request +from starlette.types import ASGIApp, Receive, Scope, Send + +# Contextvar for storing the request ID across async boundaries +request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None) + + +def get_request_id() -> Optional[str]: + """Get the request ID from the current context.""" + return request_id_var.get() + + +class RequestIdMiddleware: + """ + Pure ASGI middleware that extracts and propagates the API request ID. + + The request ID comes from cloud-api via the x-api-request-log-id header + and is used to correlate steps with API request logs. + + This middleware stores the request_id in: + - The request_id_var contextvar (works for non-streaming responses) + - request.state.request_id (works for streaming responses where contextvars may not propagate) + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + # Create a Request object for easier header access + request = Request(scope) + + # Extract request_id from header + request_id = request.headers.get("x-api-request-log-id") + + # Set in contextvar (for non-streaming code paths) + request_id_var.set(request_id) + + # Also store in request.state for streaming responses where contextvars don't propagate + # This is accessible via request.state.request_id throughout the request lifecycle + request.state.request_id = request_id + + await self.app(scope, receive, send) diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index cd57bf5f..64a1e4d3 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -21,6 +21,7 @@ from letta.schemas.step import Step as PydanticStep from letta.schemas.step_metrics import StepMetrics as PydanticStepMetrics from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.server.rest_api.middleware.request_id import get_request_id from letta.services.webhook_service import WebhookService from letta.utils import enforce_types from letta.validators import raise_on_invalid_id @@ -123,6 +124,7 @@ class StepManager: "tags": [], "tid": None, "trace_id": get_trace_id(), # Get the current trace ID + "request_id": get_request_id(), # Get the API request log ID from cloud-api "project_id": project_id, "status": status if status else StepStatus.PENDING, "error_type": error_type, @@ -182,6 +184,7 @@ class StepManager: "tags": [], "tid": None, "trace_id": get_trace_id(), # Get the current trace ID + "request_id": get_request_id(), # Get the API request log ID from cloud-api "project_id": project_id, "status": status if status else StepStatus.PENDING, "error_type": error_type,