fix: Google AI client logging as Vertex (#7337)

fix
This commit is contained in:
Kevin Lin
2025-12-17 15:38:05 -08:00
committed by Caren Thomas
parent 5312129587
commit 33afb930fc
2 changed files with 33 additions and 24 deletions

View File

@@ -14,6 +14,8 @@ logger = get_logger(__name__)
class GoogleAIClient(GoogleVertexClient): class GoogleAIClient(GoogleVertexClient):
provider_label = "Google AI"
def _get_client(self): def _get_client(self):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000) timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
return genai.Client( return genai.Client(

View File

@@ -46,6 +46,7 @@ logger = get_logger(__name__)
class GoogleVertexClient(LLMClientBase): class GoogleVertexClient(LLMClientBase):
MAX_RETRIES = model_settings.gemini_max_retries MAX_RETRIES = model_settings.gemini_max_retries
provider_label = "Google Vertex"
def _get_client(self): def _get_client(self):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000) timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
@@ -56,6 +57,12 @@ class GoogleVertexClient(LLMClientBase):
http_options=HttpOptions(api_version="v1", timeout=timeout_ms), http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
) )
def _provider_prefix(self) -> str:
return f"[{self.provider_label}]"
def _provider_name(self) -> str:
return self.provider_label
@trace_method @trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict: def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
""" """
@@ -148,7 +155,7 @@ class GoogleVertexClient(LLMClientBase):
config=request_data["config"], config=request_data["config"],
) )
except Exception as e: except Exception as e:
logger.error(f"Error streaming Google Vertex request: {e} with request data: {json.dumps(request_data)}") logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
raise e raise e
# Direct yield - keeps response alive in generator's local scope throughout iteration # Direct yield - keeps response alive in generator's local scope throughout iteration
# This is required because the SDK's connection lifecycle is tied to the response object # This is required because the SDK's connection lifecycle is tied to the response object
@@ -448,9 +455,9 @@ class GoogleVertexClient(LLMClientBase):
if content is None or content.role is None or content.parts is None: if content is None or content.role is None or content.parts is None:
# This means the response is malformed like MALFORMED_FUNCTION_CALL # This means the response is malformed like MALFORMED_FUNCTION_CALL
if candidate.finish_reason == "MALFORMED_FUNCTION_CALL": if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
raise LLMServerError(f"Malformed response from Google Vertex: {candidate.finish_reason}") raise LLMServerError(f"Malformed response from {self._provider_name()}: {candidate.finish_reason}")
else: else:
raise LLMServerError(f"Invalid response data from Google Vertex: {candidate.model_dump()}") raise LLMServerError(f"Invalid response data from {self._provider_name()}: {candidate.model_dump()}")
role = content.role role = content.role
assert role == "model", f"Unknown role in response: {role}" assert role == "model", f"Unknown role in response: {role}"
@@ -742,55 +749,55 @@ class GoogleVertexClient(LLMClientBase):
def handle_llm_error(self, e: Exception) -> Exception: def handle_llm_error(self, e: Exception) -> Exception:
# Handle Google GenAI specific errors # Handle Google GenAI specific errors
if isinstance(e, errors.ClientError): if isinstance(e, errors.ClientError):
logger.warning(f"[Google Vertex] Client error ({e.code}): {e}") logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
# Handle specific error codes # Handle specific error codes
if e.code == 400: if e.code == 400:
error_str = str(e).lower() error_str = str(e).lower()
if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str): if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
return ContextWindowExceededError( return ContextWindowExceededError(
message=f"Bad request to Google Vertex (context window exceeded): {str(e)}", message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
) )
else: else:
return LLMBadRequestError( return LLMBadRequestError(
message=f"Bad request to Google Vertex: {str(e)}", message=f"Bad request to {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
) )
elif e.code == 401: elif e.code == 401:
return LLMAuthenticationError( return LLMAuthenticationError(
message=f"Authentication failed with Google Vertex: {str(e)}", message=f"Authentication failed with {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
) )
elif e.code == 403: elif e.code == 403:
return LLMPermissionDeniedError( return LLMPermissionDeniedError(
message=f"Permission denied by Google Vertex: {str(e)}", message=f"Permission denied by {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
) )
elif e.code == 404: elif e.code == 404:
return LLMNotFoundError( return LLMNotFoundError(
message=f"Resource not found in Google Vertex: {str(e)}", message=f"Resource not found in {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
) )
elif e.code == 408: elif e.code == 408:
return LLMTimeoutError( return LLMTimeoutError(
message=f"Request to Google Vertex timed out: {str(e)}", message=f"Request to {self._provider_name()} timed out: {str(e)}",
code=ErrorCode.TIMEOUT, code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None}, details={"cause": str(e.__cause__) if e.__cause__ else None},
) )
elif e.code == 422: elif e.code == 422:
return LLMUnprocessableEntityError( return LLMUnprocessableEntityError(
message=f"Invalid request content for Google Vertex: {str(e)}", message=f"Invalid request content for {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
) )
elif e.code == 429: elif e.code == 429:
logger.warning("[Google Vertex] Rate limited (429). Consider backoff.") logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
return LLMRateLimitError( return LLMRateLimitError(
message=f"Rate limited by Google Vertex: {str(e)}", message=f"Rate limited by {self._provider_name()}: {str(e)}",
code=ErrorCode.RATE_LIMIT_EXCEEDED, code=ErrorCode.RATE_LIMIT_EXCEEDED,
) )
else: else:
return LLMServerError( return LLMServerError(
message=f"Google Vertex client error: {str(e)}", message=f"{self._provider_name()} client error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
details={ details={
"status_code": e.code, "status_code": e.code,
@@ -799,12 +806,12 @@ class GoogleVertexClient(LLMClientBase):
) )
if isinstance(e, errors.ServerError): if isinstance(e, errors.ServerError):
logger.warning(f"[Google Vertex] Server error ({e.code}): {e}") logger.warning(f"{self._provider_prefix()} Server error ({e.code}): {e}")
# Handle specific server error codes # Handle specific server error codes
if e.code == 500: if e.code == 500:
return LLMServerError( return LLMServerError(
message=f"Google Vertex internal server error: {str(e)}", message=f"{self._provider_name()} internal server error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
details={ details={
"status_code": e.code, "status_code": e.code,
@@ -813,13 +820,13 @@ class GoogleVertexClient(LLMClientBase):
) )
elif e.code == 502: elif e.code == 502:
return LLMConnectionError( return LLMConnectionError(
message=f"Bad gateway from Google Vertex: {str(e)}", message=f"Bad gateway from {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None}, details={"cause": str(e.__cause__) if e.__cause__ else None},
) )
elif e.code == 503: elif e.code == 503:
return LLMServerError( return LLMServerError(
message=f"Google Vertex service unavailable: {str(e)}", message=f"{self._provider_name()} service unavailable: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
details={ details={
"status_code": e.code, "status_code": e.code,
@@ -828,13 +835,13 @@ class GoogleVertexClient(LLMClientBase):
) )
elif e.code == 504: elif e.code == 504:
return LLMTimeoutError( return LLMTimeoutError(
message=f"Gateway timeout from Google Vertex: {str(e)}", message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
code=ErrorCode.TIMEOUT, code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None}, details={"cause": str(e.__cause__) if e.__cause__ else None},
) )
else: else:
return LLMServerError( return LLMServerError(
message=f"Google Vertex server error: {str(e)}", message=f"{self._provider_name()} server error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
details={ details={
"status_code": e.code, "status_code": e.code,
@@ -843,9 +850,9 @@ class GoogleVertexClient(LLMClientBase):
) )
if isinstance(e, errors.APIError): if isinstance(e, errors.APIError):
logger.warning(f"[Google Vertex] API error ({e.code}): {e}") logger.warning(f"{self._provider_prefix()} API error ({e.code}): {e}")
return LLMServerError( return LLMServerError(
message=f"Google Vertex API error: {str(e)}", message=f"{self._provider_name()} API error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
details={ details={
"status_code": e.code, "status_code": e.code,
@@ -855,9 +862,9 @@ class GoogleVertexClient(LLMClientBase):
# Handle connection-related errors # Handle connection-related errors
if "connection" in str(e).lower() or "timeout" in str(e).lower(): if "connection" in str(e).lower() or "timeout" in str(e).lower():
logger.warning(f"[Google Vertex] Connection/timeout error: {e}") logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")
return LLMConnectionError( return LLMConnectionError(
message=f"Failed to connect to Google Vertex: {str(e)}", message=f"Failed to connect to {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR, code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None}, details={"cause": str(e.__cause__) if e.__cause__ else None},
) )