feat: Asyncify model listing for Gemini (#2284)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from google import genai
|
||||
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
@@ -63,20 +63,24 @@ def google_ai_check_valid_api_key(api_key: str):
|
||||
|
||||
|
||||
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
|
||||
"""Synchronous version to get model list from Google AI API using httpx."""
|
||||
import httpx
|
||||
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
|
||||
response_data = response.json() # convert to dict from string
|
||||
|
||||
# Grab the models out
|
||||
model_list = response["models"]
|
||||
return model_list
|
||||
# Grab the models out
|
||||
model_list = response_data["models"]
|
||||
return model_list
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
@@ -85,8 +89,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
@@ -96,22 +100,74 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =
|
||||
raise e
|
||||
|
||||
|
||||
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
|
||||
async def google_ai_get_model_list_async(
|
||||
base_url: str, api_key: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
|
||||
) -> List[dict]:
|
||||
"""Asynchronous version to get model list from Google AI API using httpx."""
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
|
||||
|
||||
# Determine if we need to close the client at the end
|
||||
close_client = False
|
||||
if client is None:
|
||||
client = httpx.AsyncClient()
|
||||
close_client = True
|
||||
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
|
||||
response_data = response.json() # convert to dict from string
|
||||
|
||||
# Grab the models out
|
||||
model_list = response_data["models"]
|
||||
return model_list
|
||||
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
print(f"HTTP Error: {http_err.response.status_code}")
|
||||
# Print the response content (error message from server)
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Close the client if we created it
|
||||
if close_client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> dict:
|
||||
"""Synchronous version to get model details from Google AI API using httpx."""
|
||||
import httpx
|
||||
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response}")
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url, headers=headers)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
|
||||
response_data = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response_data}")
|
||||
|
||||
# Grab the models out
|
||||
return response
|
||||
# Return the model details
|
||||
return response_data
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
@@ -120,8 +176,8 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
@@ -131,8 +187,66 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_
|
||||
raise e
|
||||
|
||||
|
||||
async def google_ai_get_model_details_async(
|
||||
base_url: str, api_key: str, model: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
|
||||
) -> dict:
|
||||
"""Asynchronous version to get model details from Google AI API using httpx."""
|
||||
import httpx
|
||||
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
|
||||
|
||||
# Determine if we need to close the client at the end
|
||||
close_client = False
|
||||
if client is None:
|
||||
client = httpx.AsyncClient()
|
||||
close_client = True
|
||||
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
|
||||
response_data = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response_data}")
|
||||
|
||||
# Return the model details
|
||||
return response_data
|
||||
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
print(f"HTTP Error: {http_err.response.status_code}")
|
||||
# Print the response content (error message from server)
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Close the client if we created it
|
||||
if close_client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
||||
model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
|
||||
# TODO should this be:
|
||||
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
|
||||
return int(model_details["inputTokenLimit"])
|
||||
|
||||
|
||||
async def google_ai_get_model_context_window_async(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
||||
model_details = await google_ai_get_model_details_async(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
|
||||
# TODO should this be:
|
||||
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
|
||||
return int(model_details["inputTokenLimit"])
|
||||
|
||||
@@ -59,6 +59,9 @@ class Provider(ProviderBase):
|
||||
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def provider_tag(self) -> str:
|
||||
"""String representation of the provider for display purposes"""
|
||||
raise NotImplementedError
|
||||
@@ -1133,7 +1136,6 @@ class GoogleAIProvider(Provider):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
||||
|
||||
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
||||
# filter by 'generateContent' models
|
||||
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
|
||||
@@ -1159,11 +1161,56 @@ class GoogleAIProvider(Provider):
|
||||
)
|
||||
return configs
|
||||
|
||||
async def list_llm_models_async(self):
|
||||
import asyncio
|
||||
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
||||
|
||||
# Get and filter the model list
|
||||
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
|
||||
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
|
||||
# filter by model names
|
||||
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
||||
|
||||
# Add support for all gemini models
|
||||
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
||||
|
||||
# Prepare tasks for context window lookups in parallel
|
||||
async def create_config(model):
|
||||
context_window = await self.get_model_context_window_async(model)
|
||||
return LLMConfig(
|
||||
model=model,
|
||||
model_endpoint_type="google_ai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
# Execute all config creation tasks concurrently
|
||||
configs = await asyncio.gather(*[create_config(model) for model in model_options])
|
||||
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
||||
|
||||
# TODO: use base_url instead
|
||||
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
||||
return self._list_embedding_models(model_options)
|
||||
|
||||
async def list_embedding_models_async(self):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
||||
|
||||
# TODO: use base_url instead
|
||||
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
|
||||
return self._list_embedding_models(model_options)
|
||||
|
||||
def _list_embedding_models(self, model_options):
|
||||
# filter by 'generateContent' models
|
||||
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
@@ -1188,6 +1235,11 @@ class GoogleAIProvider(Provider):
|
||||
|
||||
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
|
||||
|
||||
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
|
||||
|
||||
return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name)
|
||||
|
||||
|
||||
class GoogleVertexProvider(Provider):
|
||||
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
||||
|
||||
@@ -130,6 +130,23 @@ def test_googleai():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_googleai_async():
|
||||
api_key = model_settings.gemini_api_key
|
||||
assert api_key is not None
|
||||
provider = GoogleAIProvider(
|
||||
name="google_ai",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = await provider.list_embedding_models_async()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_google_vertex():
|
||||
provider = GoogleVertexProvider(
|
||||
name="google_vertex",
|
||||
|
||||
Reference in New Issue
Block a user