From c0501a743ded77ed244a2bfd443df6d74aa78b77 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 9 Oct 2024 10:00:21 -0700 Subject: [PATCH] test: add complex gemini tests (#1853) Co-authored-by: Matt Zhou --- letta/llm_api/google_ai.py | 2 +- letta/llm_api/helpers.py | 15 +++++++++---- letta/llm_api/llm_api_tools.py | 14 ++---------- tests/helpers/endpoints_helper.py | 1 - tests/test_endpoints.py | 36 +++++++++++++++++++++++++++++++ 5 files changed, 50 insertions(+), 18 deletions(-) diff --git a/letta/llm_api/google_ai.py b/letta/llm_api/google_ai.py index 71f64dae..5d4e1798 100644 --- a/letta/llm_api/google_ai.py +++ b/letta/llm_api/google_ai.py @@ -436,7 +436,7 @@ def google_ai_chat_completions_request( response_json=response_json, model=data.get("model"), input_messages=data["contents"], - pull_inner_thoughts_from_args=data.get("inner_thoughts_in_kwargs", False), + pull_inner_thoughts_from_args=inner_thoughts_in_kwargs, ) except Exception as conversion_error: print(f"Error during response conversion: {conversion_error}") diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 6c532361..a5aa43b2 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -21,10 +21,17 @@ def make_post_request(url: str, headers: dict[str, str], data: dict[str, Any]) - # Raise for 4XX/5XX HTTP errors response.raise_for_status() - # Ensure the content is JSON before parsing - if response.headers.get("Content-Type") == "application/json": - response_data = response.json() # Convert to dict from JSON - printd(f"Response JSON: {response_data}") + # Check if the response content type indicates JSON and attempt to parse it + content_type = response.headers.get("Content-Type", "") + if "application/json" in content_type.lower(): + try: + response_data = response.json() # Attempt to parse the response as JSON + printd(f"Response JSON: {response_data}") + except ValueError as json_err: + # Handle the case where the content type says JSON but the body is invalid + error_message = f"Failed to parse JSON despite Content-Type being {content_type}: {json_err}" + printd(error_message) + raise ValueError(error_message) from json_err else: error_message = f"Unexpected content type returned: {response.headers.get('Content-Type')}" printd(error_message) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index b85d5739..6327d1cb 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -217,19 +217,14 @@ def create( if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Google AI API requests") - # NOTE: until Google AI supports CoT / text alongside function calls, - # we need to put it in a kwarg (unless we want to split the message into two) - google_ai_inner_thoughts_in_kwarg = True - if functions is not None: tools = [{"type": "function", "function": f} for f in functions] tools = [Tool(**t) for t in tools] - tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg) + tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=True) else: tools = None return google_ai_chat_completions_request( - inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg, base_url=llm_config.model_endpoint, model=llm_config.model, api_key=model_settings.gemini_api_key, @@ -238,6 +233,7 @@ def create( contents=[m.to_google_ai_dict() for m in messages], tools=tools, ), + inner_thoughts_in_kwargs=True, ) elif llm_config.model_endpoint_type == "anthropic": @@ -246,12 +242,6 @@ def create( if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") - if functions is not None: - tools = [{"type": "function", "function": f} for f in functions] - tools = [Tool(**t) for t in tools] - else: - tools = None - return anthropic_chat_completions_request( url=llm_config.model_endpoint, api_key=model_settings.anthropic_api_key, diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 138c2ef4..19cbcb20 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -132,7 +132,6 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet # Get inner_thoughts_in_kwargs inner_thoughts_in_kwargs = derive_inner_thoughts_in_kwargs(OptionState.DEFAULT, agent_state.llm_config.model) - # Assert that the message has an inner monologue assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 2b767937..a938617a 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -283,3 +283,39 @@ def test_gemini_pro_15_returns_valid_first_message(): response = check_first_response_is_valid_for_llm_endpoint(filename) # Log out successful response print(f"Got successful response from client: \n\n{response}") + + +def test_gemini_pro_15_returns_keyword(): + keyword = "banana" + filename = os.path.join(llm_config_dir, "gemini-pro.json") + response = check_response_contains_keyword(filename, keyword=keyword) + # Log out successful response + print(f"Got successful response from client: \n\n{response}") + + +def test_gemini_pro_15_uses_external_tool(): + filename = os.path.join(llm_config_dir, "gemini-pro.json") + response = check_agent_uses_external_tool(filename) + # Log out successful response + print(f"Got successful response from client: \n\n{response}") + + +def test_gemini_pro_15_recall_chat_memory(): + filename = os.path.join(llm_config_dir, "gemini-pro.json") + response = check_agent_recall_chat_memory(filename) + # Log out successful response + print(f"Got successful response from client: \n\n{response}") + + +def test_gemini_pro_15_archival_memory_retrieval(): + filename = os.path.join(llm_config_dir, "gemini-pro.json") + response = check_agent_archival_memory_retrieval(filename) + # Log out successful response + print(f"Got successful response from client: \n\n{response}") + + +def test_gemini_pro_15_edit_core_memory(): + filename = os.path.join(llm_config_dir, "gemini-pro.json") + response = check_agent_edit_core_memory(filename) + # Log out successful response + print(f"Got successful response from client: \n\n{response}")