test: add complex testing for Groq Llama 3.1 70b (#1845)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou
2024-10-08 14:22:13 -07:00
committed by GitHub
parent 91287a76c9
commit cffd493f75
12 changed files with 121 additions and 17 deletions

View File

@@ -491,7 +491,7 @@
{
"data": {
"text/plain": [
"Tool(description=None, source_type='python', module=None, user_id='user-552dee3c-baaf-443a-9d23-8bb54f4af964', id='tool-7559f3f1-e988-4363-a1dd-2dfff8d91a64', name='query_birthday_db', tags=['extras'], source_code='def query_birthday_db(self, name: str): \\n \"\"\"\\n This tool queries an external database to \\n lookup the birthday of someone given their name.\\n\\n Args: \\n name (str): The name to look up \\n\\n Returns: \\n birthday (str): The birthday in mm-dd-yyyy format\\n\\n \"\"\"\\n my_fake_data = {\\n \"bob\": \"03-06-1997\", \\n \"sarah\": \"03-06-1997\"\\n } \\n name = name.lower() \\n if name not in my_fake_data: \\n return None\\n else: \\n return my_fake_data[name]\\n', json_schema={'name': 'query_birthday_db', 'description': 'This tool queries an external database to ', 'parameters': {'type': 'object', 'properties': {'name': {'type': 'string', 'description': 'The name to look up '}, 'request_heartbeat': {'type': 'boolean', 'description': \"Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.\"}}, 'required': ['name', 'request_heartbeat']}})"
"Tool(description=None, source_type='python', module=None, user_id='user-552dee3c-baaf-443a-9d23-8bb54f4af964', id='tool-7559f3f1-e988-4363-a1dd-2dfff8d91a64', name='query_birthday_db', tags=['extras'], source_code='def query_birthday_db(self, name: str): \\n \"\"\"\\n This tool queries an external database to \\n lookup the birthday of someone given their name.\\n\\n Args: \\n name (str): The name to look up \\n\\n Returns: \\n birthday (str): The birthday in mm-dd-yyyy format\\n\\n \"\"\"\\n my_fake_data = {\\n \"bob\": \"03-06-1997\", \\n \"sarah\": \"03-06-1997\"\\n } \\n name = name.lower() \\n if name not in my_fake_data: \\n return None\\n else: \\n return my_fake_data[name]\\n', json_schema={'name': 'query_birthday_db', 'description': 'This tool queries an external database to ', 'parameters': {'type': 'object', 'properties': {'name': {'type': 'string', 'description': 'The name to look up '}, 'request_heartbeat': {'type': 'boolean', 'description': \"Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.\"}}, 'required': ['name', 'request_heartbeat']}})"
]
},
"execution_count": 23,

View File

@@ -482,7 +482,7 @@ class Agent(BaseAgent):
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
)
if len(response.choices) == 0:
if len(response.choices) == 0 or response.choices[0] is None:
raise Exception(f"API call didn't return a message: {response}")
# special case for 'length'
@@ -621,6 +621,11 @@ class Agent(BaseAgent):
# (Still parsing function args)
# Handle requests for immediate heartbeat
heartbeat_request = function_args.pop("request_heartbeat", None)
# Edge case: heartbeat_request is returned as a stringified boolean, we will attempt to parse:
if isinstance(heartbeat_request, str) and heartbeat_request.lower().strip() == "true":
heartbeat_request = True
if not isinstance(heartbeat_request, bool) or heartbeat_request is None:
printd(
f"{CLI_WARNING_PREFIX}'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"

View File

@@ -130,7 +130,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
if function.__name__ not in ["send_message", "pause_heartbeats"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
schema["parameters"]["required"].append("request_heartbeat")

View File

@@ -296,7 +296,6 @@ def create(
raise NotImplementedError(f"Streaming not yet implemented for Groq.")
if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"Groq key is missing from letta config file")
# force to true for groq, since they don't support 'content' is non-null

View File

@@ -93,7 +93,7 @@ class Tool(BaseTool):
# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")
@@ -128,7 +128,7 @@ class Tool(BaseTool):
# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")
@@ -161,7 +161,7 @@ class Tool(BaseTool):
# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")

View File

@@ -399,6 +399,8 @@ class SyncServer(Server):
token_warning = step_response.in_context_memory_warning
usage = step_response.usage
print(step_response.model_dump_json(indent=4))
step_count += 1
total_usage += usage
counter += 1
@@ -602,7 +604,7 @@ class SyncServer(Server):
)
# Run the agent state forward
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message, timestamp=timestamp)
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=message, timestamp=timestamp)
return usage
def system_message(

View File

@@ -13,8 +13,8 @@ class ModelSettings(BaseSettings):
openai_api_key: Optional[str] = None
# TODO: provide overriding BASE_URL?
# grok
grok_api_key: Optional[str] = None
# groq
groq_api_key: Optional[str] = None
# anthropic
anthropic_api_key: Optional[str] = None

18
poetry.lock generated
View File

@@ -2877,6 +2877,22 @@ mistralai = ["mistralai (>=0.1.8,<0.2.0)"]
test-docs = ["anthropic (>=0.27.0,<0.28.0)", "cohere (>=5.1.8,<6.0.0)", "diskcache (>=5.6.3,<6.0.0)", "fastapi (>=0.109.2,<0.110.0)", "groq (>=0.4.2,<0.5.0)", "litellm (>=1.35.31,<2.0.0)", "mistralai (>=0.1.8,<0.2.0)", "pandas (>=2.2.0,<3.0.0)", "pydantic_extra_types (>=2.6.0,<3.0.0)", "redis (>=5.0.1,<6.0.0)", "tabulate (>=0.9.0,<0.10.0)"]
vertexai = ["google-cloud-aiplatform (>=1.52.0,<2.0.0)", "jsonref (>=1.1.0,<2.0.0)"]
[[package]]
name = "ipdb"
version = "0.13.13"
description = "IPython-enabled pdb"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "ipdb-0.13.13-py3-none-any.whl", hash = "sha256:45529994741c4ab6d2388bfa5d7b725c2cf7fe9deffabdb8a6113aa5ed449ed4"},
{file = "ipdb-0.13.13.tar.gz", hash = "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726"},
]
[package.dependencies]
decorator = {version = "*", markers = "python_version > \"3.6\""}
ipython = {version = ">=7.31.1", markers = "python_version > \"3.6\""}
tomli = {version = "*", markers = "python_version > \"3.6\" and python_version < \"3.11\""}
[[package]]
name = "ipykernel"
version = "6.29.5"
@@ -8338,4 +8354,4 @@ tests = ["wikipedia"]
[metadata]
lock-version = "2.0"
python-versions = "<3.13,>=3.10"
content-hash = "07f97bbb6e045f76ff1667215e15d8778b0ccbdd816810d802fc76b796012dd1"
content-hash = "aa0bbf5825741bdc9c06388e7e27c1d9a2d85d517abb7f51cca71cc8349d1170"

