diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index c0343c65..47671398 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -1,6 +1,6 @@ from typing import List, Optional, Tuple -import requests +import httpx from google import genai from letta.errors import ErrorCode, LLMAuthenticationError, LLMError @@ -63,20 +63,24 @@ def google_ai_check_valid_api_key(api_key: str): def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]: + """Synchronous version to get model list from Google AI API using httpx.""" + import httpx + from letta.utils import printd url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header) try: - response = requests.get(url, headers=headers) - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string + with httpx.Client() as client: + response = client.get(url, headers=headers) + response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status + response_data = response.json() # convert to dict from string - # Grab the models out - model_list = response["models"] - return model_list + # Grab the models out + model_list = response_data["models"] + return model_list - except requests.exceptions.HTTPError as http_err: + except httpx.HTTPStatusError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) printd(f"Got HTTPError, exception={http_err}") # Print the HTTP status code @@ -85,8 +89,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = print(f"Message: {http_err.response.text}") raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) printd(f"Got RequestException, exception={req_err}") raise req_err @@ -96,22 +100,74 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = raise e -def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]: +async def google_ai_get_model_list_async( + base_url: str, api_key: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None +) -> List[dict]: + """Asynchronous version to get model list from Google AI API using httpx.""" + from letta.utils import printd + + url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header) + + # Determine if we need to close the client at the end + close_client = False + if client is None: + client = httpx.AsyncClient() + close_client = True + + try: + response = await client.get(url, headers=headers) + response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status + response_data = response.json() # convert to dict from string + + # Grab the models out + model_list = response_data["models"] + return model_list + + except httpx.HTTPStatusError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}") + # Print the HTTP status code + print(f"HTTP Error: {http_err.response.status_code}") + # Print the response content (error message from server) + print(f"Message: {http_err.response.text}") + raise http_err + + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + + finally: + # Close the client if we created it + if close_client: + await client.aclose() + + +def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> dict: + """Synchronous version to get model details from Google AI API using httpx.""" + import httpx + from letta.utils import printd url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header) try: - response = requests.get(url, headers=headers) - printd(f"response = {response}") - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string - printd(f"response.json = {response}") + with httpx.Client() as client: + response = client.get(url, headers=headers) + printd(f"response = {response}") + response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status + response_data = response.json() # convert to dict from string + printd(f"response.json = {response_data}") - # Grab the models out - return response + # Return the model details + return response_data - except requests.exceptions.HTTPError as http_err: + except httpx.HTTPStatusError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) printd(f"Got HTTPError, exception={http_err}") # Print the HTTP status code @@ -120,8 +176,8 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_ print(f"Message: {http_err.response.text}") raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) printd(f"Got RequestException, exception={req_err}") raise req_err @@ -131,8 +187,66 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_ raise e +async def google_ai_get_model_details_async( + base_url: str, api_key: str, model: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None +) -> dict: + """Asynchronous version to get model details from Google AI API using httpx.""" + import httpx + + from letta.utils import printd + + url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header) + + # Determine if we need to close the client at the end + close_client = False + if client is None: + client = httpx.AsyncClient() + close_client = True + + try: + response = await client.get(url, headers=headers) + printd(f"response = {response}") + response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status + response_data = response.json() # convert to dict from string + printd(f"response.json = {response_data}") + + # Return the model details + return response_data + + except httpx.HTTPStatusError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}") + # Print the HTTP status code + print(f"HTTP Error: {http_err.response.status_code}") + # Print the response content (error message from server) + print(f"Message: {http_err.response.text}") + raise http_err + + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + + finally: + # Close the client if we created it + if close_client: + await client.aclose() + + def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int: model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header) # TODO should this be: # return model_details["inputTokenLimit"] + model_details["outputTokenLimit"] return int(model_details["inputTokenLimit"]) + + +async def google_ai_get_model_context_window_async(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int: + model_details = await google_ai_get_model_details_async(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header) + # TODO should this be: + # return model_details["inputTokenLimit"] + model_details["outputTokenLimit"] + return int(model_details["inputTokenLimit"]) diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index fb8e5e05..52066b8d 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -59,6 +59,9 @@ class Provider(ProviderBase): def get_model_context_window(self, model_name: str) -> Optional[int]: raise NotImplementedError + async def get_model_context_window_async(self, model_name: str) -> Optional[int]: + raise NotImplementedError + def provider_tag(self) -> str: """String representation of the provider for display purposes""" raise NotImplementedError @@ -1133,7 +1136,6 @@ class GoogleAIProvider(Provider): from letta.llm_api.google_ai_client import google_ai_get_model_list model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) - # filter by 'generateContent' models model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] @@ -1159,11 +1161,56 @@ class GoogleAIProvider(Provider): ) return configs + async def list_llm_models_async(self): + import asyncio + + from letta.llm_api.google_ai_client import google_ai_get_model_list_async + + # Get and filter the model list + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] + model_options = [str(m["name"]) for m in model_options] + + # filter by model names + model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + + # Add support for all gemini models + model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] + + # Prepare tasks for context window lookups in parallel + async def create_config(model): + context_window = await self.get_model_context_window_async(model) + return LLMConfig( + model=model, + model_endpoint_type="google_ai", + model_endpoint=self.base_url, + context_window=context_window, + handle=self.get_handle(model), + max_tokens=8192, + provider_name=self.name, + provider_category=self.provider_category, + ) + + # Execute all config creation tasks concurrently + configs = await asyncio.gather(*[create_config(model) for model in model_options]) + + return configs + def list_embedding_models(self): from letta.llm_api.google_ai_client import google_ai_get_model_list # TODO: use base_url instead model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) + return self._list_embedding_models(model_options) + + async def list_embedding_models_async(self): + from letta.llm_api.google_ai_client import google_ai_get_model_list_async + + # TODO: use base_url instead + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + return self._list_embedding_models(model_options) + + def _list_embedding_models(self, model_options): # filter by 'generateContent' models model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] @@ -1188,6 +1235,11 @@ class GoogleAIProvider(Provider): return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) + async def get_model_context_window_async(self, model_name: str) -> Optional[int]: + from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async + + return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name) + class GoogleVertexProvider(Provider): provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.") diff --git a/tests/test_providers.py b/tests/test_providers.py index ea79126d..44846d09 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -130,6 +130,23 @@ def test_googleai(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" +@pytest.mark.asyncio +async def test_googleai_async(): + api_key = model_settings.gemini_api_key + assert api_key is not None + provider = GoogleAIProvider( + name="google_ai", + api_key=api_key, + ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = await provider.list_embedding_models_async() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + def test_google_vertex(): provider = GoogleVertexProvider( name="google_vertex",