diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index c75deefd..e42dbdbf 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -1,6 +1,8 @@ import uuid from typing import List, Optional, Tuple +import requests + from letta.constants import NON_USER_MSG_PREFIX from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps @@ -21,7 +23,13 @@ class GoogleAIClient(LLMClientBase): """ Performs underlying request to llm and returns raw response. """ - url, headers = self.get_gemini_endpoint_and_headers(generate_content=True) + url, headers = get_gemini_endpoint_and_headers( + base_url=str(self.llm_config.model_endpoint), + model=self.llm_config.model, + api_key=str(model_settings.gemini_api_key), + key_in_header=True, + generate_content=True, + ) return make_post_request(url, headers, request_data) def build_request_data( @@ -208,34 +216,6 @@ class GoogleAIClient(LLMClientBase): except KeyError as e: raise e - def get_gemini_endpoint_and_headers( - self, - key_in_header: bool = True, - generate_content: bool = False, - ) -> Tuple[str, dict]: - """ - Dynamically generate the model endpoint and headers. - """ - - url = f"{self.llm_config.model_endpoint}/v1beta/models" - - # Add the model - url += f"/{self.llm_config.model}" - - # Add extension for generating content if we're hitting the LM - if generate_content: - url += ":generateContent" - - # Decide if api key should be in header or not - # Two ways to pass the key: https://ai.google.dev/tutorials/setup - if key_in_header: - headers = {"Content-Type": "application/json", "x-goog-api-key": model_settings.gemini_api_key} - else: - url += f"?key={model_settings.gemini_api_key}" - headers = {"Content-Type": "application/json"} - - return url, headers - def convert_tools_to_google_ai_format(self, tools: List[Tool]) -> List[dict]: """ OpenAI style: @@ -330,3 +310,106 @@ class GoogleAIClient(LLMClientBase): messages_with_padding.append(dummy_yield_message) return messages_with_padding + + +def get_gemini_endpoint_and_headers( + base_url: str, model: Optional[str], api_key: str, key_in_header: bool = True, generate_content: bool = False +) -> Tuple[str, dict]: + """ + Dynamically generate the model endpoint and headers. + """ + url = f"{base_url}/v1beta/models" + + # Add the model + if model is not None: + url += f"/{model}" + + # Add extension for generating content if we're hitting the LM + if generate_content: + url += ":generateContent" + + # Decide if api key should be in header or not + # Two ways to pass the key: https://ai.google.dev/tutorials/setup + if key_in_header: + headers = {"Content-Type": "application/json", "x-goog-api-key": api_key} + else: + url += f"?key={api_key}" + headers = {"Content-Type": "application/json"} + + return url, headers + + +def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]: + from letta.utils import printd + + url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header) + + try: + response = requests.get(url, headers=headers) + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + response = response.json() # convert to dict from string + + # Grab the models out + model_list = response["models"] + return model_list + + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}") + # Print the HTTP status code + print(f"HTTP Error: {http_err.response.status_code}") + # Print the response content (error message from server) + print(f"Message: {http_err.response.text}") + raise http_err + + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + + +def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]: + from letta.utils import printd + + url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header) + + try: + response = requests.get(url, headers=headers) + printd(f"response = {response}") + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + response = response.json() # convert to dict from string + printd(f"response.json = {response}") + + # Grab the models out + return response + + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}") + # Print the HTTP status code + print(f"HTTP Error: {http_err.response.status_code}") + # Print the response content (error message from server) + print(f"Message: {http_err.response.text}") + raise http_err + + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + + +def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int: + model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header) + # TODO should this be: + # return model_details["inputTokenLimit"] + model_details["outputTokenLimit"] + return int(model_details["inputTokenLimit"]) diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 647ab001..90776f9d 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -1120,7 +1120,7 @@ class GoogleAIProvider(Provider): base_url: str = "https://generativelanguage.googleapis.com" def list_llm_models(self): - from letta.llm_api.google_ai import google_ai_get_model_list + from letta.llm_api.google_ai_client import google_ai_get_model_list model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) # filter by 'generateContent' models @@ -1149,7 +1149,7 @@ class GoogleAIProvider(Provider): return configs def list_embedding_models(self): - from letta.llm_api.google_ai import google_ai_get_model_list + from letta.llm_api.google_ai_client import google_ai_get_model_list # TODO: use base_url instead model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) @@ -1173,7 +1173,7 @@ class GoogleAIProvider(Provider): return configs def get_model_context_window(self, model_name: str) -> Optional[int]: - from letta.llm_api.google_ai import google_ai_get_model_context_window + from letta.llm_api.google_ai_client import google_ai_get_model_context_window return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index e84f6865..69345b00 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -340,11 +340,9 @@ class SandboxToolExecutor(ToolExecutor): else: inject_agent_state = False - # Execute in sandbox sandbox_run_result = await AsyncToolExecutionSandbox(function_name, function_args, actor, tool_object=tool).run( - agent_state=agent_state_copy, - inject_agent_state=inject_agent_state + agent_state=agent_state_copy, inject_agent_state=inject_agent_state ) function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state diff --git a/tests/integration_test_experimental.py b/tests/integration_test_experimental.py index 98802554..cbc5ab74 100644 --- a/tests/integration_test_experimental.py +++ b/tests/integration_test_experimental.py @@ -10,7 +10,6 @@ from dotenv import load_dotenv from letta_client import Letta from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from letta import create_client from letta.agents.letta_agent import LettaAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageStreamStatus @@ -185,7 +184,7 @@ def agent_state(client, roll_dice_tool, weather_tool, rethink_tool): }, ], llm_config=llm_config, - embedding_config=EmbeddingConfig.default_config(provider="openai") + embedding_config=EmbeddingConfig.default_config(provider="openai"), ) yield agent_state client.agents.delete(agent_state.id) diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 6cf832c4..3a3a1b3e 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -569,9 +569,8 @@ def test_list_llm_models(client: RESTClient): assert has_model_endpoint_type(models, "azure") if model_settings.openai_api_key: assert has_model_endpoint_type(models, "openai") - # TODO: Fix this - # if model_settings.gemini_api_key: - # assert has_model_endpoint_type(models, "google_ai") + if model_settings.gemini_api_key: + assert has_model_endpoint_type(models, "google_ai") if model_settings.anthropic_api_key: assert has_model_endpoint_type(models, "anthropic")