diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index b1281b62..f765a121 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -8,6 +8,7 @@ from letta.errors import ErrorCode, LLMAuthenticationError, LLMError from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.log import get_logger +from letta.schemas.llm_config import LLMConfig from letta.settings import model_settings, settings logger = get_logger(__name__) @@ -16,10 +17,27 @@ logger = get_logger(__name__) class GoogleAIClient(GoogleVertexClient): provider_label = "Google AI" - def _get_client(self): + def _get_client(self, llm_config: Optional[LLMConfig] = None): timeout_ms = int(settings.llm_request_timeout_seconds * 1000) + api_key = None + if llm_config: + api_key, _, _ = self.get_byok_overrides(llm_config) + if not api_key: + api_key = model_settings.gemini_api_key return genai.Client( - api_key=model_settings.gemini_api_key, + api_key=api_key, + http_options=HttpOptions(timeout=timeout_ms), + ) + + async def _get_client_async(self, llm_config: Optional[LLMConfig] = None): + timeout_ms = int(settings.llm_request_timeout_seconds * 1000) + api_key = None + if llm_config: + api_key, _, _ = await self.get_byok_overrides_async(llm_config) + if not api_key: + api_key = model_settings.gemini_api_key + return genai.Client( + api_key=api_key, http_options=HttpOptions(timeout=timeout_ms), ) diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 488494a0..368ba5e7 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -53,8 +53,31 @@ class GoogleVertexClient(LLMClientBase): MAX_RETRIES = model_settings.gemini_max_retries provider_label = "Google Vertex" - def _get_client(self): + def _get_client(self, llm_config: Optional[LLMConfig] = None): timeout_ms = int(settings.llm_request_timeout_seconds * 1000) + if llm_config: + api_key, _, _ = self.get_byok_overrides(llm_config) + if api_key: + return Client( + api_key=api_key, + http_options=HttpOptions(timeout=timeout_ms), + ) + return Client( + vertexai=True, + project=model_settings.google_cloud_project, + location=model_settings.google_cloud_location, + http_options=HttpOptions(api_version="v1", timeout=timeout_ms), + ) + + async def _get_client_async(self, llm_config: Optional[LLMConfig] = None): + timeout_ms = int(settings.llm_request_timeout_seconds * 1000) + if llm_config: + api_key, _, _ = await self.get_byok_overrides_async(llm_config) + if api_key: + return Client( + api_key=api_key, + http_options=HttpOptions(timeout=timeout_ms), + ) return Client( vertexai=True, project=model_settings.google_cloud_project, @@ -74,7 +97,7 @@ class GoogleVertexClient(LLMClientBase): Performs underlying request to llm and returns raw response. """ try: - client = self._get_client() + client = self._get_client(llm_config) response = client.models.generate_content( model=llm_config.model, contents=request_data["contents"], @@ -103,7 +126,7 @@ class GoogleVertexClient(LLMClientBase): """ request_data = sanitize_unicode_surrogates(request_data) - client = self._get_client() + client = await self._get_client_async(llm_config) # Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry # https://github.com/googleapis/python-aiplatform/issues/4472 @@ -180,7 +203,7 @@ class GoogleVertexClient(LLMClientBase): async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]: request_data = sanitize_unicode_surrogates(request_data) - client = self._get_client() + client = await self._get_client_async(llm_config) try: response = await client.aio.models.generate_content_stream(