test: add complex testing for Groq Llama 3.1 70b (#1845)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
@@ -491,7 +491,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Tool(description=None, source_type='python', module=None, user_id='user-552dee3c-baaf-443a-9d23-8bb54f4af964', id='tool-7559f3f1-e988-4363-a1dd-2dfff8d91a64', name='query_birthday_db', tags=['extras'], source_code='def query_birthday_db(self, name: str): \\n \"\"\"\\n This tool queries an external database to \\n lookup the birthday of someone given their name.\\n\\n Args: \\n name (str): The name to look up \\n\\n Returns: \\n birthday (str): The birthday in mm-dd-yyyy format\\n\\n \"\"\"\\n my_fake_data = {\\n \"bob\": \"03-06-1997\", \\n \"sarah\": \"03-06-1997\"\\n } \\n name = name.lower() \\n if name not in my_fake_data: \\n return None\\n else: \\n return my_fake_data[name]\\n', json_schema={'name': 'query_birthday_db', 'description': 'This tool queries an external database to ', 'parameters': {'type': 'object', 'properties': {'name': {'type': 'string', 'description': 'The name to look up '}, 'request_heartbeat': {'type': 'boolean', 'description': \"Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.\"}}, 'required': ['name', 'request_heartbeat']}})"
|
||||
"Tool(description=None, source_type='python', module=None, user_id='user-552dee3c-baaf-443a-9d23-8bb54f4af964', id='tool-7559f3f1-e988-4363-a1dd-2dfff8d91a64', name='query_birthday_db', tags=['extras'], source_code='def query_birthday_db(self, name: str): \\n \"\"\"\\n This tool queries an external database to \\n lookup the birthday of someone given their name.\\n\\n Args: \\n name (str): The name to look up \\n\\n Returns: \\n birthday (str): The birthday in mm-dd-yyyy format\\n\\n \"\"\"\\n my_fake_data = {\\n \"bob\": \"03-06-1997\", \\n \"sarah\": \"03-06-1997\"\\n } \\n name = name.lower() \\n if name not in my_fake_data: \\n return None\\n else: \\n return my_fake_data[name]\\n', json_schema={'name': 'query_birthday_db', 'description': 'This tool queries an external database to ', 'parameters': {'type': 'object', 'properties': {'name': {'type': 'string', 'description': 'The name to look up '}, 'request_heartbeat': {'type': 'boolean', 'description': \"Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.\"}}, 'required': ['name', 'request_heartbeat']}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
|
||||
@@ -482,7 +482,7 @@ class Agent(BaseAgent):
|
||||
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
|
||||
)
|
||||
|
||||
if len(response.choices) == 0:
|
||||
if len(response.choices) == 0 or response.choices[0] is None:
|
||||
raise Exception(f"API call didn't return a message: {response}")
|
||||
|
||||
# special case for 'length'
|
||||
@@ -621,6 +621,11 @@ class Agent(BaseAgent):
|
||||
# (Still parsing function args)
|
||||
# Handle requests for immediate heartbeat
|
||||
heartbeat_request = function_args.pop("request_heartbeat", None)
|
||||
|
||||
# Edge case: heartbeat_request is returned as a stringified boolean, we will attempt to parse:
|
||||
if isinstance(heartbeat_request, str) and heartbeat_request.lower().strip() == "true":
|
||||
heartbeat_request = True
|
||||
|
||||
if not isinstance(heartbeat_request, bool) or heartbeat_request is None:
|
||||
printd(
|
||||
f"{CLI_WARNING_PREFIX}'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
|
||||
|
||||
@@ -130,7 +130,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
||||
if function.__name__ not in ["send_message", "pause_heartbeats"]:
|
||||
schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
|
||||
@@ -296,7 +296,6 @@ def create(
|
||||
raise NotImplementedError(f"Streaming not yet implemented for Groq.")
|
||||
|
||||
if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions":
|
||||
# only is a problem if we are *not* using an openai proxy
|
||||
raise ValueError(f"Groq key is missing from letta config file")
|
||||
|
||||
# force to true for groq, since they don't support 'content' is non-null
|
||||
|
||||
@@ -93,7 +93,7 @@ class Tool(BaseTool):
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
json_schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
@@ -128,7 +128,7 @@ class Tool(BaseTool):
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
json_schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
@@ -161,7 +161,7 @@ class Tool(BaseTool):
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
json_schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
|
||||
@@ -399,6 +399,8 @@ class SyncServer(Server):
|
||||
token_warning = step_response.in_context_memory_warning
|
||||
usage = step_response.usage
|
||||
|
||||
print(step_response.model_dump_json(indent=4))
|
||||
|
||||
step_count += 1
|
||||
total_usage += usage
|
||||
counter += 1
|
||||
@@ -602,7 +604,7 @@ class SyncServer(Server):
|
||||
)
|
||||
|
||||
# Run the agent state forward
|
||||
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message, timestamp=timestamp)
|
||||
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=message, timestamp=timestamp)
|
||||
return usage
|
||||
|
||||
def system_message(
|
||||
|
||||
@@ -13,8 +13,8 @@ class ModelSettings(BaseSettings):
|
||||
openai_api_key: Optional[str] = None
|
||||
# TODO: provide overriding BASE_URL?
|
||||
|
||||
# grok
|
||||
grok_api_key: Optional[str] = None
|
||||
# groq
|
||||
groq_api_key: Optional[str] = None
|
||||
|
||||
# anthropic
|
||||
anthropic_api_key: Optional[str] = None
|
||||
|
||||
18
poetry.lock
generated
18
poetry.lock
generated
@@ -2877,6 +2877,22 @@ mistralai = ["mistralai (>=0.1.8,<0.2.0)"]
|
||||
test-docs = ["anthropic (>=0.27.0,<0.28.0)", "cohere (>=5.1.8,<6.0.0)", "diskcache (>=5.6.3,<6.0.0)", "fastapi (>=0.109.2,<0.110.0)", "groq (>=0.4.2,<0.5.0)", "litellm (>=1.35.31,<2.0.0)", "mistralai (>=0.1.8,<0.2.0)", "pandas (>=2.2.0,<3.0.0)", "pydantic_extra_types (>=2.6.0,<3.0.0)", "redis (>=5.0.1,<6.0.0)", "tabulate (>=0.9.0,<0.10.0)"]
|
||||
vertexai = ["google-cloud-aiplatform (>=1.52.0,<2.0.0)", "jsonref (>=1.1.0,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "ipdb"
|
||||
version = "0.13.13"
|
||||
description = "IPython-enabled pdb"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
files = [
|
||||
{file = "ipdb-0.13.13-py3-none-any.whl", hash = "sha256:45529994741c4ab6d2388bfa5d7b725c2cf7fe9deffabdb8a6113aa5ed449ed4"},
|
||||
{file = "ipdb-0.13.13.tar.gz", hash = "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
decorator = {version = "*", markers = "python_version > \"3.6\""}
|
||||
ipython = {version = ">=7.31.1", markers = "python_version > \"3.6\""}
|
||||
tomli = {version = "*", markers = "python_version > \"3.6\" and python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "ipykernel"
|
||||
version = "6.29.5"
|
||||
@@ -8338,4 +8354,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.13,>=3.10"
|
||||
content-hash = "07f97bbb6e045f76ff1667215e15d8778b0ccbdd816810d802fc76b796012dd1"
|
||||
content-hash = "aa0bbf5825741bdc9c06388e7e27c1d9a2d85d517abb7f51cca71cc8349d1170"
|
||||
|
||||
@@ -92,6 +92,7 @@ tests = ["wikipedia"]
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.4.2"
|
||||
ipykernel = "^6.29.5"
|
||||
ipdb = "^0.13.13"
|
||||
|
||||
[tool.black]
|
||||
line-length = 140
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model": "llama3-groq-70b-8192-tool-use-preview",
|
||||
"model": "llama-3.1-70b-versatile",
|
||||
"model_endpoint_type": "groq",
|
||||
"model_endpoint": "https://api.groq.com/openai/v1",
|
||||
"model_wrapper": null
|
||||
|
||||
@@ -20,7 +20,6 @@ from letta.embeddings import embedding_model
|
||||
from letta.errors import (
|
||||
InvalidFunctionCallError,
|
||||
InvalidInnerMonologueError,
|
||||
LettaError,
|
||||
MissingFunctionCallError,
|
||||
MissingInnerMonologueError,
|
||||
)
|
||||
@@ -122,6 +121,7 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
|
||||
assert response is not None, response
|
||||
assert response.choices is not None, response
|
||||
assert len(response.choices) > 0, response
|
||||
assert response.choices[0] is not None, response
|
||||
|
||||
# Select first choice
|
||||
choice = response.choices[0]
|
||||
@@ -249,7 +249,10 @@ def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse:
|
||||
secret_word = "banana"
|
||||
client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!")
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="Search archival memory for the secret word and repeat it back to me.")
|
||||
response = client.user_message(
|
||||
agent_id=agent_state.id,
|
||||
message="Search archival memory for the secret word. If you find it successfully, you MUST respond by using the `send_message` function with a message that includes the secret word so I know you found it.",
|
||||
)
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
@@ -328,7 +331,7 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
|
||||
|
||||
# No messages found with `send_messages`
|
||||
if target_message is None:
|
||||
raise LettaError("Missing send_message function call")
|
||||
raise MissingFunctionCallError(messages=messages, explanation="Missing `send_message` function call")
|
||||
|
||||
send_message_function_call = target_message.function_call
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import functools
|
||||
import os
|
||||
import time
|
||||
|
||||
from tests.helpers.endpoints_helper import (
|
||||
check_agent_archival_memory_retrieval,
|
||||
@@ -15,6 +17,43 @@ embedding_config_dir = "configs/embedding_model_configs"
|
||||
llm_config_dir = "tests/configs/llm_model_configs"
|
||||
|
||||
|
||||
def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
|
||||
"""
|
||||
Decorator to retry a test until a failure threshold is crossed.
|
||||
|
||||
:param threshold: Expected passing rate (e.g., 0.5 means 50% success rate expected).
|
||||
:param max_attempts: Maximum number of attempts to retry the test.
|
||||
"""
|
||||
|
||||
def decorator_retry(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
failure_count += 1
|
||||
print(f"\033[93mAn attempt failed with error:\n{e}\033[0m")
|
||||
|
||||
time.sleep(sleep_time_seconds)
|
||||
|
||||
rate = success_count / max_attempts
|
||||
if rate >= threshold:
|
||||
print(f"Test met expected passing rate of {threshold:.2f}. Actual rate: {success_count}/{max_attempts}")
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Test did not meet expected passing rate of {threshold:.2f}. Actual rate: {success_count}/{max_attempts}"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_retry
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# OPENAI TESTS
|
||||
# ======================================================================================================================
|
||||
@@ -192,6 +231,45 @@ def test_claude_opus_3_edit_core_memory():
|
||||
# ======================================================================================================================
|
||||
# GROQ TESTS
|
||||
# ======================================================================================================================
|
||||
def test_llm_endpoint_groq():
|
||||
def test_groq_llama31_70b_returns_valid_first_message():
|
||||
filename = os.path.join(llm_config_dir, "groq.json")
|
||||
check_first_response_is_valid_for_llm_endpoint(filename)
|
||||
response = check_first_response_is_valid_for_llm_endpoint(filename)
|
||||
# Log out successful response
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
def test_groq_llama31_70b_returns_keyword():
|
||||
keyword = "banana"
|
||||
filename = os.path.join(llm_config_dir, "groq.json")
|
||||
response = check_response_contains_keyword(filename, keyword=keyword)
|
||||
# Log out successful response
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
def test_groq_llama31_70b_uses_external_tool():
|
||||
filename = os.path.join(llm_config_dir, "groq.json")
|
||||
response = check_agent_uses_external_tool(filename)
|
||||
# Log out successful response
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
def test_groq_llama31_70b_recall_chat_memory():
|
||||
filename = os.path.join(llm_config_dir, "groq.json")
|
||||
response = check_agent_recall_chat_memory(filename)
|
||||
# Log out successful response
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
@retry_until_threshold(threshold=0.75, max_attempts=4)
|
||||
def test_groq_llama31_70b_archival_memory_retrieval():
|
||||
filename = os.path.join(llm_config_dir, "groq.json")
|
||||
response = check_agent_archival_memory_retrieval(filename)
|
||||
# Log out successful response
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
|
||||
def test_groq_llama31_70b_edit_core_memory():
|
||||
filename = os.path.join(llm_config_dir, "groq.json")
|
||||
response = check_agent_edit_core_memory(filename)
|
||||
# Log out successful response
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
|
||||
Reference in New Issue
Block a user