feat: Asyncify model listing for Gemini (#2284)

This commit is contained in:
Matthew Zhou
2025-05-20 16:00:20 -07:00
committed by GitHub
parent b0f38cd2b1
commit 0103ea6fcf
3 changed files with 205 additions and 22 deletions

View File

@@ -1,6 +1,6 @@
from typing import List, Optional, Tuple
import requests
import httpx
from google import genai
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
@@ -63,20 +63,24 @@ def google_ai_check_valid_api_key(api_key: str):
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
"""Synchronous version to get model list from Google AI API using httpx."""
import httpx
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
with httpx.Client() as client:
response = client.get(url, headers=headers)
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
response_data = response.json() # convert to dict from string
# Grab the models out
model_list = response["models"]
return model_list
# Grab the models out
model_list = response_data["models"]
return model_list
except requests.exceptions.HTTPError as http_err:
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
@@ -85,8 +89,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =
print(f"Message: {http_err.response.text}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
@@ -96,22 +100,74 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =
raise e
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
async def google_ai_get_model_list_async(
base_url: str, api_key: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
) -> List[dict]:
"""Asynchronous version to get model list from Google AI API using httpx."""
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
# Determine if we need to close the client at the end
close_client = False
if client is None:
client = httpx.AsyncClient()
close_client = True
try:
response = await client.get(url, headers=headers)
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
response_data = response.json() # convert to dict from string
# Grab the models out
model_list = response_data["models"]
return model_list
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
print(f"HTTP Error: {http_err.response.status_code}")
# Print the response content (error message from server)
print(f"Message: {http_err.response.text}")
raise http_err
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
finally:
# Close the client if we created it
if close_client:
await client.aclose()
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> dict:
"""Synchronous version to get model details from Google AI API using httpx."""
import httpx
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
try:
response = requests.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
with httpx.Client() as client:
response = client.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
response_data = response.json() # convert to dict from string
printd(f"response.json = {response_data}")
# Grab the models out
return response
# Return the model details
return response_data
except requests.exceptions.HTTPError as http_err:
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
@@ -120,8 +176,8 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_
print(f"Message: {http_err.response.text}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
@@ -131,8 +187,66 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_
raise e
async def google_ai_get_model_details_async(
base_url: str, api_key: str, model: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
) -> dict:
"""Asynchronous version to get model details from Google AI API using httpx."""
import httpx
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
# Determine if we need to close the client at the end
close_client = False
if client is None:
client = httpx.AsyncClient()
close_client = True
try:
response = await client.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
response_data = response.json() # convert to dict from string
printd(f"response.json = {response_data}")
# Return the model details
return response_data
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
print(f"HTTP Error: {http_err.response.status_code}")
# Print the response content (error message from server)
print(f"Message: {http_err.response.text}")
raise http_err
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
finally:
# Close the client if we created it
if close_client:
await client.aclose()
def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
# TODO should this be:
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
return int(model_details["inputTokenLimit"])
async def google_ai_get_model_context_window_async(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
model_details = await google_ai_get_model_details_async(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
# TODO should this be:
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
return int(model_details["inputTokenLimit"])

View File

@@ -59,6 +59,9 @@ class Provider(ProviderBase):
def get_model_context_window(self, model_name: str) -> Optional[int]:
raise NotImplementedError
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
raise NotImplementedError
def provider_tag(self) -> str:
"""String representation of the provider for display purposes"""
raise NotImplementedError
@@ -1133,7 +1136,6 @@ class GoogleAIProvider(Provider):
from letta.llm_api.google_ai_client import google_ai_get_model_list
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
# filter by 'generateContent' models
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
@@ -1159,11 +1161,56 @@ class GoogleAIProvider(Provider):
)
return configs
async def list_llm_models_async(self):
import asyncio
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
# Get and filter the model list
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
# filter by model names
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# Add support for all gemini models
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
# Prepare tasks for context window lookups in parallel
async def create_config(model):
context_window = await self.get_model_context_window_async(model)
return LLMConfig(
model=model,
model_endpoint_type="google_ai",
model_endpoint=self.base_url,
context_window=context_window,
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
provider_category=self.provider_category,
)
# Execute all config creation tasks concurrently
configs = await asyncio.gather(*[create_config(model) for model in model_options])
return configs
def list_embedding_models(self):
from letta.llm_api.google_ai_client import google_ai_get_model_list
# TODO: use base_url instead
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
return self._list_embedding_models(model_options)
async def list_embedding_models_async(self):
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
# TODO: use base_url instead
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
return self._list_embedding_models(model_options)
def _list_embedding_models(self, model_options):
# filter by 'generateContent' models
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
@@ -1188,6 +1235,11 @@ class GoogleAIProvider(Provider):
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name)
class GoogleVertexProvider(Provider):
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")

View File

@@ -130,6 +130,23 @@ def test_googleai():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
@pytest.mark.asyncio
async def test_googleai_async():
api_key = model_settings.gemini_api_key
assert api_key is not None
provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = await provider.list_embedding_models_async()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_google_vertex():
provider = GoogleVertexProvider(
name="google_vertex",