fix: Google AI client logging as Vertex (#7337)

fix
This commit is contained in:
Kevin Lin
2025-12-17 15:38:05 -08:00
committed by Caren Thomas
parent 5312129587
commit 33afb930fc
2 changed files with 33 additions and 24 deletions

View File

@@ -14,6 +14,8 @@ logger = get_logger(__name__)
class GoogleAIClient(GoogleVertexClient):
provider_label = "Google AI"
def _get_client(self):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
return genai.Client(

View File

@@ -46,6 +46,7 @@ logger = get_logger(__name__)
class GoogleVertexClient(LLMClientBase):
MAX_RETRIES = model_settings.gemini_max_retries
provider_label = "Google Vertex"
def _get_client(self):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
@@ -56,6 +57,12 @@ class GoogleVertexClient(LLMClientBase):
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
)
def _provider_prefix(self) -> str:
return f"[{self.provider_label}]"
def _provider_name(self) -> str:
return self.provider_label
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
@@ -148,7 +155,7 @@ class GoogleVertexClient(LLMClientBase):
config=request_data["config"],
)
except Exception as e:
logger.error(f"Error streaming Google Vertex request: {e} with request data: {json.dumps(request_data)}")
logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
raise e
# Direct yield - keeps response alive in generator's local scope throughout iteration
# This is required because the SDK's connection lifecycle is tied to the response object
@@ -448,9 +455,9 @@ class GoogleVertexClient(LLMClientBase):
if content is None or content.role is None or content.parts is None:
# This means the response is malformed like MALFORMED_FUNCTION_CALL
if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
raise LLMServerError(f"Malformed response from Google Vertex: {candidate.finish_reason}")
raise LLMServerError(f"Malformed response from {self._provider_name()}: {candidate.finish_reason}")
else:
raise LLMServerError(f"Invalid response data from Google Vertex: {candidate.model_dump()}")
raise LLMServerError(f"Invalid response data from {self._provider_name()}: {candidate.model_dump()}")
role = content.role
assert role == "model", f"Unknown role in response: {role}"
@@ -742,55 +749,55 @@ class GoogleVertexClient(LLMClientBase):
def handle_llm_error(self, e: Exception) -> Exception:
# Handle Google GenAI specific errors
if isinstance(e, errors.ClientError):
logger.warning(f"[Google Vertex] Client error ({e.code}): {e}")
logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
# Handle specific error codes
if e.code == 400:
error_str = str(e).lower()
if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
return ContextWindowExceededError(
message=f"Bad request to Google Vertex (context window exceeded): {str(e)}",
message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
)
else:
return LLMBadRequestError(
message=f"Bad request to Google Vertex: {str(e)}",
message=f"Bad request to {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
elif e.code == 401:
return LLMAuthenticationError(
message=f"Authentication failed with Google Vertex: {str(e)}",
message=f"Authentication failed with {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
elif e.code == 403:
return LLMPermissionDeniedError(
message=f"Permission denied by Google Vertex: {str(e)}",
message=f"Permission denied by {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
elif e.code == 404:
return LLMNotFoundError(
message=f"Resource not found in Google Vertex: {str(e)}",
message=f"Resource not found in {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
elif e.code == 408:
return LLMTimeoutError(
message=f"Request to Google Vertex timed out: {str(e)}",
message=f"Request to {self._provider_name()} timed out: {str(e)}",
code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
elif e.code == 422:
return LLMUnprocessableEntityError(
message=f"Invalid request content for Google Vertex: {str(e)}",
message=f"Invalid request content for {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
elif e.code == 429:
logger.warning("[Google Vertex] Rate limited (429). Consider backoff.")
logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
return LLMRateLimitError(
message=f"Rate limited by Google Vertex: {str(e)}",
message=f"Rate limited by {self._provider_name()}: {str(e)}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
)
else:
return LLMServerError(
message=f"Google Vertex client error: {str(e)}",
message=f"{self._provider_name()} client error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
@@ -799,12 +806,12 @@ class GoogleVertexClient(LLMClientBase):
)
if isinstance(e, errors.ServerError):
logger.warning(f"[Google Vertex] Server error ({e.code}): {e}")
logger.warning(f"{self._provider_prefix()} Server error ({e.code}): {e}")
# Handle specific server error codes
if e.code == 500:
return LLMServerError(
message=f"Google Vertex internal server error: {str(e)}",
message=f"{self._provider_name()} internal server error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
@@ -813,13 +820,13 @@ class GoogleVertexClient(LLMClientBase):
)
elif e.code == 502:
return LLMConnectionError(
message=f"Bad gateway from Google Vertex: {str(e)}",
message=f"Bad gateway from {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
elif e.code == 503:
return LLMServerError(
message=f"Google Vertex service unavailable: {str(e)}",
message=f"{self._provider_name()} service unavailable: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
@@ -828,13 +835,13 @@ class GoogleVertexClient(LLMClientBase):
)
elif e.code == 504:
return LLMTimeoutError(
message=f"Gateway timeout from Google Vertex: {str(e)}",
message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
else:
return LLMServerError(
message=f"Google Vertex server error: {str(e)}",
message=f"{self._provider_name()} server error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
@@ -843,9 +850,9 @@ class GoogleVertexClient(LLMClientBase):
)
if isinstance(e, errors.APIError):
logger.warning(f"[Google Vertex] API error ({e.code}): {e}")
logger.warning(f"{self._provider_prefix()} API error ({e.code}): {e}")
return LLMServerError(
message=f"Google Vertex API error: {str(e)}",
message=f"{self._provider_name()} API error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
@@ -855,9 +862,9 @@ class GoogleVertexClient(LLMClientBase):
# Handle connection-related errors
if "connection" in str(e).lower() or "timeout" in str(e).lower():
logger.warning(f"[Google Vertex] Connection/timeout error: {e}")
logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")
return LLMConnectionError(
message=f"Failed to connect to Google Vertex: {str(e)}",
message=f"Failed to connect to {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)