From b83f77af225fcf96fcbfc3e2757e57964ca4d75a Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 8 Oct 2024 16:55:11 -0700 Subject: [PATCH] fix: refactor Google AI Provider / helper functions and add endpoint test (#1850) Co-authored-by: Matt Zhou --- letta/credentials.py | 4 +- letta/llm_api/google_ai.py | 103 +++++++----------- letta/llm_api/helpers.py | 59 +++++++++- letta/llm_api/llm_api_tools.py | 3 +- letta/llm_api/openai.py | 54 +-------- letta/providers.py | 9 +- .../openai/chat_completion_response.py | 3 + .../configs/llm_model_configs/gemini-pro.json | 7 ++ tests/test_endpoints.py | 10 ++ 9 files changed, 127 insertions(+), 125 deletions(-) create mode 100644 tests/configs/llm_model_configs/gemini-pro.json diff --git a/letta/credentials.py b/letta/credentials.py index 05d683ae..91d9cce7 100644 --- a/letta/credentials.py +++ b/letta/credentials.py @@ -76,7 +76,7 @@ class LettaCredentials: "azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"), # gemini "google_ai_key": get_field(config, "google_ai", "key"), - "google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"), + # "google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"), # anthropic "anthropic_key": get_field(config, "anthropic", "key"), # cohere @@ -117,7 +117,7 @@ class LettaCredentials: # gemini set_field(config, "google_ai", "key", self.google_ai_key) - set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint) + # set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint) # anthropic set_field(config, "anthropic", "key", self.anthropic_key) diff --git a/letta/llm_api/google_ai.py b/letta/llm_api/google_ai.py index fd49b8fd..71f64dae 100644 --- a/letta/llm_api/google_ai.py +++ b/letta/llm_api/google_ai.py @@ -1,9 +1,10 @@ import uuid -from typing import List, Optional +from typing import List, Optional, Tuple import requests from letta.constants import NON_USER_MSG_PREFIX +from letta.llm_api.helpers import make_post_request from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens from letta.schemas.openai.chat_completion_request import Tool @@ -15,27 +16,41 @@ from letta.schemas.openai.chat_completion_response import ( ToolCall, UsageStatistics, ) -from letta.utils import get_tool_call_id, get_utc_time - -# from letta.data_types import ToolCall +from letta.utils import get_tool_call_id, get_utc_time, json_dumps -SUPPORTED_MODELS = [ - "gemini-pro", -] +def get_gemini_endpoint_and_headers( + base_url: str, model: Optional[str], api_key: str, key_in_header: bool = True, generate_content: bool = False +) -> Tuple[str, dict]: + """ + Dynamically generate the model endpoint and headers. + """ + url = f"{base_url}/v1beta/models" + # Add the model + if model is not None: + url += f"/{model}" -def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]: - from letta.utils import printd + # Add extension for generating content if we're hitting the LM + if generate_content: + url += ":generateContent" + # Decide if api key should be in header or not # Two ways to pass the key: https://ai.google.dev/tutorials/setup if key_in_header: - url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}" headers = {"Content-Type": "application/json", "x-goog-api-key": api_key} else: - url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}?key={api_key}" + url += f"?key={api_key}" headers = {"Content-Type": "application/json"} + return url, headers + + +def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]: + from letta.utils import printd + + url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header) + try: response = requests.get(url, headers=headers) printd(f"response = {response}") @@ -66,25 +81,17 @@ def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str, raise e -def google_ai_get_model_context_window(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> int: - model_details = google_ai_get_model_details( - service_endpoint=service_endpoint, api_key=api_key, model=model, key_in_header=key_in_header - ) +def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int: + model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header) # TODO should this be: # return model_details["inputTokenLimit"] + model_details["outputTokenLimit"] return int(model_details["inputTokenLimit"]) -def google_ai_get_model_list(service_endpoint: str, api_key: str, key_in_header: bool = True) -> List[dict]: +def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]: from letta.utils import printd - # Two ways to pass the key: https://ai.google.dev/tutorials/setup - if key_in_header: - url = f"https://{service_endpoint}.googleapis.com/v1beta/models" - headers = {"Content-Type": "application/json", "x-goog-api-key": api_key} - else: - url = f"https://{service_endpoint}.googleapis.com/v1beta/models?key={api_key}" - headers = {"Content-Type": "application/json"} + url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header) try: response = requests.get(url, headers=headers) @@ -396,7 +403,7 @@ def convert_google_ai_response_to_chatcompletion( # TODO convert 'data' type to pydantic def google_ai_chat_completions_request( - service_endpoint: str, + base_url: str, model: str, api_key: str, data: dict, @@ -414,55 +421,23 @@ def google_ai_chat_completions_request( This service has the following service endpoint and all URIs below are relative to this service endpoint: https://xxx.googleapis.com """ - from letta.utils import printd - assert service_endpoint is not None, "Missing service_endpoint when calling Google AI" assert api_key is not None, "Missing api_key when calling Google AI" - assert model in SUPPORTED_MODELS, f"Model '{model}' not in supported models: {', '.join(SUPPORTED_MODELS)}" - # Two ways to pass the key: https://ai.google.dev/tutorials/setup - if key_in_header: - url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent" - headers = {"Content-Type": "application/json", "x-goog-api-key": api_key} - else: - url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}" - headers = {"Content-Type": "application/json"} + url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header, generate_content=True) # data["contents"][-1]["role"] = "model" if add_postfunc_model_messages: data["contents"] = add_dummy_model_messages(data["contents"]) - printd(f"Sending request to {url}") + response_json = make_post_request(url, headers, data) try: - response = requests.post(url, headers=headers, json=data) - printd(f"response = {response}") - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string - printd(f"response.json = {response}") - - # Convert Google AI response to ChatCompletion style return convert_google_ai_response_to_chatcompletion( - response_json=response, - model=model, + response_json=response_json, + model=data.get("model"), input_messages=data["contents"], - pull_inner_thoughts_from_args=inner_thoughts_in_kwargs, + pull_inner_thoughts_from_args=data.get("inner_thoughts_in_kwargs", False), ) - - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}, payload={data}") - # Print the HTTP status code - print(f"HTTP Error: {http_err.response.status_code}") - # Print the response content (error message from server) - print(f"Message: {http_err.response.text}") - raise http_err - - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e + except Exception as conversion_error: + print(f"Error during response conversion: {conversion_error}") + raise conversion_error diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 3fae442a..6c532361 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -1,14 +1,69 @@ import copy import json import warnings -from typing import List, Union +from typing import Any, List, Union import requests from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.schemas.enums import OptionState from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice -from letta.utils import json_dumps +from letta.utils import json_dumps, printd + + +def make_post_request(url: str, headers: dict[str, str], data: dict[str, Any]) -> dict[str, Any]: + printd(f"Sending request to {url}") + try: + # Make the POST request + response = requests.post(url, headers=headers, json=data) + printd(f"Response status code: {response.status_code}") + + # Raise for 4XX/5XX HTTP errors + response.raise_for_status() + + # Ensure the content is JSON before parsing + if response.headers.get("Content-Type") == "application/json": + response_data = response.json() # Convert to dict from JSON + printd(f"Response JSON: {response_data}") + else: + error_message = f"Unexpected content type returned: {response.headers.get('Content-Type')}" + printd(error_message) + raise ValueError(error_message) + + # Process the response using the callback function + return response_data + + except requests.exceptions.HTTPError as http_err: + # HTTP errors (4XX, 5XX) + error_message = f"HTTP error occurred: {http_err}" + if http_err.response is not None: + error_message += f" | Status code: {http_err.response.status_code}, Message: {http_err.response.text}" + printd(error_message) + raise requests.exceptions.HTTPError(error_message) from http_err + + except requests.exceptions.Timeout as timeout_err: + # Handle timeout errors + error_message = f"Request timed out: {timeout_err}" + printd(error_message) + raise requests.exceptions.Timeout(error_message) from timeout_err + + except requests.exceptions.RequestException as req_err: + # Non-HTTP errors (e.g., connection, SSL errors) + error_message = f"Request failed: {req_err}" + printd(error_message) + raise requests.exceptions.RequestException(error_message) from req_err + + except ValueError as val_err: + # Handle content-type or non-JSON response issues + error_message = f"ValueError: {val_err}" + printd(error_message) + raise ValueError(error_message) from val_err + + except Exception as e: + # Catch any other unknown exceptions + error_message = f"An unexpected error occurred: {e}" + printd(error_message) + raise Exception(error_message) from e # TODO update to use better types diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 6d30236a..b85d5739 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -28,7 +28,6 @@ from letta.local_llm.constants import ( INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, ) -from letta.providers import GoogleAIProvider from letta.schemas.enums import OptionState from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message @@ -231,7 +230,7 @@ def create( return google_ai_chat_completions_request( inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg, - service_endpoint=GoogleAIProvider(model_settings.gemini_api_key).service_endpoint, + base_url=llm_config.model_endpoint, model=llm_config.model, api_key=model_settings.gemini_api_key, # see structure of payload here: https://ai.google.dev/docs/function_calling diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 18f92372..0e8cda99 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -9,7 +9,7 @@ from httpx_sse._exceptions import SSEError from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.errors import LLMError -from letta.llm_api.helpers import add_inner_thoughts_to_functions +from letta.llm_api.helpers import add_inner_thoughts_to_functions, make_post_request from letta.local_llm.constants import ( INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, @@ -483,58 +483,14 @@ def openai_chat_completions_request( data.pop("tools") data.pop("tool_choice", None) # extra safe, should exist always (default="auto") - printd(f"Sending request to {url}") - try: - response = requests.post(url, headers=headers, json=data) - printd(f"response = {response}, response.text = {response.text}") - # print(json.dumps(data, indent=4)) - # raise requests.exceptions.HTTPError - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - - response = response.json() # convert to dict from string - printd(f"response.json = {response}") - - response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default - return response - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}, payload={data}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e + response_json = make_post_request(url, headers, data) + return ChatCompletionResponse(**response_json) def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse: """https://platform.openai.com/docs/api-reference/embeddings/create""" - from letta.utils import printd url = smart_urljoin(url, "embeddings") headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - - printd(f"Sending request to {url}") - try: - response = requests.post(url, headers=headers, json=data) - printd(f"response = {response}") - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string - printd(f"response.json = {response}") - response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default - return response - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}, payload={data}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e + response_json = make_post_request(url, headers, data) + return EmbeddingResponse(**response_json) diff --git a/letta/providers.py b/letta/providers.py index e3fe71d7..35912214 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -217,14 +217,12 @@ class GroqProvider(OpenAIProvider): class GoogleAIProvider(Provider): # gemini api_key: str = Field(..., description="API key for the Google AI API.") - service_endpoint: str = "generativelanguage" # TODO: remove once old functions are refactored to just use base_url base_url: str = "https://generativelanguage.googleapis.com" def list_llm_models(self): from letta.llm_api.google_ai import google_ai_get_model_list - # TODO: use base_url instead - model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key) + model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) # filter by 'generateContent' models model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] @@ -251,7 +249,7 @@ class GoogleAIProvider(Provider): from letta.llm_api.google_ai import google_ai_get_model_list # TODO: use base_url instead - model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key) + model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) # filter by 'generateContent' models model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] @@ -273,8 +271,7 @@ class GoogleAIProvider(Provider): def get_model_context_window(self, model_name: str): from letta.llm_api.google_ai import google_ai_get_model_context_window - # TODO: use base_url instead - return google_ai_get_model_context_window(self.service_endpoint, self.api_key, model_name) + return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) class AzureProvider(Provider): diff --git a/letta/schemas/openai/chat_completion_response.py b/letta/schemas/openai/chat_completion_response.py index 7b74ca88..ea37ec45 100644 --- a/letta/schemas/openai/chat_completion_response.py +++ b/letta/schemas/openai/chat_completion_response.py @@ -74,6 +74,9 @@ class ChatCompletionResponse(BaseModel): object: Literal["chat.completion"] = "chat.completion" usage: UsageStatistics + def __str__(self): + return self.model_dump_json(indent=4) + class FunctionCallDelta(BaseModel): # arguments: Optional[str] = None diff --git a/tests/configs/llm_model_configs/gemini-pro.json b/tests/configs/llm_model_configs/gemini-pro.json new file mode 100644 index 00000000..5c425b6d --- /dev/null +++ b/tests/configs/llm_model_configs/gemini-pro.json @@ -0,0 +1,7 @@ +{ + "context_window": 2097152, + "model": "gemini-1.5-pro-latest", + "model_endpoint_type": "google_ai", + "model_endpoint": "https://generativelanguage.googleapis.com", + "model_wrapper": null +} diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index b65042e1..2b767937 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -273,3 +273,13 @@ def test_groq_llama31_70b_edit_core_memory(): response = check_agent_edit_core_memory(filename) # Log out successful response print(f"Got successful response from client: \n\n{response}") + + +# ====================================================================================================================== +# GEMINI TESTS +# ====================================================================================================================== +def test_gemini_pro_15_returns_valid_first_message(): + filename = os.path.join(llm_config_dir, "gemini-pro.json") + response = check_first_response_is_valid_for_llm_endpoint(filename) + # Log out successful response + print(f"Got successful response from client: \n\n{response}")