diff --git a/letta/llm_api/chatgpt_oauth_client.py b/letta/llm_api/chatgpt_oauth_client.py index 96a0f15b..35ba3887 100644 --- a/letta/llm_api/chatgpt_oauth_client.py +++ b/letta/llm_api/chatgpt_oauth_client.py @@ -1,5 +1,6 @@ """ChatGPT OAuth Client - handles requests to chatgpt.com/backend-api/codex/responses.""" +import asyncio import json from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union @@ -102,6 +103,10 @@ class ChatGPTOAuthClient(LLMClientBase): 4. Transforms responses back to OpenAI ChatCompletion format """ + MAX_RETRIES = 3 + # Transient httpx errors that are safe to retry (connection drops, transport-level failures) + _RETRYABLE_ERRORS = (httpx.ReadError, httpx.WriteError, httpx.ConnectError, httpx.RemoteProtocolError) + @trace_method async def _get_provider_and_credentials_async(self, llm_config: LLMConfig) -> tuple[ChatGPTOAuthProvider, ChatGPTOAuthCredentials]: """Get the ChatGPT OAuth provider and credentials with automatic refresh if needed. @@ -371,18 +376,20 @@ class ChatGPTOAuthClient(LLMClientBase): endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT # ChatGPT backend requires streaming, so we use client.stream() to handle SSE - async with httpx.AsyncClient() as client: + # Retry on transient network errors with exponential backoff + for attempt in range(self.MAX_RETRIES): try: - async with client.stream( - "POST", - endpoint, - json=request_data, - headers=headers, - timeout=120.0, - ) as response: - response.raise_for_status() - # Accumulate SSE events into a final response - return await self._accumulate_sse_response(response) + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", + endpoint, + json=request_data, + headers=headers, + timeout=120.0, + ) as response: + response.raise_for_status() + # Accumulate SSE events into a final response + return await self._accumulate_sse_response(response) except httpx.HTTPStatusError as e: raise self._handle_http_error(e) @@ -391,12 +398,29 @@ class ChatGPTOAuthClient(LLMClientBase): message="ChatGPT backend request timed out", code=ErrorCode.TIMEOUT, ) + except self._RETRYABLE_ERRORS as e: + if attempt < self.MAX_RETRIES - 1: + wait = 2**attempt + logger.warning( + f"[ChatGPT] Transient error on request (attempt {attempt + 1}/{self.MAX_RETRIES}), " + f"retrying in {wait}s: {type(e).__name__}: {e}" + ) + await asyncio.sleep(wait) + continue + raise LLMConnectionError( + message=f"Failed to connect to ChatGPT backend after {self.MAX_RETRIES} attempts: {str(e)}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__}, + ) except httpx.RequestError as e: raise LLMConnectionError( message=f"Failed to connect to ChatGPT backend: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, ) + # Should not be reached, but satisfy type checker + raise LLMConnectionError(message="ChatGPT request failed after all retries", code=ErrorCode.INTERNAL_SERVER_ERROR) + async def _accumulate_sse_response(self, response: httpx.Response) -> dict: """Accumulate SSE stream into a final response. @@ -572,69 +596,89 @@ class ChatGPTOAuthClient(LLMClientBase): # Track sequence_number in case backend doesn't provide it # (OpenAI SDK expects incrementing sequence numbers starting at 0) sequence_counter = 0 + # Track whether we've yielded any events — once we have, we can't + # transparently retry because the caller has already consumed partial data. + has_yielded = False - async with httpx.AsyncClient() as client: - async with client.stream( - "POST", - endpoint, - json=request_data, - headers=headers, - timeout=120.0, - ) as response: - # Check for error status - if response.status_code != 200: - error_body = await response.aread() - logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}") - raise self._handle_http_error_from_status(response.status_code, error_body.decode()) + for attempt in range(self.MAX_RETRIES): + try: + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", + endpoint, + json=request_data, + headers=headers, + timeout=120.0, + ) as response: + # Check for error status + if response.status_code != 200: + error_body = await response.aread() + logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}") + raise self._handle_http_error_from_status(response.status_code, error_body.decode()) - async for line in response.aiter_lines(): - if not line or not line.startswith("data: "): - continue + async for line in response.aiter_lines(): + if not line or not line.startswith("data: "): + continue - data_str = line[6:] - if data_str == "[DONE]": - break + data_str = line[6:] + if data_str == "[DONE]": + break - try: - raw_event = json.loads(data_str) - event_type = raw_event.get("type") + try: + raw_event = json.loads(data_str) + event_type = raw_event.get("type") - # Check for error events from the API (context window, rate limit, etc.) - if event_type == "error": - logger.error(f"ChatGPT SSE error event: {json.dumps(raw_event, default=str)[:1000]}") - raise self._handle_sse_error_event(raw_event) + # Check for error events from the API (context window, rate limit, etc.) + if event_type == "error": + logger.error(f"ChatGPT SSE error event: {json.dumps(raw_event, default=str)[:1000]}") + raise self._handle_sse_error_event(raw_event) - # Check for response.failed or response.incomplete events - if event_type in ("response.failed", "response.incomplete"): - logger.error(f"ChatGPT SSE {event_type} event: {json.dumps(raw_event, default=str)[:1000]}") - resp_obj = raw_event.get("response", {}) - error_info = resp_obj.get("error", {}) - if error_info: - raise self._handle_sse_error_event({"error": error_info, "type": event_type}) - else: - raise LLMBadRequestError( - message=f"ChatGPT request failed with status '{event_type}' (no error details provided)", - code=ErrorCode.INTERNAL_SERVER_ERROR, - ) + # Check for response.failed or response.incomplete events + if event_type in ("response.failed", "response.incomplete"): + logger.error(f"ChatGPT SSE {event_type} event: {json.dumps(raw_event, default=str)[:1000]}") + resp_obj = raw_event.get("response", {}) + error_info = resp_obj.get("error", {}) + if error_info: + raise self._handle_sse_error_event({"error": error_info, "type": event_type}) + else: + raise LLMBadRequestError( + message=f"ChatGPT request failed with status '{event_type}' (no error details provided)", + code=ErrorCode.INTERNAL_SERVER_ERROR, + ) - # Use backend-provided sequence_number if available, else use counter - # This ensures proper ordering even if backend doesn't provide it - if "sequence_number" not in raw_event: - raw_event["sequence_number"] = sequence_counter - sequence_counter = raw_event["sequence_number"] + 1 + # Use backend-provided sequence_number if available, else use counter + # This ensures proper ordering even if backend doesn't provide it + if "sequence_number" not in raw_event: + raw_event["sequence_number"] = sequence_counter + sequence_counter = raw_event["sequence_number"] + 1 - # Track output index for output_item.added events - if event_type == "response.output_item.added": - output_index = raw_event.get("output_index", output_index) + # Track output index for output_item.added events + if event_type == "response.output_item.added": + output_index = raw_event.get("output_index", output_index) - # Convert to OpenAI SDK ResponseStreamEvent - sdk_event = self._convert_to_sdk_event(raw_event, output_index) - if sdk_event: - yield sdk_event + # Convert to OpenAI SDK ResponseStreamEvent + sdk_event = self._convert_to_sdk_event(raw_event, output_index) + if sdk_event: + yield sdk_event + has_yielded = True - except json.JSONDecodeError: - logger.warning(f"Failed to parse SSE event: {data_str[:100]}") - continue + except json.JSONDecodeError: + logger.warning(f"Failed to parse SSE event: {data_str[:100]}") + continue + + # Stream completed successfully + return + + except self._RETRYABLE_ERRORS as e: + if has_yielded or attempt >= self.MAX_RETRIES - 1: + # Already yielded partial data or exhausted retries — must propagate + raise + wait = 2**attempt + logger.warning( + f"[ChatGPT] Transient error on stream (attempt {attempt + 1}/{self.MAX_RETRIES}), " + f"retrying in {wait}s: {type(e).__name__}: {e}" + ) + await asyncio.sleep(wait) # Wrap the async generator in AsyncStreamWrapper to provide context manager protocol return AsyncStreamWrapper(stream_generator()) @@ -1038,6 +1082,16 @@ class ChatGPTOAuthClient(LLMClientBase): if isinstance(e, httpx.HTTPStatusError): return self._handle_http_error(e, is_byok=is_byok) + # Handle httpx network errors which can occur during streaming + # when the connection is unexpectedly closed while reading/writing + if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)): + logger.warning(f"[ChatGPT] Network error during streaming: {type(e).__name__}: {e}") + return LLMConnectionError( + message=f"Network error during ChatGPT streaming: {str(e)}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok}, + ) + return super().handle_llm_error(e, llm_config=llm_config) def _handle_http_error(self, e: httpx.HTTPStatusError, is_byok: bool | None = None) -> Exception: