fix: refactor Google AI Provider / helper functions and add endpoint test (#1850)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
@@ -15,27 +16,41 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from letta.utils import get_tool_call_id, get_utc_time
|
||||
|
||||
# from letta.data_types import ToolCall
|
||||
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
||||
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"gemini-pro",
|
||||
]
|
||||
def get_gemini_endpoint_and_headers(
|
||||
base_url: str, model: Optional[str], api_key: str, key_in_header: bool = True, generate_content: bool = False
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Dynamically generate the model endpoint and headers.
|
||||
"""
|
||||
url = f"{base_url}/v1beta/models"
|
||||
|
||||
# Add the model
|
||||
if model is not None:
|
||||
url += f"/{model}"
|
||||
|
||||
def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
|
||||
from letta.utils import printd
|
||||
# Add extension for generating content if we're hitting the LM
|
||||
if generate_content:
|
||||
url += ":generateContent"
|
||||
|
||||
# Decide if api key should be in header or not
|
||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
||||
if key_in_header:
|
||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}"
|
||||
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
|
||||
else:
|
||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}?key={api_key}"
|
||||
url += f"?key={api_key}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
return url, headers
|
||||
|
||||
|
||||
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
printd(f"response = {response}")
|
||||
@@ -66,25 +81,17 @@ def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str,
|
||||
raise e
|
||||
|
||||
|
||||
def google_ai_get_model_context_window(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
||||
model_details = google_ai_get_model_details(
|
||||
service_endpoint=service_endpoint, api_key=api_key, model=model, key_in_header=key_in_header
|
||||
)
|
||||
def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
||||
model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
|
||||
# TODO should this be:
|
||||
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
|
||||
return int(model_details["inputTokenLimit"])
|
||||
|
||||
|
||||
def google_ai_get_model_list(service_endpoint: str, api_key: str, key_in_header: bool = True) -> List[dict]:
|
||||
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
|
||||
from letta.utils import printd
|
||||
|
||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
||||
if key_in_header:
|
||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models"
|
||||
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
|
||||
else:
|
||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models?key={api_key}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
@@ -396,7 +403,7 @@ def convert_google_ai_response_to_chatcompletion(
|
||||
|
||||
# TODO convert 'data' type to pydantic
|
||||
def google_ai_chat_completions_request(
|
||||
service_endpoint: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
@@ -414,55 +421,23 @@ def google_ai_chat_completions_request(
|
||||
This service has the following service endpoint and all URIs below are relative to this service endpoint:
|
||||
https://xxx.googleapis.com
|
||||
"""
|
||||
from letta.utils import printd
|
||||
|
||||
assert service_endpoint is not None, "Missing service_endpoint when calling Google AI"
|
||||
assert api_key is not None, "Missing api_key when calling Google AI"
|
||||
assert model in SUPPORTED_MODELS, f"Model '{model}' not in supported models: {', '.join(SUPPORTED_MODELS)}"
|
||||
|
||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
||||
if key_in_header:
|
||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent"
|
||||
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
|
||||
else:
|
||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header, generate_content=True)
|
||||
|
||||
# data["contents"][-1]["role"] = "model"
|
||||
if add_postfunc_model_messages:
|
||||
data["contents"] = add_dummy_model_messages(data["contents"])
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
response_json = make_post_request(url, headers, data)
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response}")
|
||||
|
||||
# Convert Google AI response to ChatCompletion style
|
||||
return convert_google_ai_response_to_chatcompletion(
|
||||
response_json=response,
|
||||
model=model,
|
||||
response_json=response_json,
|
||||
model=data.get("model"),
|
||||
input_messages=data["contents"],
|
||||
pull_inner_thoughts_from_args=inner_thoughts_in_kwargs,
|
||||
pull_inner_thoughts_from_args=data.get("inner_thoughts_in_kwargs", False),
|
||||
)
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}, payload={data}")
|
||||
# Print the HTTP status code
|
||||
print(f"HTTP Error: {http_err.response.status_code}")
|
||||
# Print the response content (error message from server)
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
except Exception as conversion_error:
|
||||
print(f"Error during response conversion: {conversion_error}")
|
||||
raise conversion_error
|
||||
|
||||
Reference in New Issue
Block a user