fix: handle transient network errors in ChatGPT OAuth client (#9462)
- Map httpx.ReadError/WriteError/ConnectError to LLMConnectionError in handle_llm_error so Temporal correctly classifies them as retryable (previously fell through to generic non-retryable LLMError) - Add client-level retry with exponential backoff (up to 3 attempts) on request_async and stream_async for transient transport errors - Stream retry is guarded by has_yielded flag to avoid corrupting partial responses already consumed by the caller
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""ChatGPT OAuth Client - handles requests to chatgpt.com/backend-api/codex/responses."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
|
||||
|
||||
@@ -102,6 +103,10 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
4. Transforms responses back to OpenAI ChatCompletion format
|
||||
"""
|
||||
|
||||
MAX_RETRIES = 3
|
||||
# Transient httpx errors that are safe to retry (connection drops, transport-level failures)
|
||||
_RETRYABLE_ERRORS = (httpx.ReadError, httpx.WriteError, httpx.ConnectError, httpx.RemoteProtocolError)
|
||||
|
||||
@trace_method
|
||||
async def _get_provider_and_credentials_async(self, llm_config: LLMConfig) -> tuple[ChatGPTOAuthProvider, ChatGPTOAuthCredentials]:
|
||||
"""Get the ChatGPT OAuth provider and credentials with automatic refresh if needed.
|
||||
@@ -371,18 +376,20 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT
|
||||
|
||||
# ChatGPT backend requires streaming, so we use client.stream() to handle SSE
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Retry on transient network errors with exponential backoff
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
# Accumulate SSE events into a final response
|
||||
return await self._accumulate_sse_response(response)
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
# Accumulate SSE events into a final response
|
||||
return await self._accumulate_sse_response(response)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_http_error(e)
|
||||
@@ -391,12 +398,29 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
message="ChatGPT backend request timed out",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
)
|
||||
except self._RETRYABLE_ERRORS as e:
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
wait = 2**attempt
|
||||
logger.warning(
|
||||
f"[ChatGPT] Transient error on request (attempt {attempt + 1}/{self.MAX_RETRIES}), "
|
||||
f"retrying in {wait}s: {type(e).__name__}: {e}"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise LLMConnectionError(
|
||||
message=f"Failed to connect to ChatGPT backend after {self.MAX_RETRIES} attempts: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise LLMConnectionError(
|
||||
message=f"Failed to connect to ChatGPT backend: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
# Should not be reached, but satisfy type checker
|
||||
raise LLMConnectionError(message="ChatGPT request failed after all retries", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
async def _accumulate_sse_response(self, response: httpx.Response) -> dict:
|
||||
"""Accumulate SSE stream into a final response.
|
||||
|
||||
@@ -572,69 +596,89 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
# Track sequence_number in case backend doesn't provide it
|
||||
# (OpenAI SDK expects incrementing sequence numbers starting at 0)
|
||||
sequence_counter = 0
|
||||
# Track whether we've yielded any events — once we have, we can't
|
||||
# transparently retry because the caller has already consumed partial data.
|
||||
has_yielded = False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as response:
|
||||
# Check for error status
|
||||
if response.status_code != 200:
|
||||
error_body = await response.aread()
|
||||
logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}")
|
||||
raise self._handle_http_error_from_status(response.status_code, error_body.decode())
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
) as response:
|
||||
# Check for error status
|
||||
if response.status_code != 200:
|
||||
error_body = await response.aread()
|
||||
logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}")
|
||||
raise self._handle_http_error_from_status(response.status_code, error_body.decode())
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
raw_event = json.loads(data_str)
|
||||
event_type = raw_event.get("type")
|
||||
try:
|
||||
raw_event = json.loads(data_str)
|
||||
event_type = raw_event.get("type")
|
||||
|
||||
# Check for error events from the API (context window, rate limit, etc.)
|
||||
if event_type == "error":
|
||||
logger.error(f"ChatGPT SSE error event: {json.dumps(raw_event, default=str)[:1000]}")
|
||||
raise self._handle_sse_error_event(raw_event)
|
||||
# Check for error events from the API (context window, rate limit, etc.)
|
||||
if event_type == "error":
|
||||
logger.error(f"ChatGPT SSE error event: {json.dumps(raw_event, default=str)[:1000]}")
|
||||
raise self._handle_sse_error_event(raw_event)
|
||||
|
||||
# Check for response.failed or response.incomplete events
|
||||
if event_type in ("response.failed", "response.incomplete"):
|
||||
logger.error(f"ChatGPT SSE {event_type} event: {json.dumps(raw_event, default=str)[:1000]}")
|
||||
resp_obj = raw_event.get("response", {})
|
||||
error_info = resp_obj.get("error", {})
|
||||
if error_info:
|
||||
raise self._handle_sse_error_event({"error": error_info, "type": event_type})
|
||||
else:
|
||||
raise LLMBadRequestError(
|
||||
message=f"ChatGPT request failed with status '{event_type}' (no error details provided)",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
# Check for response.failed or response.incomplete events
|
||||
if event_type in ("response.failed", "response.incomplete"):
|
||||
logger.error(f"ChatGPT SSE {event_type} event: {json.dumps(raw_event, default=str)[:1000]}")
|
||||
resp_obj = raw_event.get("response", {})
|
||||
error_info = resp_obj.get("error", {})
|
||||
if error_info:
|
||||
raise self._handle_sse_error_event({"error": error_info, "type": event_type})
|
||||
else:
|
||||
raise LLMBadRequestError(
|
||||
message=f"ChatGPT request failed with status '{event_type}' (no error details provided)",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
# Use backend-provided sequence_number if available, else use counter
|
||||
# This ensures proper ordering even if backend doesn't provide it
|
||||
if "sequence_number" not in raw_event:
|
||||
raw_event["sequence_number"] = sequence_counter
|
||||
sequence_counter = raw_event["sequence_number"] + 1
|
||||
# Use backend-provided sequence_number if available, else use counter
|
||||
# This ensures proper ordering even if backend doesn't provide it
|
||||
if "sequence_number" not in raw_event:
|
||||
raw_event["sequence_number"] = sequence_counter
|
||||
sequence_counter = raw_event["sequence_number"] + 1
|
||||
|
||||
# Track output index for output_item.added events
|
||||
if event_type == "response.output_item.added":
|
||||
output_index = raw_event.get("output_index", output_index)
|
||||
# Track output index for output_item.added events
|
||||
if event_type == "response.output_item.added":
|
||||
output_index = raw_event.get("output_index", output_index)
|
||||
|
||||
# Convert to OpenAI SDK ResponseStreamEvent
|
||||
sdk_event = self._convert_to_sdk_event(raw_event, output_index)
|
||||
if sdk_event:
|
||||
yield sdk_event
|
||||
# Convert to OpenAI SDK ResponseStreamEvent
|
||||
sdk_event = self._convert_to_sdk_event(raw_event, output_index)
|
||||
if sdk_event:
|
||||
yield sdk_event
|
||||
has_yielded = True
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse SSE event: {data_str[:100]}")
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse SSE event: {data_str[:100]}")
|
||||
continue
|
||||
|
||||
# Stream completed successfully
|
||||
return
|
||||
|
||||
except self._RETRYABLE_ERRORS as e:
|
||||
if has_yielded or attempt >= self.MAX_RETRIES - 1:
|
||||
# Already yielded partial data or exhausted retries — must propagate
|
||||
raise
|
||||
wait = 2**attempt
|
||||
logger.warning(
|
||||
f"[ChatGPT] Transient error on stream (attempt {attempt + 1}/{self.MAX_RETRIES}), "
|
||||
f"retrying in {wait}s: {type(e).__name__}: {e}"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
# Wrap the async generator in AsyncStreamWrapper to provide context manager protocol
|
||||
return AsyncStreamWrapper(stream_generator())
|
||||
@@ -1038,6 +1082,16 @@ class ChatGPTOAuthClient(LLMClientBase):
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
return self._handle_http_error(e, is_byok=is_byok)
|
||||
|
||||
# Handle httpx network errors which can occur during streaming
|
||||
# when the connection is unexpectedly closed while reading/writing
|
||||
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
|
||||
logger.warning(f"[ChatGPT] Network error during streaming: {type(e).__name__}: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Network error during ChatGPT streaming: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
|
||||
)
|
||||
|
||||
return super().handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
def _handle_http_error(self, e: httpx.HTTPStatusError, is_byok: bool | None = None) -> Exception:
|
||||
|
||||
Reference in New Issue
Block a user