diff --git a/.github/workflows/test_openai.yml b/.github/workflows/test_openai.yml index af791740..00553a43 100644 --- a/.github/workflows/test_openai.yml +++ b/.github/workflows/test_openai.yml @@ -30,11 +30,35 @@ jobs: run: | poetry run letta quickstart --backend openai - - name: Test LLM endpoint + - name: Test first message contains expected function call and inner monologue env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | - poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_openai + poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_returns_valid_first_message + + - name: Test model sends message with keyword + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_returns_keyword + + - name: Test model uses external tool correctly + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_uses_external_tool + + - name: Test model recalls chat memory + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_recall_chat_memory + + - name: Test model uses `archival_memory_search` to find secret + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_archival_memory_retrieval - name: Test embedding endpoint env: diff --git a/letta/client/client.py b/letta/client/client.py index e2d5455b..7be577e8 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1592,7 +1592,7 @@ class LocalClient(AbstractClient): # memory def get_in_context_memory(self, agent_id: str) -> Memory: """ - Get the in-contxt (i.e. core) memory of an agent + Get the in-context (i.e. core) memory of an agent Args: agent_id (str): ID of the agent diff --git a/letta/errors.py b/letta/errors.py index 45c38090..852ec874 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING +import json +from typing import TYPE_CHECKING, List, Optional, Union # Avoid circular imports if TYPE_CHECKING: @@ -37,73 +38,47 @@ class LocalLLMConnectionError(LettaError): super().__init__(self.message) -class MissingFunctionCallError(LettaError): - message: "Message" - """ The message that caused this error. +class LettaMessageError(LettaError): + """Base error class for handling message-related errors.""" - This error should be raised when a message that we expect to have a function call does not. - """ - - def __init__(self, *, message: "Message") -> None: - error_msg = "The message is missing a function call: \n\n" - - # Pretty print out message - message_json = message.model_dump_json(indent=4) - error_msg += f"{message_json}" + messages: List[Union["Message", "LettaMessage"]] + default_error_message: str = "An error occurred with the message." + def __init__(self, *, messages: List[Union["Message", "LettaMessage"]], explanation: Optional[str] = None) -> None: + error_msg = self.construct_error_message(messages, self.default_error_message, explanation) super().__init__(error_msg) - self.message = message + self.messages = messages + + @staticmethod + def construct_error_message(messages: List[Union["Message", "LettaMessage"]], error_msg: str, explanation: Optional[str] = None) -> str: + """Helper method to construct a clean and formatted error message.""" + if explanation: + error_msg += f" (Explanation: {explanation})" + + # Pretty print out message JSON + message_json = json.dumps([message.model_dump_json(indent=4) for message in messages], indent=4) + return f"{error_msg}\n\n{message_json}" -class InvalidFunctionCallError(LettaError): - message: "Message" - """ The message that caused this error. +class MissingFunctionCallError(LettaMessageError): + """Error raised when a message is missing a function call.""" - This error should be raised when a message uses a function that is unexpected or invalid, or if the usage is incorrect. - """ - - def __init__(self, *, message: "Message") -> None: - error_msg = "The message uses an invalid function call or has improper usage of a function call: \n\n" - - # Pretty print out message - message_json = message.model_dump_json(indent=4) - error_msg += f"{message_json}" - - super().__init__(error_msg) - self.message = message + default_error_message = "The message is missing a function call." -class MissingInnerMonologueError(LettaError): - message: "Message" - """ The message that caused this error. +class InvalidFunctionCallError(LettaMessageError): + """Error raised when a message uses an invalid function call.""" - This error should be raised when a message that we expect to have an inner monologue does not. - """ - - def __init__(self, *, message: "Message") -> None: - error_msg = "The message is missing an inner monologue: \n\n" - - # Pretty print out message - message_json = message.model_dump_json(indent=4) - error_msg += f"{message_json}" - - super().__init__(error_msg) - self.message = message + default_error_message = "The message uses an invalid function call or has improper usage of a function call." -class InvalidInnerMonologueError(LettaError): - message: "Message" - """ The message that caused this error. +class MissingInnerMonologueError(LettaMessageError): + """Error raised when a message is missing an inner monologue.""" - This error should be raised when a message has an improperly formatted inner monologue. - """ + default_error_message = "The message is missing an inner monologue." - def __init__(self, *, message: "Message") -> None: - error_msg = "The message has a malformed inner monologue: \n\n" - # Pretty print out message - message_json = message.model_dump_json(indent=4) - error_msg += f"{message_json}" +class InvalidInnerMonologueError(LettaMessageError): + """Error raised when a message has a malformed inner monologue.""" - super().__init__(error_msg) - self.message = message + default_error_message = "The message has a malformed inner monologue." diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 7d9578f1..540277f3 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -1,21 +1,51 @@ import json -from typing import Callable, Optional +import uuid +from typing import Callable, List, Optional, Union +from letta import LocalClient, RESTClient from letta.config import LettaConfig +from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA from letta.errors import ( InvalidFunctionCallError, InvalidInnerMonologueError, + LettaError, MissingFunctionCallError, MissingInnerMonologueError, ) from letta.llm_api.llm_api_tools import unpack_inner_thoughts_from_kwargs from letta.local_llm.constants import INNER_THOUGHTS_KWARG +from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_message import ( + FunctionCallMessage, + InternalMonologue, + LettaMessage, +) +from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig +from letta.schemas.memory import ChatMemory from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message +from letta.utils import get_human_text, get_persona_text + +# Generate uuid for agent name for this example +namespace = uuid.NAMESPACE_DNS +agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent")) -def setup_llm_endpoint(filename: str, embedding_config_path: str) -> [LLMConfig, EmbeddingConfig]: +# ====================================================================================================================== +# Section: Test Setup +# These functions help setup the test +# ====================================================================================================================== + + +def setup_agent( + client: Union[LocalClient, RESTClient], + filename: str, + embedding_config_path: str, + memory_human_str: str = get_human_text(DEFAULT_HUMAN), + memory_persona_str: str = get_persona_text(DEFAULT_PERSONA), + tools: Optional[List[str]] = None, +) -> AgentState: config_data = json.load(open(filename, "r")) llm_config = LLMConfig(**config_data) embedding_config = EmbeddingConfig(**json.load(open(embedding_config_path))) @@ -26,10 +56,84 @@ def setup_llm_endpoint(filename: str, embedding_config_path: str) -> [LLMConfig, config.default_embedding_config = embedding_config config.save() - return llm_config, embedding_config + memory = ChatMemory(human=memory_human_str, persona=memory_persona_str) + agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tools=tools) + + return agent_state -def assert_contains_valid_function_call(message: Message, function_call_validator: Optional[Callable[[FunctionCall], bool]] = None) -> None: +# ====================================================================================================================== +# Section: Letta Message Assertions +# These functions are validating elements of parsed Letta Messsage +# ====================================================================================================================== + + +def assert_sanity_checks(response: LettaResponse): + assert response is not None + assert response.messages is not None + assert len(response.messages) > 0 + + +def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str) -> None: + # Find first instance of send_message + target_message = None + for message in messages: + if isinstance(message, FunctionCallMessage) and message.function_call.name == "send_message": + target_message = message + break + + # No messages found with `send_messages` + if target_message is None: + raise LettaError("Missing send_message function call") + + send_message_function_call = target_message.function_call + try: + arguments = json.loads(send_message_function_call.arguments) + except: + raise InvalidFunctionCallError(messages=[target_message], explanation="Function call arguments could not be loaded into JSON") + + # Message field not in send_message + if "message" not in arguments: + raise InvalidFunctionCallError( + messages=[target_message], explanation=f"send_message function call does not have required field `message`" + ) + + # Check that the keyword is in the message arguments + if not keyword in arguments["message"]: + raise InvalidFunctionCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}") + + +def assert_invoked_function_call(messages: List[LettaMessage], function_name: str) -> None: + for message in messages: + if isinstance(message, FunctionCallMessage) and message.function_call.name == function_name: + # Found it, do nothing + return + + raise MissingFunctionCallError( + messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}" + ) + + +def assert_inner_monologue_is_present_and_valid(messages: List[LettaMessage]) -> None: + for message in messages: + if isinstance(message, InternalMonologue): + # Found it, do nothing + return + + raise MissingInnerMonologueError(messages=messages) + + +# ====================================================================================================================== +# Section: Raw API Assertions +# These functions are validating elements of the (close to) raw LLM API's response +# ====================================================================================================================== + + +def assert_contains_valid_function_call( + message: Message, + function_call_validator: Optional[Callable[[FunctionCall], bool]] = None, + validation_failure_summary: Optional[str] = None, +) -> None: """ Helper function to check that a message contains a valid function call. @@ -39,33 +143,50 @@ def assert_contains_valid_function_call(message: Message, function_call_validato if (hasattr(message, "function_call") and message.function_call is not None) and ( hasattr(message, "tool_calls") and message.tool_calls is not None ): - return False + raise InvalidFunctionCallError(messages=[message], explanation="Both function_call and tool_calls is present in the message") elif hasattr(message, "function_call") and message.function_call is not None: function_call = message.function_call elif hasattr(message, "tool_calls") and message.tool_calls is not None: + # Note: We only take the first one for now. Is this a problem? @charles + # This seems to be standard across the repo function_call = message.tool_calls[0].function else: # Throw a missing function call error - raise MissingFunctionCallError(message=message) + raise MissingFunctionCallError(messages=[message]) if function_call_validator and not function_call_validator(function_call): - raise InvalidFunctionCallError(message=message) + raise InvalidFunctionCallError(messages=[message], explanation=validation_failure_summary) -def inner_monologue_is_valid(monologue: str) -> bool: +def assert_inner_monologue_is_valid(message: Message) -> None: + """ + Helper function to check that the inner monologue is valid. + """ invalid_chars = '(){}[]"' # Sometimes the syntax won't be correct and internal syntax will leak into message invalid_phrases = ["functions", "send_message"] - return any(char in monologue for char in invalid_chars) or any(p in monologue for p in invalid_phrases) + monologue = message.content + for char in invalid_chars: + if char in monologue: + raise InvalidInnerMonologueError(messages=[message], explanation=f"{char} is in monologue") + + for phrase in invalid_phrases: + if phrase in monologue: + raise InvalidInnerMonologueError(messages=[message], explanation=f"{phrase} is in monologue") def assert_contains_correct_inner_monologue(choice: Choice, inner_thoughts_in_kwargs: bool) -> None: + """ + Helper function to check that the inner monologue exists and is valid. + """ + # Unpack inner thoughts out of function kwargs, and repackage into choice if inner_thoughts_in_kwargs: choice = unpack_inner_thoughts_from_kwargs(choice, INNER_THOUGHTS_KWARG) - monologue = choice.message.content + message = choice.message + monologue = message.content if not monologue or monologue is None or monologue == "": - raise MissingInnerMonologueError(message=choice.message) - elif not inner_monologue_is_valid(monologue): - raise InvalidInnerMonologueError(message=choice.message) + raise MissingInnerMonologueError(messages=[message]) + + assert_inner_monologue_is_valid(message) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 5f65f84e..b08bb4bb 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -10,9 +10,14 @@ from letta.prompts import gpt_system from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.message import Message from tests.helpers.endpoints_helper import ( + agent_uuid, assert_contains_correct_inner_monologue, assert_contains_valid_function_call, - setup_llm_endpoint, + assert_inner_monologue_is_present_and_valid, + assert_invoked_function_call, + assert_invoked_send_message_with_keyword, + assert_sanity_checks, + setup_agent, ) from tests.helpers.utils import cleanup @@ -26,17 +31,21 @@ llm_config_path = "configs/llm_model_configs/letta-hosted.json" embedding_config_dir = "configs/embedding_model_configs" llm_config_dir = "configs/llm_model_configs" -# Generate uuid for agent name for this example -namespace = uuid.NAMESPACE_DNS -agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent")) - def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False): - llm_config, embedding_config = setup_llm_endpoint(filename, embedding_config_path) + """ + Checks that the first response is valid: + 1. Contains either send_message or archival_memory_search + 2. Contains valid usage of the function + 3. Contains inner monologue + + Note: This is acting on the raw LLM response, note the usage of `create` + """ client = create_client() cleanup(client=client, agent_uuid=agent_uuid) - agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config) + agent_state = setup_agent(client, filename, embedding_config_path) + tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools] agent = Agent( interface=None, @@ -45,9 +54,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts ) response = create( - llm_config=llm_config, - user_id=uuid.UUID(int=1), # dummy user_id - # messages=agent_state.messages, + llm_config=agent_state.llm_config, + user_id=str(uuid.UUID(int=1)), # dummy user_id messages=agent._messages, functions=agent.functions, functions_python=agent.functions_python, @@ -63,10 +71,130 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search" assert_contains_valid_function_call(choice.message, validator_func) - # Assert that the choice has an inner monologue + # Assert that the message has an inner monologue assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs) +def check_response_contains_keyword(filename: str): + """ + Checks that the prompted response from the LLM contains a chosen keyword + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + agent_state = setup_agent(client, filename, embedding_config_path) + + keyword = "banana" + keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"' + response = client.user_message(agent_id=agent_state.id, message=keyword_message) + + # Basic checks + assert_sanity_checks(response) + + # Make sure the message was sent + assert_invoked_send_message_with_keyword(response.messages, keyword) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + +def check_agent_uses_external_tool(filename: str): + """ + Checks that the LLM will use external tools if instructed + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + from crewai_tools import ScrapeWebsiteTool + + from letta.schemas.tool import Tool + + crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com") + tool = Tool.from_crewai(crewai_tool) + tool_name = tool.name + + # Set up client + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + client.add_tool(tool) + + # Set up persona for tool usage + persona = f""" + + My name is Letta. + + I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question. + + Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. + """ + + agent_state = setup_agent(client, filename, embedding_config_path, memory_persona_str=persona, tools=[tool_name]) + + response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?") + + # Basic checks + assert_sanity_checks(response) + + # Make sure the tool was called + assert_invoked_function_call(response.messages, tool_name) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + +def check_agent_recall_chat_memory(filename: str): + """ + Checks that the LLM will recall the chat memory, specifically the human persona. + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + # Set up client + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + human_name = "BananaBoy" + agent_state = setup_agent(client, filename, embedding_config_path, memory_human_str=f"My name is {human_name}") + + response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.") + + # Basic checks + assert_sanity_checks(response) + + # Make sure my name was repeated back to me + assert_invoked_send_message_with_keyword(response.messages, human_name) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + +def check_agent_archival_memory_retrieval(filename: str): + """ + Checks that the LLM will execute an archival memory retrieval. + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + # Set up client + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + agent_state = setup_agent(client, filename, embedding_config_path) + secret_word = "banana" + client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!") + + response = client.user_message(agent_id=agent_state.id, message="Search archival memory for the secret word and repeat it back to me.") + + # Basic checks + assert_sanity_checks(response) + + # Make sure archival_memory_search was called + assert_invoked_function_call(response.messages, "archival_memory_search") + + # Make sure secret was repeated back to me + assert_invoked_send_message_with_keyword(response.messages, secret_word) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + def run_embedding_endpoint(filename): # load JSON file config_data = json.load(open(filename, "r")) @@ -79,16 +207,42 @@ def run_embedding_endpoint(filename): assert query_vec is not None -def test_llm_endpoint_openai(): +# ====================================================================================================================== +# OPENAI TESTS +# ====================================================================================================================== +def test_openai_gpt_4_returns_valid_first_message(): filename = os.path.join(llm_config_dir, "gpt-4.json") check_first_response_is_valid_for_llm_endpoint(filename) +def test_openai_gpt_4_returns_keyword(): + filename = os.path.join(llm_config_dir, "gpt-4.json") + check_response_contains_keyword(filename) + + +def test_openai_gpt_4_uses_external_tool(): + filename = os.path.join(llm_config_dir, "gpt-4.json") + check_agent_uses_external_tool(filename) + + +def test_openai_gpt_4_recall_chat_memory(): + filename = os.path.join(llm_config_dir, "gpt-4.json") + check_agent_recall_chat_memory(filename) + + +def test_openai_gpt_4_archival_memory_retrieval(): + filename = os.path.join(llm_config_dir, "gpt-4.json") + check_agent_archival_memory_retrieval(filename) + + def test_embedding_endpoint_openai(): filename = os.path.join(embedding_config_dir, "text-embedding-ada-002.json") run_embedding_endpoint(filename) +# ====================================================================================================================== +# LETTA HOSTED +# ====================================================================================================================== def test_llm_endpoint_letta_hosted(): filename = os.path.join(llm_config_dir, "letta-hosted.json") check_first_response_is_valid_for_llm_endpoint(filename) @@ -99,6 +253,9 @@ def test_embedding_endpoint_letta_hosted(): run_embedding_endpoint(filename) +# ====================================================================================================================== +# LOCAL MODELS +# ====================================================================================================================== def test_embedding_endpoint_local(): filename = os.path.join(embedding_config_dir, "local.json") run_embedding_endpoint(filename) @@ -114,6 +271,9 @@ def test_embedding_endpoint_ollama(): run_embedding_endpoint(filename) +# ====================================================================================================================== +# ANTHROPIC TESTS +# ====================================================================================================================== def test_llm_endpoint_anthropic(): filename = os.path.join(llm_config_dir, "anthropic.json") check_first_response_is_valid_for_llm_endpoint(filename)