fix(core): pass org_id to dulwich via header for git HTTP (#9291)
This commit is contained in:
committed by
Caren Thomas
parent
09d7940090
commit
49d354bac1
@@ -29,6 +29,7 @@ successful `git-receive-pack`.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -55,11 +56,12 @@ except ImportError: # pragma: no cover
|
||||
make_server = None # type: ignore[assignment]
|
||||
_DULWICH_AVAILABLE = False
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -73,6 +75,9 @@ _server_instance = None
|
||||
_repo_cache: Dict[str, str] = {}
|
||||
_repo_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
# org_id for the currently-handled dulwich request (set by a WSGI wrapper).
|
||||
_current_org_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("letta_git_http_org_id", default=None)
|
||||
|
||||
# Dulwich server globals
|
||||
_dulwich_server = None
|
||||
_dulwich_thread: Optional[threading.Thread] = None
|
||||
@@ -143,13 +148,12 @@ def stop_dulwich_server() -> None:
|
||||
logger.exception("Failed to shutdown dulwich server")
|
||||
|
||||
|
||||
def _default_org_id() -> str:
|
||||
if _server_instance is None:
|
||||
raise RuntimeError("Server instance not set")
|
||||
default_user = getattr(_server_instance, "default_user", None)
|
||||
org_id = getattr(default_user, "organization_id", None)
|
||||
def _require_current_org_id() -> str:
|
||||
"""Read the org_id set by the WSGI wrapper for the current request."""
|
||||
|
||||
org_id = _current_org_id.get()
|
||||
if not org_id:
|
||||
raise RuntimeError("Unable to infer org_id for git HTTP path")
|
||||
raise RuntimeError("Missing org_id for git HTTP request")
|
||||
return org_id
|
||||
|
||||
|
||||
@@ -185,7 +189,7 @@ class GCSBackend(Backend):
|
||||
raise ValueError(f"Invalid repository path (expected /{{agent_id}}/state.git): {path}")
|
||||
|
||||
agent_id = parts[0]
|
||||
org_id = _default_org_id()
|
||||
org_id = _require_current_org_id()
|
||||
|
||||
cache_key = f"{org_id}/{agent_id}"
|
||||
logger.info("GCSBackend.open_repository: org=%s agent=%s", org_id, agent_id)
|
||||
@@ -398,13 +402,11 @@ async def _sync_after_push(org_id: str, agent_id: str) -> None:
|
||||
shutil.rmtree(os.path.dirname(repo_path), ignore_errors=True)
|
||||
|
||||
|
||||
def _parse_org_agent_from_repo_path(path: str) -> Optional[tuple[str, str]]:
|
||||
"""Extract (org_id, agent_id) from a git HTTP path.
|
||||
def _parse_agent_id_from_repo_path(path: str) -> Optional[str]:
|
||||
"""Extract agent_id from a git HTTP path.
|
||||
|
||||
Expected path form:
|
||||
- {agent_id}/state.git/...
|
||||
|
||||
org_id is inferred from the running server instance.
|
||||
"""
|
||||
|
||||
parts = path.strip("/").split("/")
|
||||
@@ -414,7 +416,7 @@ def _parse_org_agent_from_repo_path(path: str) -> Optional[tuple[str, str]]:
|
||||
if parts[1] != "state.git":
|
||||
return None
|
||||
|
||||
return _default_org_id(), parts[0]
|
||||
return parts[0]
|
||||
|
||||
|
||||
def _filter_out_hop_by_hop_headers(headers: Iterable[tuple[str, str]]) -> Dict[str, str]:
|
||||
@@ -440,7 +442,12 @@ def _filter_out_hop_by_hop_headers(headers: Iterable[tuple[str, str]]) -> Dict[s
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]) # pragma: no cover
|
||||
async def proxy_git_http(path: str, request: Request):
|
||||
async def proxy_git_http(
|
||||
path: str,
|
||||
request: Request,
|
||||
server=Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""Proxy `/v1/git/*` requests to the local dulwich WSGI server."""
|
||||
|
||||
if not _DULWICH_AVAILABLE:
|
||||
@@ -462,6 +469,13 @@ async def proxy_git_http(path: str, request: Request):
|
||||
req_headers.pop("host", None)
|
||||
req_headers.pop("content-length", None)
|
||||
|
||||
# Resolve org_id from the authenticated actor + agent and forward to dulwich.
|
||||
agent_id = _parse_agent_id_from_repo_path(path)
|
||||
if agent_id is not None:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=[])
|
||||
req_headers["x-organization-id"] = agent.organization_id
|
||||
|
||||
async def _body_iter():
|
||||
async for chunk in request.stream():
|
||||
yield chunk
|
||||
@@ -480,11 +494,15 @@ async def proxy_git_http(path: str, request: Request):
|
||||
|
||||
# If this was a push, trigger our sync.
|
||||
if request.method == "POST" and path.endswith("git-receive-pack") and upstream.status_code < 400:
|
||||
parsed = _parse_org_agent_from_repo_path(path)
|
||||
if parsed is not None:
|
||||
org_id, agent_id = parsed
|
||||
# Fire-and-forget; do not block git client response.
|
||||
asyncio.create_task(_sync_after_push(org_id, agent_id))
|
||||
agent_id = _parse_agent_id_from_repo_path(path)
|
||||
if agent_id is not None:
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=[])
|
||||
# Fire-and-forget; do not block git client response.
|
||||
asyncio.create_task(_sync_after_push(agent.organization_id, agent_id))
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger post-push sync (agent_id=%s)", agent_id)
|
||||
|
||||
async def _aclose_upstream_and_client() -> None:
|
||||
try:
|
||||
@@ -501,6 +519,27 @@ async def proxy_git_http(path: str, request: Request):
|
||||
)
|
||||
|
||||
|
||||
def _org_header_middleware(app):
|
||||
"""WSGI wrapper to capture org_id from proxied requests.
|
||||
|
||||
FastAPI proxies requests to the dulwich server and injects `X-Organization-Id`.
|
||||
Dulwich itself only passes repository *paths* into the Backend, so we capture
|
||||
the org_id from the WSGI environ and stash it in a contextvar.
|
||||
"""
|
||||
|
||||
def _wrapped(environ, start_response):
|
||||
token = None
|
||||
try:
|
||||
org_id = environ.get("HTTP_X_ORGANIZATION_ID")
|
||||
token = _current_org_id.set(org_id)
|
||||
return app(environ, start_response)
|
||||
finally:
|
||||
if token is not None:
|
||||
_current_org_id.reset(token)
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
# dulwich WSGI app (optional)
|
||||
_backend = GCSBackend()
|
||||
_git_wsgi_app = HTTPGitApplication(_backend) if _DULWICH_AVAILABLE and HTTPGitApplication is not None else None
|
||||
_git_wsgi_app = _org_header_middleware(HTTPGitApplication(_backend)) if _DULWICH_AVAILABLE and HTTPGitApplication is not None else None
|
||||
|
||||
Reference in New Issue
Block a user