diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index ff7a37b2..62d93ee8 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -3,6 +3,7 @@ import uuid from typing import List, Optional from google import genai +from google.genai import errors from google.genai.types import ( FunctionCallingConfig, FunctionCallingConfigMode, @@ -67,11 +68,21 @@ class GoogleVertexClient(LLMClientBase): retry_count = 1 should_retry = True while should_retry and retry_count <= self.MAX_RETRIES: - response = await client.aio.models.generate_content( - model=llm_config.model, - contents=request_data["contents"], - config=request_data["config"], - ) + try: + response = await client.aio.models.generate_content( + model=llm_config.model, + contents=request_data["contents"], + config=request_data["config"], + ) + except errors.APIError as e: + # Retry on 503 and 500 errors as well, usually ephemeral from Gemini + if e.code == 503 or e.code == 500: + logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}") + retry_count += 1 + continue + raise e + except Exception as e: + raise e response_data = response.model_dump() is_malformed_function_call = self.is_malformed_function_call(response_data) if is_malformed_function_call: diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 756adb5d..b47a4d54 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -1334,6 +1334,7 @@ def test_background_token_streaming_tool_call( messages=messages_to_send, stream_tokens=True, background=True, + request_options={"timeout_in_seconds": 300}, ) verify_token_streaming = ( llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model