@@ -14,6 +14,8 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GoogleAIClient(GoogleVertexClient):
|
||||
provider_label = "Google AI"
|
||||
|
||||
def _get_client(self):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
return genai.Client(
|
||||
|
||||
@@ -46,6 +46,7 @@ logger = get_logger(__name__)
|
||||
|
||||
class GoogleVertexClient(LLMClientBase):
|
||||
MAX_RETRIES = model_settings.gemini_max_retries
|
||||
provider_label = "Google Vertex"
|
||||
|
||||
def _get_client(self):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
@@ -56,6 +57,12 @@ class GoogleVertexClient(LLMClientBase):
|
||||
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
|
||||
)
|
||||
|
||||
def _provider_prefix(self) -> str:
|
||||
return f"[{self.provider_label}]"
|
||||
|
||||
def _provider_name(self) -> str:
|
||||
return self.provider_label
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
@@ -148,7 +155,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
config=request_data["config"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Google Vertex request: {e} with request data: {json.dumps(request_data)}")
|
||||
logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
# Direct yield - keeps response alive in generator's local scope throughout iteration
|
||||
# This is required because the SDK's connection lifecycle is tied to the response object
|
||||
@@ -448,9 +455,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
if content is None or content.role is None or content.parts is None:
|
||||
# This means the response is malformed like MALFORMED_FUNCTION_CALL
|
||||
if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
|
||||
raise LLMServerError(f"Malformed response from Google Vertex: {candidate.finish_reason}")
|
||||
raise LLMServerError(f"Malformed response from {self._provider_name()}: {candidate.finish_reason}")
|
||||
else:
|
||||
raise LLMServerError(f"Invalid response data from Google Vertex: {candidate.model_dump()}")
|
||||
raise LLMServerError(f"Invalid response data from {self._provider_name()}: {candidate.model_dump()}")
|
||||
|
||||
role = content.role
|
||||
assert role == "model", f"Unknown role in response: {role}"
|
||||
@@ -742,55 +749,55 @@ class GoogleVertexClient(LLMClientBase):
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
# Handle Google GenAI specific errors
|
||||
if isinstance(e, errors.ClientError):
|
||||
logger.warning(f"[Google Vertex] Client error ({e.code}): {e}")
|
||||
logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
|
||||
|
||||
# Handle specific error codes
|
||||
if e.code == 400:
|
||||
error_str = str(e).lower()
|
||||
if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
|
||||
return ContextWindowExceededError(
|
||||
message=f"Bad request to Google Vertex (context window exceeded): {str(e)}",
|
||||
message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
|
||||
)
|
||||
else:
|
||||
return LLMBadRequestError(
|
||||
message=f"Bad request to Google Vertex: {str(e)}",
|
||||
message=f"Bad request to {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 401:
|
||||
return LLMAuthenticationError(
|
||||
message=f"Authentication failed with Google Vertex: {str(e)}",
|
||||
message=f"Authentication failed with {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 403:
|
||||
return LLMPermissionDeniedError(
|
||||
message=f"Permission denied by Google Vertex: {str(e)}",
|
||||
message=f"Permission denied by {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 404:
|
||||
return LLMNotFoundError(
|
||||
message=f"Resource not found in Google Vertex: {str(e)}",
|
||||
message=f"Resource not found in {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 408:
|
||||
return LLMTimeoutError(
|
||||
message=f"Request to Google Vertex timed out: {str(e)}",
|
||||
message=f"Request to {self._provider_name()} timed out: {str(e)}",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
elif e.code == 422:
|
||||
return LLMUnprocessableEntityError(
|
||||
message=f"Invalid request content for Google Vertex: {str(e)}",
|
||||
message=f"Invalid request content for {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
elif e.code == 429:
|
||||
logger.warning("[Google Vertex] Rate limited (429). Consider backoff.")
|
||||
logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
|
||||
return LLMRateLimitError(
|
||||
message=f"Rate limited by Google Vertex: {str(e)}",
|
||||
message=f"Rate limited by {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
)
|
||||
else:
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex client error: {str(e)}",
|
||||
message=f"{self._provider_name()} client error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -799,12 +806,12 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
|
||||
if isinstance(e, errors.ServerError):
|
||||
logger.warning(f"[Google Vertex] Server error ({e.code}): {e}")
|
||||
logger.warning(f"{self._provider_prefix()} Server error ({e.code}): {e}")
|
||||
|
||||
# Handle specific server error codes
|
||||
if e.code == 500:
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex internal server error: {str(e)}",
|
||||
message=f"{self._provider_name()} internal server error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -813,13 +820,13 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
elif e.code == 502:
|
||||
return LLMConnectionError(
|
||||
message=f"Bad gateway from Google Vertex: {str(e)}",
|
||||
message=f"Bad gateway from {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
elif e.code == 503:
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex service unavailable: {str(e)}",
|
||||
message=f"{self._provider_name()} service unavailable: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -828,13 +835,13 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
elif e.code == 504:
|
||||
return LLMTimeoutError(
|
||||
message=f"Gateway timeout from Google Vertex: {str(e)}",
|
||||
message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.TIMEOUT,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
else:
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex server error: {str(e)}",
|
||||
message=f"{self._provider_name()} server error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -843,9 +850,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
|
||||
if isinstance(e, errors.APIError):
|
||||
logger.warning(f"[Google Vertex] API error ({e.code}): {e}")
|
||||
logger.warning(f"{self._provider_prefix()} API error ({e.code}): {e}")
|
||||
return LLMServerError(
|
||||
message=f"Google Vertex API error: {str(e)}",
|
||||
message=f"{self._provider_name()} API error: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={
|
||||
"status_code": e.code,
|
||||
@@ -855,9 +862,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
|
||||
# Handle connection-related errors
|
||||
if "connection" in str(e).lower() or "timeout" in str(e).lower():
|
||||
logger.warning(f"[Google Vertex] Connection/timeout error: {e}")
|
||||
logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Failed to connect to Google Vertex: {str(e)}",
|
||||
message=f"Failed to connect to {self._provider_name()}: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user