fix: handle transient network errors in ChatGPT OAuth client (#9462)

- Map httpx.ReadError/WriteError/ConnectError to LLMConnectionError in
  handle_llm_error so Temporal correctly classifies them as retryable
  (previously fell through to generic non-retryable LLMError)
- Add client-level retry with exponential backoff (up to 3 attempts) on
  request_async and stream_async for transient transport errors
- Stream retry is guarded by has_yielded flag to avoid corrupting
  partial responses already consumed by the caller
This commit is contained in:
jnjpng
2026-02-12 12:43:53 -08:00
committed by Caren Thomas
parent 4126fdadea
commit 778f28ccf3

View File

@@ -1,5 +1,6 @@
"""ChatGPT OAuth Client - handles requests to chatgpt.com/backend-api/codex/responses."""
import asyncio
import json
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
@@ -102,6 +103,10 @@ class ChatGPTOAuthClient(LLMClientBase):
4. Transforms responses back to OpenAI ChatCompletion format
"""
MAX_RETRIES = 3
# Transient httpx errors that are safe to retry (connection drops, transport-level failures)
_RETRYABLE_ERRORS = (httpx.ReadError, httpx.WriteError, httpx.ConnectError, httpx.RemoteProtocolError)
@trace_method
async def _get_provider_and_credentials_async(self, llm_config: LLMConfig) -> tuple[ChatGPTOAuthProvider, ChatGPTOAuthCredentials]:
"""Get the ChatGPT OAuth provider and credentials with automatic refresh if needed.
@@ -371,18 +376,20 @@ class ChatGPTOAuthClient(LLMClientBase):
endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT
# ChatGPT backend requires streaming, so we use client.stream() to handle SSE
async with httpx.AsyncClient() as client:
# Retry on transient network errors with exponential backoff
for attempt in range(self.MAX_RETRIES):
try:
async with client.stream(
"POST",
endpoint,
json=request_data,
headers=headers,
timeout=120.0,
) as response:
response.raise_for_status()
# Accumulate SSE events into a final response
return await self._accumulate_sse_response(response)
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
endpoint,
json=request_data,
headers=headers,
timeout=120.0,
) as response:
response.raise_for_status()
# Accumulate SSE events into a final response
return await self._accumulate_sse_response(response)
except httpx.HTTPStatusError as e:
raise self._handle_http_error(e)
@@ -391,12 +398,29 @@ class ChatGPTOAuthClient(LLMClientBase):
message="ChatGPT backend request timed out",
code=ErrorCode.TIMEOUT,
)
except self._RETRYABLE_ERRORS as e:
if attempt < self.MAX_RETRIES - 1:
wait = 2**attempt
logger.warning(
f"[ChatGPT] Transient error on request (attempt {attempt + 1}/{self.MAX_RETRIES}), "
f"retrying in {wait}s: {type(e).__name__}: {e}"
)
await asyncio.sleep(wait)
continue
raise LLMConnectionError(
message=f"Failed to connect to ChatGPT backend after {self.MAX_RETRIES} attempts: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
)
except httpx.RequestError as e:
raise LLMConnectionError(
message=f"Failed to connect to ChatGPT backend: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
# Should not be reached, but satisfy type checker
raise LLMConnectionError(message="ChatGPT request failed after all retries", code=ErrorCode.INTERNAL_SERVER_ERROR)
async def _accumulate_sse_response(self, response: httpx.Response) -> dict:
"""Accumulate SSE stream into a final response.
@@ -572,69 +596,89 @@ class ChatGPTOAuthClient(LLMClientBase):
# Track sequence_number in case backend doesn't provide it
# (OpenAI SDK expects incrementing sequence numbers starting at 0)
sequence_counter = 0
# Track whether we've yielded any events — once we have, we can't
# transparently retry because the caller has already consumed partial data.
has_yielded = False
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
endpoint,
json=request_data,
headers=headers,
timeout=120.0,
) as response:
# Check for error status
if response.status_code != 200:
error_body = await response.aread()
logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}")
raise self._handle_http_error_from_status(response.status_code, error_body.decode())
for attempt in range(self.MAX_RETRIES):
try:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
endpoint,
json=request_data,
headers=headers,
timeout=120.0,
) as response:
# Check for error status
if response.status_code != 200:
error_body = await response.aread()
logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}")
raise self._handle_http_error_from_status(response.status_code, error_body.decode())
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
data_str = line[6:]
if data_str == "[DONE]":
break
try:
raw_event = json.loads(data_str)
event_type = raw_event.get("type")
try:
raw_event = json.loads(data_str)
event_type = raw_event.get("type")
# Check for error events from the API (context window, rate limit, etc.)
if event_type == "error":
logger.error(f"ChatGPT SSE error event: {json.dumps(raw_event, default=str)[:1000]}")
raise self._handle_sse_error_event(raw_event)
# Check for error events from the API (context window, rate limit, etc.)
if event_type == "error":
logger.error(f"ChatGPT SSE error event: {json.dumps(raw_event, default=str)[:1000]}")
raise self._handle_sse_error_event(raw_event)
# Check for response.failed or response.incomplete events
if event_type in ("response.failed", "response.incomplete"):
logger.error(f"ChatGPT SSE {event_type} event: {json.dumps(raw_event, default=str)[:1000]}")
resp_obj = raw_event.get("response", {})
error_info = resp_obj.get("error", {})
if error_info:
raise self._handle_sse_error_event({"error": error_info, "type": event_type})
else:
raise LLMBadRequestError(
message=f"ChatGPT request failed with status '{event_type}' (no error details provided)",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
# Check for response.failed or response.incomplete events
if event_type in ("response.failed", "response.incomplete"):
logger.error(f"ChatGPT SSE {event_type} event: {json.dumps(raw_event, default=str)[:1000]}")
resp_obj = raw_event.get("response", {})
error_info = resp_obj.get("error", {})
if error_info:
raise self._handle_sse_error_event({"error": error_info, "type": event_type})
else:
raise LLMBadRequestError(
message=f"ChatGPT request failed with status '{event_type}' (no error details provided)",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
# Use backend-provided sequence_number if available, else use counter
# This ensures proper ordering even if backend doesn't provide it
if "sequence_number" not in raw_event:
raw_event["sequence_number"] = sequence_counter
sequence_counter = raw_event["sequence_number"] + 1
# Use backend-provided sequence_number if available, else use counter
# This ensures proper ordering even if backend doesn't provide it
if "sequence_number" not in raw_event:
raw_event["sequence_number"] = sequence_counter
sequence_counter = raw_event["sequence_number"] + 1
# Track output index for output_item.added events
if event_type == "response.output_item.added":
output_index = raw_event.get("output_index", output_index)
# Track output index for output_item.added events
if event_type == "response.output_item.added":
output_index = raw_event.get("output_index", output_index)
# Convert to OpenAI SDK ResponseStreamEvent
sdk_event = self._convert_to_sdk_event(raw_event, output_index)
if sdk_event:
yield sdk_event
# Convert to OpenAI SDK ResponseStreamEvent
sdk_event = self._convert_to_sdk_event(raw_event, output_index)
if sdk_event:
yield sdk_event
has_yielded = True
except json.JSONDecodeError:
logger.warning(f"Failed to parse SSE event: {data_str[:100]}")
continue
except json.JSONDecodeError:
logger.warning(f"Failed to parse SSE event: {data_str[:100]}")
continue
# Stream completed successfully
return
except self._RETRYABLE_ERRORS as e:
if has_yielded or attempt >= self.MAX_RETRIES - 1:
# Already yielded partial data or exhausted retries — must propagate
raise
wait = 2**attempt
logger.warning(
f"[ChatGPT] Transient error on stream (attempt {attempt + 1}/{self.MAX_RETRIES}), "
f"retrying in {wait}s: {type(e).__name__}: {e}"
)
await asyncio.sleep(wait)
# Wrap the async generator in AsyncStreamWrapper to provide context manager protocol
return AsyncStreamWrapper(stream_generator())
@@ -1038,6 +1082,16 @@ class ChatGPTOAuthClient(LLMClientBase):
if isinstance(e, httpx.HTTPStatusError):
return self._handle_http_error(e, is_byok=is_byok)
# Handle httpx network errors which can occur during streaming
# when the connection is unexpectedly closed while reading/writing
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
logger.warning(f"[ChatGPT] Network error during streaming: {type(e).__name__}: {e}")
return LLMConnectionError(
message=f"Network error during ChatGPT streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
)
return super().handle_llm_error(e, llm_config=llm_config)
def _handle_http_error(self, e: httpx.HTTPStatusError, is_byok: bool | None = None) -> Exception: