fix(core): use BYOK API keys for Google AI/Vertex LLM requests (#9439)
GoogleAIClient and GoogleVertexClient were hardcoding Letta's managed credentials for all requests, ignoring user-provided BYOK API keys. This meant Letta was paying Google API costs for BYOK users. Add _get_client_async and update _get_client to check BYOK overrides (via get_byok_overrides / get_byok_overrides_async) before falling back to managed credentials, matching the pattern used by OpenAIClient and AnthropicClient. 🤖 Generated with [Letta Code](https://letta.com) Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -53,8 +53,31 @@ class GoogleVertexClient(LLMClientBase):
|
||||
MAX_RETRIES = model_settings.gemini_max_retries
|
||||
provider_label = "Google Vertex"
|
||||
|
||||
def _get_client(self):
|
||||
def _get_client(self, llm_config: Optional[LLMConfig] = None):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
if llm_config:
|
||||
api_key, _, _ = self.get_byok_overrides(llm_config)
|
||||
if api_key:
|
||||
return Client(
|
||||
api_key=api_key,
|
||||
http_options=HttpOptions(timeout=timeout_ms),
|
||||
)
|
||||
return Client(
|
||||
vertexai=True,
|
||||
project=model_settings.google_cloud_project,
|
||||
location=model_settings.google_cloud_location,
|
||||
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
|
||||
)
|
||||
|
||||
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
if llm_config:
|
||||
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
|
||||
if api_key:
|
||||
return Client(
|
||||
api_key=api_key,
|
||||
http_options=HttpOptions(timeout=timeout_ms),
|
||||
)
|
||||
return Client(
|
||||
vertexai=True,
|
||||
project=model_settings.google_cloud_project,
|
||||
@@ -74,7 +97,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client = self._get_client(llm_config)
|
||||
response = client.models.generate_content(
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
@@ -103,7 +126,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = self._get_client()
|
||||
client = await self._get_client_async(llm_config)
|
||||
|
||||
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
|
||||
# https://github.com/googleapis/python-aiplatform/issues/4472
|
||||
@@ -180,7 +203,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
client = self._get_client()
|
||||
client = await self._get_client_async(llm_config)
|
||||
|
||||
try:
|
||||
response = await client.aio.models.generate_content_stream(
|
||||
|
||||
Reference in New Issue
Block a user