diff --git a/letta/errors.py b/letta/errors.py index 6ac70181..45c38090 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -1,8 +1,19 @@ -class LLMError(Exception): - """Base class for all LLM-related errors.""" +from typing import TYPE_CHECKING + +# Avoid circular imports +if TYPE_CHECKING: + from letta.schemas.message import Message -class LLMJSONParsingError(LLMError): +class LettaError(Exception): + """Base class for all Letta related errors.""" + + +class LLMError(LettaError): + pass + + +class LLMJSONParsingError(LettaError): """Exception raised for errors in the JSON parsing process.""" def __init__(self, message="Error parsing JSON generated by LLM"): @@ -10,7 +21,7 @@ class LLMJSONParsingError(LLMError): super().__init__(self.message) -class LocalLLMError(LLMError): +class LocalLLMError(LettaError): """Generic catch-all error for local LLM problems""" def __init__(self, message="Encountered an error while running local LLM"): @@ -18,9 +29,81 @@ class LocalLLMError(LLMError): super().__init__(self.message) -class LocalLLMConnectionError(LLMError): +class LocalLLMConnectionError(LettaError): """Error for when local LLM cannot be reached with provided IP/port""" def __init__(self, message="Could not connect to local LLM"): self.message = message super().__init__(self.message) + + +class MissingFunctionCallError(LettaError): + message: "Message" + """ The message that caused this error. + + This error should be raised when a message that we expect to have a function call does not. + """ + + def __init__(self, *, message: "Message") -> None: + error_msg = "The message is missing a function call: \n\n" + + # Pretty print out message + message_json = message.model_dump_json(indent=4) + error_msg += f"{message_json}" + + super().__init__(error_msg) + self.message = message + + +class InvalidFunctionCallError(LettaError): + message: "Message" + """ The message that caused this error. + + This error should be raised when a message uses a function that is unexpected or invalid, or if the usage is incorrect. + """ + + def __init__(self, *, message: "Message") -> None: + error_msg = "The message uses an invalid function call or has improper usage of a function call: \n\n" + + # Pretty print out message + message_json = message.model_dump_json(indent=4) + error_msg += f"{message_json}" + + super().__init__(error_msg) + self.message = message + + +class MissingInnerMonologueError(LettaError): + message: "Message" + """ The message that caused this error. + + This error should be raised when a message that we expect to have an inner monologue does not. + """ + + def __init__(self, *, message: "Message") -> None: + error_msg = "The message is missing an inner monologue: \n\n" + + # Pretty print out message + message_json = message.model_dump_json(indent=4) + error_msg += f"{message_json}" + + super().__init__(error_msg) + self.message = message + + +class InvalidInnerMonologueError(LettaError): + message: "Message" + """ The message that caused this error. + + This error should be raised when a message has an improperly formatted inner monologue. + """ + + def __init__(self, *, message: "Message") -> None: + error_msg = "The message has a malformed inner monologue: \n\n" + + # Pretty print out message + message_json = message.model_dump_json(indent=4) + error_msg += f"{message_json}" + + super().__init__(error_msg) + self.message = message diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 93753a55..7ff9193d 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -37,7 +37,7 @@ from letta.schemas.openai.chat_completion_request import ( Tool, cast_message_to_subtype, ) -from letta.schemas.openai.chat_completion_response import ChatCompletionResponse +from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice from letta.streaming_interface import ( AgentChunkStreamingInterface, AgentRefreshStreamingInterface, @@ -83,7 +83,7 @@ def add_inner_thoughts_to_functions( return new_functions -def unpack_inner_thoughts_from_kwargs( +def unpack_all_inner_thoughts_from_kwargs( response: ChatCompletionResponse, inner_thoughts_key: str, ) -> ChatCompletionResponse: @@ -93,36 +93,7 @@ def unpack_inner_thoughts_from_kwargs( new_choices = [] for choice in response.choices: - msg = choice.message - if msg.role == "assistant" and msg.tool_calls and len(msg.tool_calls) >= 1: - if len(msg.tool_calls) > 1: - warnings.warn(f"Unpacking inner thoughts from more than one tool call ({len(msg.tool_calls)}) is not supported") - # TODO support multiple tool calls - tool_call = msg.tool_calls[0] - - try: - # Sadly we need to parse the JSON since args are in string format - func_args = dict(json.loads(tool_call.function.arguments)) - if inner_thoughts_key in func_args: - # extract the inner thoughts - inner_thoughts = func_args.pop(inner_thoughts_key) - - # replace the kwargs - new_choice = choice.model_copy(deep=True) - new_choice.message.tool_calls[0].function.arguments = json_dumps(func_args) - # also replace the message content - if new_choice.message.content is not None: - warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})") - new_choice.message.content = inner_thoughts - - # save copy - new_choices.append(new_choice) - else: - warnings.warn(f"Did not find inner thoughts in tool call: {str(tool_call)}") - - except json.JSONDecodeError as e: - warnings.warn(f"Failed to strip inner thoughts from kwargs: {e}") - raise e + new_choices.append(unpack_inner_thoughts_from_kwargs(choice, inner_thoughts_key)) # return an updated copy new_response = response.model_copy(deep=True) @@ -130,6 +101,38 @@ def unpack_inner_thoughts_from_kwargs( return new_response +def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -> Choice: + message = choice.message + if message.role == "assistant" and message.tool_calls and len(message.tool_calls) >= 1: + if len(message.tool_calls) > 1: + warnings.warn(f"Unpacking inner thoughts from more than one tool call ({len(message.tool_calls)}) is not supported") + # TODO support multiple tool calls + tool_call = message.tool_calls[0] + + try: + # Sadly we need to parse the JSON since args are in string format + func_args = dict(json.loads(tool_call.function.arguments)) + if inner_thoughts_key in func_args: + # extract the inner thoughts + inner_thoughts = func_args.pop(inner_thoughts_key) + + # replace the kwargs + new_choice = choice.model_copy(deep=True) + new_choice.message.tool_calls[0].function.arguments = json_dumps(func_args) + # also replace the message content + if new_choice.message.content is not None: + warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})") + new_choice.message.content = inner_thoughts + + return new_choice + else: + warnings.warn(f"Did not find inner thoughts in tool call: {str(tool_call)}") + + except json.JSONDecodeError as e: + warnings.warn(f"Failed to strip inner thoughts from kwargs: {e}") + raise e + + def is_context_overflow_error(exception: requests.exceptions.RequestException) -> bool: """Checks if an exception is due to context overflow (based on common OpenAI response messages)""" from letta.utils import printd @@ -343,7 +346,7 @@ def create( stream_inferface.stream_end() if inner_thoughts_in_kwargs: - response = unpack_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) + response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) return response diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py new file mode 100644 index 00000000..7d9578f1 --- /dev/null +++ b/tests/helpers/endpoints_helper.py @@ -0,0 +1,71 @@ +import json +from typing import Callable, Optional + +from letta.config import LettaConfig +from letta.errors import ( + InvalidFunctionCallError, + InvalidInnerMonologueError, + MissingFunctionCallError, + MissingInnerMonologueError, +) +from letta.llm_api.llm_api_tools import unpack_inner_thoughts_from_kwargs +from letta.local_llm.constants import INNER_THOUGHTS_KWARG +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message + + +def setup_llm_endpoint(filename: str, embedding_config_path: str) -> [LLMConfig, EmbeddingConfig]: + config_data = json.load(open(filename, "r")) + llm_config = LLMConfig(**config_data) + embedding_config = EmbeddingConfig(**json.load(open(embedding_config_path))) + + # setup config + config = LettaConfig() + config.default_llm_config = llm_config + config.default_embedding_config = embedding_config + config.save() + + return llm_config, embedding_config + + +def assert_contains_valid_function_call(message: Message, function_call_validator: Optional[Callable[[FunctionCall], bool]] = None) -> None: + """ + Helper function to check that a message contains a valid function call. + + There is an Optional parameter `function_call_validator` that specifies a validator function. + This function gets called on the resulting function_call to validate the function is what we expect. + """ + if (hasattr(message, "function_call") and message.function_call is not None) and ( + hasattr(message, "tool_calls") and message.tool_calls is not None + ): + return False + elif hasattr(message, "function_call") and message.function_call is not None: + function_call = message.function_call + elif hasattr(message, "tool_calls") and message.tool_calls is not None: + function_call = message.tool_calls[0].function + else: + # Throw a missing function call error + raise MissingFunctionCallError(message=message) + + if function_call_validator and not function_call_validator(function_call): + raise InvalidFunctionCallError(message=message) + + +def inner_monologue_is_valid(monologue: str) -> bool: + invalid_chars = '(){}[]"' + # Sometimes the syntax won't be correct and internal syntax will leak into message + invalid_phrases = ["functions", "send_message"] + + return any(char in monologue for char in invalid_chars) or any(p in monologue for p in invalid_phrases) + + +def assert_contains_correct_inner_monologue(choice: Choice, inner_thoughts_in_kwargs: bool) -> None: + if inner_thoughts_in_kwargs: + choice = unpack_inner_thoughts_from_kwargs(choice, INNER_THOUGHTS_KWARG) + + monologue = choice.message.content + if not monologue or monologue is None or monologue == "": + raise MissingInnerMonologueError(message=choice.message) + elif not inner_monologue_is_valid(monologue): + raise InvalidInnerMonologueError(message=choice.message) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py new file mode 100644 index 00000000..c5f1c156 --- /dev/null +++ b/tests/helpers/utils.py @@ -0,0 +1,11 @@ +from typing import Union + +from letta import LocalClient, RESTClient + + +def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str): + # Clear all agents + for agent_state in client.list_agents(): + if agent_state.name == agent_uuid: + client.delete_agent(agent_id=agent_state.id) + print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 9bb4fe16..5f65f84e 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -4,13 +4,17 @@ import uuid from letta import create_client from letta.agent import Agent -from letta.config import LettaConfig from letta.embeddings import embedding_model from letta.llm_api.llm_api_tools import create from letta.prompts import gpt_system from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message +from tests.helpers.endpoints_helper import ( + assert_contains_correct_inner_monologue, + assert_contains_valid_function_call, + setup_llm_endpoint, +) +from tests.helpers.utils import cleanup messages = [Message(role="system", text=gpt_system.get_system_text("memgpt_chat")), Message(role="user", text="How are you?")] @@ -27,36 +31,17 @@ namespace = uuid.NAMESPACE_DNS agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent")) -def cleanup(client): - # Clear all agents - for agent_state in client.list_agents(): - if agent_state.name == agent_uuid: - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") - - -def run_llm_endpoint(filename): - config_data = json.load(open(filename, "r")) - print(config_data) - llm_config = LLMConfig(**config_data) - embedding_config = EmbeddingConfig(**json.load(open(embedding_config_path))) - - # setup config - config = LettaConfig() - config.default_llm_config = llm_config - config.default_embedding_config = embedding_config - config.save() +def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False): + llm_config, embedding_config = setup_llm_endpoint(filename, embedding_config_path) client = create_client() - cleanup(client) + cleanup(client=client, agent_uuid=agent_uuid) agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config) tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools] agent = Agent( interface=None, tools=tools, agent_state=agent_state, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True, ) response = create( @@ -67,10 +52,20 @@ def run_llm_endpoint(filename): functions=agent.functions, functions_python=agent.functions_python, ) - client.delete_agent(agent_state.id) - print(response) + + # Basic check assert response is not None + # Select first choice + choice = response.choices[0] + + # Ensure that the first message returns a "send_message" + validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search" + assert_contains_valid_function_call(choice.message, validator_func) + + # Assert that the choice has an inner monologue + assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs) + def run_embedding_endpoint(filename): # load JSON file @@ -86,7 +81,7 @@ def run_embedding_endpoint(filename): def test_llm_endpoint_openai(): filename = os.path.join(llm_config_dir, "gpt-4.json") - run_llm_endpoint(filename) + check_first_response_is_valid_for_llm_endpoint(filename) def test_embedding_endpoint_openai(): @@ -96,7 +91,7 @@ def test_embedding_endpoint_openai(): def test_llm_endpoint_letta_hosted(): filename = os.path.join(llm_config_dir, "letta-hosted.json") - run_llm_endpoint(filename) + check_first_response_is_valid_for_llm_endpoint(filename) def test_embedding_endpoint_letta_hosted(): @@ -111,7 +106,7 @@ def test_embedding_endpoint_local(): def test_llm_endpoint_ollama(): filename = os.path.join(llm_config_dir, "ollama.json") - run_llm_endpoint(filename) + check_first_response_is_valid_for_llm_endpoint(filename) def test_embedding_endpoint_ollama(): @@ -121,4 +116,4 @@ def test_embedding_endpoint_ollama(): def test_llm_endpoint_anthropic(): filename = os.path.join(llm_config_dir, "anthropic.json") - run_llm_endpoint(filename) + check_first_response_is_valid_for_llm_endpoint(filename)