From e212b3ec17d042825fa35dc0f7ab5bf26852eb39 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 3 Oct 2024 18:33:42 -0700 Subject: [PATCH] Refactor out the testing functions so we can use this for benchmarking --- tests/helpers/endpoints_helper.py | 187 ++++++++++++++++++++++++++- tests/test_endpoints.py | 205 +----------------------------- 2 files changed, 190 insertions(+), 202 deletions(-) diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 540277f3..9775658f 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -2,9 +2,11 @@ import json import uuid from typing import Callable, List, Optional, Union -from letta import LocalClient, RESTClient +from letta import LocalClient, RESTClient, create_client +from letta.agent import Agent from letta.config import LettaConfig from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA +from letta.embeddings import embedding_model from letta.errors import ( InvalidFunctionCallError, InvalidInnerMonologueError, @@ -12,7 +14,7 @@ from letta.errors import ( MissingFunctionCallError, MissingInnerMonologueError, ) -from letta.llm_api.llm_api_tools import unpack_inner_thoughts_from_kwargs +from letta.llm_api.llm_api_tools import create, unpack_inner_thoughts_from_kwargs from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig @@ -26,11 +28,16 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message from letta.utils import get_human_text, get_persona_text +from tests.helpers.utils import cleanup # Generate uuid for agent name for this example namespace = uuid.NAMESPACE_DNS agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent")) +# defaults (letta hosted) +embedding_config_path = "configs/embedding_model_configs/letta-hosted.json" +llm_config_path = "configs/llm_model_configs/letta-hosted.json" + # ====================================================================================================================== # Section: Test Setup @@ -41,7 +48,6 @@ agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent")) def setup_agent( client: Union[LocalClient, RESTClient], filename: str, - embedding_config_path: str, memory_human_str: str = get_human_text(DEFAULT_HUMAN), memory_persona_str: str = get_persona_text(DEFAULT_PERSONA), tools: Optional[List[str]] = None, @@ -62,6 +68,181 @@ def setup_agent( return agent_state +def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False): + """ + Checks that the first response is valid: + + 1. Contains either send_message or archival_memory_search + 2. Contains valid usage of the function + 3. Contains inner monologue + + Note: This is acting on the raw LLM response, note the usage of `create` + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + agent_state = setup_agent(client, filename, embedding_config_path) + + tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools] + agent = Agent( + interface=None, + tools=tools, + agent_state=agent_state, + ) + + response = create( + llm_config=agent_state.llm_config, + user_id=str(uuid.UUID(int=1)), # dummy user_id + messages=agent._messages, + functions=agent.functions, + functions_python=agent.functions_python, + ) + + # Basic check + assert response is not None + + # Select first choice + choice = response.choices[0] + + # Ensure that the first message returns a "send_message" + validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search" + assert_contains_valid_function_call(choice.message, validator_func) + + # Assert that the message has an inner monologue + assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs) + + +def check_response_contains_keyword(filename: str): + """ + Checks that the prompted response from the LLM contains a chosen keyword + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + agent_state = setup_agent(client, filename, embedding_config_path) + + keyword = "banana" + keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"' + response = client.user_message(agent_id=agent_state.id, message=keyword_message) + + # Basic checks + assert_sanity_checks(response) + + # Make sure the message was sent + assert_invoked_send_message_with_keyword(response.messages, keyword) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + +def check_agent_uses_external_tool(filename: str): + """ + Checks that the LLM will use external tools if instructed + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + from crewai_tools import ScrapeWebsiteTool + + from letta.schemas.tool import Tool + + crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com") + tool = Tool.from_crewai(crewai_tool) + tool_name = tool.name + + # Set up client + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + client.add_tool(tool) + + # Set up persona for tool usage + persona = f""" + + My name is Letta. + + I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question. + + Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. + """ + + agent_state = setup_agent(client, filename, embedding_config_path, memory_persona_str=persona, tools=[tool_name]) + + response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?") + + # Basic checks + assert_sanity_checks(response) + + # Make sure the tool was called + assert_invoked_function_call(response.messages, tool_name) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + +def check_agent_recall_chat_memory(filename: str): + """ + Checks that the LLM will recall the chat memory, specifically the human persona. + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + # Set up client + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + human_name = "BananaBoy" + agent_state = setup_agent(client, filename, embedding_config_path, memory_human_str=f"My name is {human_name}") + + response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.") + + # Basic checks + assert_sanity_checks(response) + + # Make sure my name was repeated back to me + assert_invoked_send_message_with_keyword(response.messages, human_name) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + +def check_agent_archival_memory_retrieval(filename: str): + """ + Checks that the LLM will execute an archival memory retrieval. + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + # Set up client + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + agent_state = setup_agent(client, filename, embedding_config_path) + secret_word = "banana" + client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!") + + response = client.user_message(agent_id=agent_state.id, message="Search archival memory for the secret word and repeat it back to me.") + + # Basic checks + assert_sanity_checks(response) + + # Make sure archival_memory_search was called + assert_invoked_function_call(response.messages, "archival_memory_search") + + # Make sure secret was repeated back to me + assert_invoked_send_message_with_keyword(response.messages, secret_word) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + +def run_embedding_endpoint(filename): + # load JSON file + config_data = json.load(open(filename, "r")) + print(config_data) + embedding_config = EmbeddingConfig(**config_data) + model = embedding_model(embedding_config) + query_text = "hello" + query_vec = model.get_text_embedding(query_text) + print("vector dim", len(query_vec)) + assert query_vec is not None + + # ====================================================================================================================== # Section: Letta Message Assertions # These functions are validating elements of parsed Letta Messsage diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index b08bb4bb..3e5e3d0f 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,212 +1,19 @@ -import json import os -import uuid -from letta import create_client -from letta.agent import Agent -from letta.embeddings import embedding_model -from letta.llm_api.llm_api_tools import create -from letta.prompts import gpt_system -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.message import Message from tests.helpers.endpoints_helper import ( - agent_uuid, - assert_contains_correct_inner_monologue, - assert_contains_valid_function_call, - assert_inner_monologue_is_present_and_valid, - assert_invoked_function_call, - assert_invoked_send_message_with_keyword, - assert_sanity_checks, - setup_agent, + check_agent_archival_memory_retrieval, + check_agent_recall_chat_memory, + check_agent_uses_external_tool, + check_first_response_is_valid_for_llm_endpoint, + check_response_contains_keyword, + run_embedding_endpoint, ) -from tests.helpers.utils import cleanup - -messages = [Message(role="system", text=gpt_system.get_system_text("memgpt_chat")), Message(role="user", text="How are you?")] - -# defaults (letta hosted) -embedding_config_path = "configs/embedding_model_configs/letta-hosted.json" -llm_config_path = "configs/llm_model_configs/letta-hosted.json" # directories embedding_config_dir = "configs/embedding_model_configs" llm_config_dir = "configs/llm_model_configs" -def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False): - """ - Checks that the first response is valid: - - 1. Contains either send_message or archival_memory_search - 2. Contains valid usage of the function - 3. Contains inner monologue - - Note: This is acting on the raw LLM response, note the usage of `create` - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename, embedding_config_path) - - tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools] - agent = Agent( - interface=None, - tools=tools, - agent_state=agent_state, - ) - - response = create( - llm_config=agent_state.llm_config, - user_id=str(uuid.UUID(int=1)), # dummy user_id - messages=agent._messages, - functions=agent.functions, - functions_python=agent.functions_python, - ) - - # Basic check - assert response is not None - - # Select first choice - choice = response.choices[0] - - # Ensure that the first message returns a "send_message" - validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search" - assert_contains_valid_function_call(choice.message, validator_func) - - # Assert that the message has an inner monologue - assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs) - - -def check_response_contains_keyword(filename: str): - """ - Checks that the prompted response from the LLM contains a chosen keyword - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename, embedding_config_path) - - keyword = "banana" - keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"' - response = client.user_message(agent_id=agent_state.id, message=keyword_message) - - # Basic checks - assert_sanity_checks(response) - - # Make sure the message was sent - assert_invoked_send_message_with_keyword(response.messages, keyword) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - -def check_agent_uses_external_tool(filename: str): - """ - Checks that the LLM will use external tools if instructed - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - from crewai_tools import ScrapeWebsiteTool - - from letta.schemas.tool import Tool - - crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com") - tool = Tool.from_crewai(crewai_tool) - tool_name = tool.name - - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - client.add_tool(tool) - - # Set up persona for tool usage - persona = f""" - - My name is Letta. - - I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question. - - Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. - """ - - agent_state = setup_agent(client, filename, embedding_config_path, memory_persona_str=persona, tools=[tool_name]) - - response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?") - - # Basic checks - assert_sanity_checks(response) - - # Make sure the tool was called - assert_invoked_function_call(response.messages, tool_name) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - -def check_agent_recall_chat_memory(filename: str): - """ - Checks that the LLM will recall the chat memory, specifically the human persona. - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - human_name = "BananaBoy" - agent_state = setup_agent(client, filename, embedding_config_path, memory_human_str=f"My name is {human_name}") - - response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.") - - # Basic checks - assert_sanity_checks(response) - - # Make sure my name was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, human_name) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - -def check_agent_archival_memory_retrieval(filename: str): - """ - Checks that the LLM will execute an archival memory retrieval. - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename, embedding_config_path) - secret_word = "banana" - client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!") - - response = client.user_message(agent_id=agent_state.id, message="Search archival memory for the secret word and repeat it back to me.") - - # Basic checks - assert_sanity_checks(response) - - # Make sure archival_memory_search was called - assert_invoked_function_call(response.messages, "archival_memory_search") - - # Make sure secret was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, secret_word) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - -def run_embedding_endpoint(filename): - # load JSON file - config_data = json.load(open(filename, "r")) - print(config_data) - embedding_config = EmbeddingConfig(**config_data) - model = embedding_model(embedding_config) - query_text = "hello" - query_vec = model.get_text_embedding(query_text) - print("vector dim", len(query_vec)) - assert query_vec is not None - - # ====================================================================================================================== # OPENAI TESTS # ======================================================================================================================