fix(core): use BYOK API keys for Google AI/Vertex LLM requests (#9439)

GoogleAIClient and GoogleVertexClient were hardcoding Letta's managed
credentials for all requests, ignoring user-provided BYOK API keys.
This meant Letta was paying Google API costs for BYOK users.

Add _get_client_async and update _get_client to check BYOK overrides
(via get_byok_overrides / get_byok_overrides_async) before falling back
to managed credentials, matching the pattern used by OpenAIClient and
AnthropicClient.

🤖 Generated with [Letta Code](https://letta.com)

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
Kian Jones
2026-02-11 15:22:06 -08:00
committed by Caren Thomas
parent d0e25ae471
commit 5b7dd15905
2 changed files with 47 additions and 6 deletions

View File

@@ -8,6 +8,7 @@ from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.log import get_logger
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings, settings
logger = get_logger(__name__)
@@ -16,10 +17,27 @@ logger = get_logger(__name__)
class GoogleAIClient(GoogleVertexClient):
provider_label = "Google AI"
def _get_client(self):
def _get_client(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
api_key = None
if llm_config:
api_key, _, _ = self.get_byok_overrides(llm_config)
if not api_key:
api_key = model_settings.gemini_api_key
return genai.Client(
api_key=model_settings.gemini_api_key,
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
api_key = None
if llm_config:
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
if not api_key:
api_key = model_settings.gemini_api_key
return genai.Client(
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)

View File

@@ -53,8 +53,31 @@ class GoogleVertexClient(LLMClientBase):
MAX_RETRIES = model_settings.gemini_max_retries
provider_label = "Google Vertex"
def _get_client(self):
def _get_client(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
if llm_config:
api_key, _, _ = self.get_byok_overrides(llm_config)
if api_key:
return Client(
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)
return Client(
vertexai=True,
project=model_settings.google_cloud_project,
location=model_settings.google_cloud_location,
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
)
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
if llm_config:
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
if api_key:
return Client(
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)
return Client(
vertexai=True,
project=model_settings.google_cloud_project,
@@ -74,7 +97,7 @@ class GoogleVertexClient(LLMClientBase):
Performs underlying request to llm and returns raw response.
"""
try:
client = self._get_client()
client = self._get_client(llm_config)
response = client.models.generate_content(
model=llm_config.model,
contents=request_data["contents"],
@@ -103,7 +126,7 @@ class GoogleVertexClient(LLMClientBase):
"""
request_data = sanitize_unicode_surrogates(request_data)
client = self._get_client()
client = await self._get_client_async(llm_config)
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
# https://github.com/googleapis/python-aiplatform/issues/4472
@@ -180,7 +203,7 @@ class GoogleVertexClient(LLMClientBase):
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
request_data = sanitize_unicode_surrogates(request_data)
client = self._get_client()
client = await self._get_client_async(llm_config)
try:
response = await client.aio.models.generate_content_stream(