fix: google model listing api (#1454)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
@@ -21,7 +23,13 @@ class GoogleAIClient(LLMClientBase):
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
url, headers = self.get_gemini_endpoint_and_headers(generate_content=True)
|
||||
url, headers = get_gemini_endpoint_and_headers(
|
||||
base_url=str(self.llm_config.model_endpoint),
|
||||
model=self.llm_config.model,
|
||||
api_key=str(model_settings.gemini_api_key),
|
||||
key_in_header=True,
|
||||
generate_content=True,
|
||||
)
|
||||
return make_post_request(url, headers, request_data)
|
||||
|
||||
def build_request_data(
|
||||
@@ -208,34 +216,6 @@ class GoogleAIClient(LLMClientBase):
|
||||
except KeyError as e:
|
||||
raise e
|
||||
|
||||
def get_gemini_endpoint_and_headers(
|
||||
self,
|
||||
key_in_header: bool = True,
|
||||
generate_content: bool = False,
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Dynamically generate the model endpoint and headers.
|
||||
"""
|
||||
|
||||
url = f"{self.llm_config.model_endpoint}/v1beta/models"
|
||||
|
||||
# Add the model
|
||||
url += f"/{self.llm_config.model}"
|
||||
|
||||
# Add extension for generating content if we're hitting the LM
|
||||
if generate_content:
|
||||
url += ":generateContent"
|
||||
|
||||
# Decide if api key should be in header or not
|
||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
||||
if key_in_header:
|
||||
headers = {"Content-Type": "application/json", "x-goog-api-key": model_settings.gemini_api_key}
|
||||
else:
|
||||
url += f"?key={model_settings.gemini_api_key}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
return url, headers
|
||||
|
||||
def convert_tools_to_google_ai_format(self, tools: List[Tool]) -> List[dict]:
|
||||
"""
|
||||
OpenAI style:
|
||||
@@ -330,3 +310,106 @@ class GoogleAIClient(LLMClientBase):
|
||||
messages_with_padding.append(dummy_yield_message)
|
||||
|
||||
return messages_with_padding
|
||||
|
||||
|
||||
def get_gemini_endpoint_and_headers(
|
||||
base_url: str, model: Optional[str], api_key: str, key_in_header: bool = True, generate_content: bool = False
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Dynamically generate the model endpoint and headers.
|
||||
"""
|
||||
url = f"{base_url}/v1beta/models"
|
||||
|
||||
# Add the model
|
||||
if model is not None:
|
||||
url += f"/{model}"
|
||||
|
||||
# Add extension for generating content if we're hitting the LM
|
||||
if generate_content:
|
||||
url += ":generateContent"
|
||||
|
||||
# Decide if api key should be in header or not
|
||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
||||
if key_in_header:
|
||||
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
|
||||
else:
|
||||
url += f"?key={api_key}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
return url, headers
|
||||
|
||||
|
||||
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
|
||||
# Grab the models out
|
||||
model_list = response["models"]
|
||||
return model_list
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
print(f"HTTP Error: {http_err.response.status_code}")
|
||||
# Print the response content (error message from server)
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
|
||||
|
||||
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response}")
|
||||
|
||||
# Grab the models out
|
||||
return response
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
print(f"HTTP Error: {http_err.response.status_code}")
|
||||
# Print the response content (error message from server)
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
|
||||
|
||||
def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
||||
model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
|
||||
# TODO should this be:
|
||||
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
|
||||
return int(model_details["inputTokenLimit"])
|
||||
|
||||
@@ -1120,7 +1120,7 @@ class GoogleAIProvider(Provider):
|
||||
base_url: str = "https://generativelanguage.googleapis.com"
|
||||
|
||||
def list_llm_models(self):
|
||||
from letta.llm_api.google_ai import google_ai_get_model_list
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
||||
|
||||
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
||||
# filter by 'generateContent' models
|
||||
@@ -1149,7 +1149,7 @@ class GoogleAIProvider(Provider):
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self):
|
||||
from letta.llm_api.google_ai import google_ai_get_model_list
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
||||
|
||||
# TODO: use base_url instead
|
||||
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
||||
@@ -1173,7 +1173,7 @@ class GoogleAIProvider(Provider):
|
||||
return configs
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
||||
from letta.llm_api.google_ai import google_ai_get_model_context_window
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_context_window
|
||||
|
||||
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
|
||||
|
||||
|
||||
@@ -340,11 +340,9 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
else:
|
||||
inject_agent_state = False
|
||||
|
||||
|
||||
# Execute in sandbox
|
||||
sandbox_run_result = await AsyncToolExecutionSandbox(function_name, function_args, actor, tool_object=tool).run(
|
||||
agent_state=agent_state_copy,
|
||||
inject_agent_state=inject_agent_state
|
||||
agent_state=agent_state_copy, inject_agent_state=inject_agent_state
|
||||
)
|
||||
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
|
||||
@@ -10,7 +10,6 @@ from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta import create_client
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
@@ -185,7 +184,7 @@ def agent_state(client, roll_dice_tool, weather_tool, rethink_tool):
|
||||
},
|
||||
],
|
||||
llm_config=llm_config,
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai")
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
yield agent_state
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
@@ -569,9 +569,8 @@ def test_list_llm_models(client: RESTClient):
|
||||
assert has_model_endpoint_type(models, "azure")
|
||||
if model_settings.openai_api_key:
|
||||
assert has_model_endpoint_type(models, "openai")
|
||||
# TODO: Fix this
|
||||
# if model_settings.gemini_api_key:
|
||||
# assert has_model_endpoint_type(models, "google_ai")
|
||||
if model_settings.gemini_api_key:
|
||||
assert has_model_endpoint_type(models, "google_ai")
|
||||
if model_settings.anthropic_api_key:
|
||||
assert has_model_endpoint_type(models, "anthropic")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user