View File

@@ -92,6 +92,7 @@ tests = ["wikipedia"]
[tool.poetry.group.dev.dependencies]
black = "^24.4.2"
ipykernel = "^6.29.5"
ipdb = "^0.13.13"
[tool.black]
line-length = 140

View File

@@ -1,6 +1,6 @@
{
"context_window": 8192,
"model": "llama3-groq-70b-8192-tool-use-preview",
"model": "llama-3.1-70b-versatile",
"model_endpoint_type": "groq",
"model_endpoint": "https://api.groq.com/openai/v1",
"model_wrapper": null

View File

@@ -20,7 +20,6 @@ from letta.embeddings import embedding_model
from letta.errors import (
InvalidFunctionCallError,
InvalidInnerMonologueError,
LettaError,
MissingFunctionCallError,
MissingInnerMonologueError,
)
@@ -122,6 +121,7 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
assert response is not None, response
assert response.choices is not None, response
assert len(response.choices) > 0, response
assert response.choices[0] is not None, response
# Select first choice
choice = response.choices[0]
@@ -249,7 +249,10 @@ def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse:
secret_word = "banana"
client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!")
response = client.user_message(agent_id=agent_state.id, message="Search archival memory for the secret word and repeat it back to me.")
response = client.user_message(
agent_id=agent_state.id,
message="Search archival memory for the secret word. If you find it successfully, you MUST respond by using the `send_message` function with a message that includes the secret word so I know you found it.",
)
# Basic checks
assert_sanity_checks(response)
@@ -328,7 +331,7 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
# No messages found with `send_messages`
if target_message is None:
raise LettaError("Missing send_message function call")
raise MissingFunctionCallError(messages=messages, explanation="Missing `send_message` function call")
send_message_function_call = target_message.function_call
try:

View File

@@ -1,4 +1,6 @@
import functools
import os
import time
from tests.helpers.endpoints_helper import (
check_agent_archival_memory_retrieval,
@@ -15,6 +17,43 @@ embedding_config_dir = "configs/embedding_model_configs"
llm_config_dir = "tests/configs/llm_model_configs"
def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
"""
Decorator to retry a test until a failure threshold is crossed.
:param threshold: Expected passing rate (e.g., 0.5 means 50% success rate expected).
:param max_attempts: Maximum number of attempts to retry the test.
"""
def decorator_retry(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
success_count = 0
failure_count = 0
for attempt in range(max_attempts):
try:
func(*args, **kwargs)
success_count += 1
except Exception as e:
failure_count += 1
print(f"\033[93mAn attempt failed with error:\n{e}\033[0m")
time.sleep(sleep_time_seconds)
rate = success_count / max_attempts
if rate >= threshold:
print(f"Test met expected passing rate of {threshold:.2f}. Actual rate: {success_count}/{max_attempts}")
else:
raise AssertionError(
f"Test did not meet expected passing rate of {threshold:.2f}. Actual rate: {success_count}/{max_attempts}"
)
return wrapper
return decorator_retry
# ======================================================================================================================
# OPENAI TESTS
# ======================================================================================================================
@@ -192,6 +231,45 @@ def test_claude_opus_3_edit_core_memory():
# ======================================================================================================================
# GROQ TESTS
# ======================================================================================================================
def test_llm_endpoint_groq():
def test_groq_llama31_70b_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "groq.json")
check_first_response_is_valid_for_llm_endpoint(filename)
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_groq_llama31_70b_returns_keyword():
keyword = "banana"
filename = os.path.join(llm_config_dir, "groq.json")
response = check_response_contains_keyword(filename, keyword=keyword)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_groq_llama31_70b_uses_external_tool():
filename = os.path.join(llm_config_dir, "groq.json")
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_groq_llama31_70b_recall_chat_memory():
filename = os.path.join(llm_config_dir, "groq.json")
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@retry_until_threshold(threshold=0.75, max_attempts=4)
def test_groq_llama31_70b_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "groq.json")
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_groq_llama31_70b_edit_core_memory():
filename = os.path.join(llm_config_dir, "groq.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")