diff --git a/letta/helpers/datetime_helpers.py b/letta/helpers/datetime_helpers.py index 0633add3..07856495 100644 --- a/letta/helpers/datetime_helpers.py +++ b/letta/helpers/datetime_helpers.py @@ -42,7 +42,7 @@ def get_local_time_timezone(timezone=DEFAULT_TIMEZONE): return formatted_time -def get_local_time(timezone=DEFAULT_TIMEZONE): +def get_local_time(timezone: str | None = DEFAULT_TIMEZONE): if timezone is not None: time_str = get_local_time_timezone(timezone) else: diff --git a/letta/helpers/json_helpers.py b/letta/helpers/json_helpers.py index 2075c58d..6618f274 100644 --- a/letta/helpers/json_helpers.py +++ b/letta/helpers/json_helpers.py @@ -7,7 +7,7 @@ def json_loads(data): return json.loads(data, strict=False) -def json_dumps(data, indent=2): +def json_dumps(data, indent=2) -> str: def safe_serializer(obj): if isinstance(obj, datetime): return obj.isoformat() diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 73bf73db..8c15bead 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -41,7 +41,7 @@ from letta.schemas.letta_message_content import ( get_letta_message_content_union_str_json_schema, ) from letta.system import unpack_message -from letta.utils import parse_json +from letta.utils import parse_json, validate_function_response def add_inner_thoughts_to_tool_call( @@ -251,10 +251,10 @@ class Message(BaseMessage): include_err: Optional[bool] = None, ) -> List[LettaMessage]: """Convert message object (in DB format) to the style used by the original Letta API""" - messages = [] + # TODO (cliandy): break this into more manageable pieces if self.role == MessageRole.assistant: - + messages = [] # Handle reasoning if self.content: # Check for ReACT-style COT inside of TextContent @@ -348,7 +348,7 @@ class Message(BaseMessage): # We need to unpack the actual message contents from the function call try: func_args = parse_json(tool_call.function.arguments) - message_string = func_args[assistant_message_tool_kwarg] + message_string = validate_function_response(func_args[assistant_message_tool_kwarg], 0, truncate=False) except KeyError: raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument") messages.append( @@ -380,99 +380,106 @@ class Message(BaseMessage): is_err=self.is_err, ) ) + elif self.role == MessageRole.tool: - # This is type ToolReturnMessage - # Try to interpret the function return, recall that this is how we packaged: - # def package_function_response(was_success, response_string, timestamp=None): - # formatted_time = get_local_time() if timestamp is None else timestamp - # packaged_message = { - # "status": "OK" if was_success else "Failed", - # "message": response_string, - # "time": formatted_time, - # } - if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): - text_content = self.content[0].text - else: - raise ValueError(f"Invalid tool return (no text object on message): {self.content}") - - try: - function_return = parse_json(text_content) - text_content = str(function_return.get("message", text_content)) - status = function_return["status"] - if status == "OK": - status_enum = "success" - elif status == "Failed": - status_enum = "error" - else: - raise ValueError(f"Invalid status: {status}") - except json.JSONDecodeError: - raise ValueError(f"Failed to decode function return: {text_content}") - assert self.tool_call_id is not None - messages.append( - # TODO make sure this is what the API returns - # function_return may not match exactly... - ToolReturnMessage( - id=self.id, - date=self.created_at, - tool_return=text_content, - status=self.tool_returns[0].status if self.tool_returns else status_enum, - tool_call_id=self.tool_call_id, - stdout=self.tool_returns[0].stdout if self.tool_returns else None, - stderr=self.tool_returns[0].stderr if self.tool_returns else None, - name=self.name, - otid=Message.generate_otid_from_id(self.id, len(messages)), - sender_id=self.sender_id, - step_id=self.step_id, - is_err=self.is_err, - ) - ) + messages = [self._convert_tool_message()] elif self.role == MessageRole.user: - # This is type UserMessage - if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): - text_content = self.content[0].text - elif self.content: - text_content = self.content - else: - raise ValueError(f"Invalid user message (no text object on message): {self.content}") - - message = unpack_message(text_content) - messages.append( - UserMessage( - id=self.id, - date=self.created_at, - content=message, - name=self.name, - otid=self.otid, - sender_id=self.sender_id, - step_id=self.step_id, - is_err=self.is_err, - ) - ) + messages = [self._convert_user_message()] elif self.role == MessageRole.system: - # This is type SystemMessage - if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): - text_content = self.content[0].text - else: - raise ValueError(f"Invalid system message (no text object on system): {self.content}") - - messages.append( - SystemMessage( - id=self.id, - date=self.created_at, - content=text_content, - name=self.name, - otid=self.otid, - sender_id=self.sender_id, - step_id=self.step_id, - ) - ) + messages = [self._convert_system_message()] else: - raise ValueError(self.role) + raise ValueError(f"Unknown role: {self.role}") - if reverse: - messages.reverse() + return messages[::-1] if reverse else messages - return messages + def _convert_tool_message(self) -> ToolReturnMessage: + """Convert tool role message to ToolReturnMessage + + the tool return is packaged as follows: + packaged_message = { + "status": "OK" if was_success else "Failed", + "message": response_string, + "time": formatted_time, + } + """ + if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): + text_content = self.content[0].text + else: + raise ValueError(f"Invalid tool return (no text object on message): {self.content}") + + try: + function_return = parse_json(text_content) + message_text = str(function_return.get("message", text_content)) + status = self._parse_tool_status(function_return["status"]) + except json.JSONDecodeError: + raise ValueError(f"Failed to decode function return: {text_content}") + + assert self.tool_call_id is not None + + return ToolReturnMessage( + id=self.id, + date=self.created_at, + tool_return=message_text, + status=self.tool_returns[0].status if self.tool_returns else status, + tool_call_id=self.tool_call_id, + stdout=self.tool_returns[0].stdout if self.tool_returns else None, + stderr=self.tool_returns[0].stderr if self.tool_returns else None, + name=self.name, + otid=Message.generate_otid_from_id(self.id, 0), + sender_id=self.sender_id, + step_id=self.step_id, + is_err=self.is_err, + ) + + @staticmethod + def _parse_tool_status(status: str) -> Literal["success", "error"]: + """Convert tool status string to enum value""" + if status == "OK": + return "success" + elif status == "Failed": + return "error" + else: + raise ValueError(f"Invalid status: {status}") + + def _convert_user_message(self) -> UserMessage: + """Convert user role message to UserMessage""" + # Extract text content + if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): + text_content = self.content[0].text + elif self.content: + text_content = self.content + else: + raise ValueError(f"Invalid user message (no text object on message): {self.content}") + + message = unpack_message(text_content) + + return UserMessage( + id=self.id, + date=self.created_at, + content=message, + name=self.name, + otid=self.otid, + sender_id=self.sender_id, + step_id=self.step_id, + is_err=self.is_err, + ) + + def _convert_system_message(self) -> SystemMessage: + """Convert system role message to SystemMessage""" + if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): + text_content = self.content[0].text + else: + raise ValueError(f"Invalid system message (no text object on system): {self.content}") + + return SystemMessage( + id=self.id, + date=self.created_at, + content=text_content, + name=self.name, + otid=self.otid, + sender_id=self.sender_id, + step_id=self.step_id, + ) @staticmethod def dict_to_message( diff --git a/letta/system.py b/letta/system.py index 888cc304..f64cc9d0 100644 --- a/letta/system.py +++ b/letta/system.py @@ -141,7 +141,7 @@ def package_user_message( return json_dumps(packaged_message) -def package_function_response(was_success, response_string, timezone): +def package_function_response(was_success: bool, response_string: str, timezone: str | None) -> str: formatted_time = get_local_time(timezone=timezone) packaged_message = { "status": "OK" if was_success else "Failed", diff --git a/letta/utils.py b/letta/utils.py index b2f35642..73abb1a3 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -27,7 +27,6 @@ from sqlalchemy import text import letta from letta.constants import ( - CLI_WARNING_PREFIX, CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT, DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT, @@ -851,47 +850,32 @@ def parse_json(string) -> dict: raise e -def validate_function_response(function_response_string: any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str: +def validate_function_response(function_response: Any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str: """Check to make sure that a function used by Letta returned a valid response. Truncates to return_char_limit if necessary. - Responses need to be strings (or None) that fall under a certain text count limit. + This makes sure that we can coerce the function_response into a string that meets our criteria. We handle some soft coercion. + If strict is True, we raise a ValueError if function_response is not a string or None. """ - if not isinstance(function_response_string, str): - # Soft correction for a few basic types + if isinstance(function_response, str): + function_response_string = function_response - if function_response_string is None: - # function_response_string = "Empty (no function output)" - function_response_string = "None" # backcompat + elif function_response is None: + function_response_string = "None" - elif isinstance(function_response_string, dict): - if strict: - # TODO add better error message - raise ValueError(function_response_string) + elif strict: + raise ValueError(f"Strict mode violation. Function returned type: {type(function_response).__name__}") - # Allow dict through since it will be cast to json.dumps() - try: - # TODO find a better way to do this that won't result in double escapes - function_response_string = json_dumps(function_response_string) - except: - raise ValueError(function_response_string) + elif isinstance(function_response, dict): + # As functions can return arbitrary data, if there's already nesting somewhere in the response, it's difficult + # for us to not result in double escapes. + function_response_string = json_dumps(function_response) + else: + logger.debug(f"Function returned type {type(function_response).__name__}. Coercing to string.") + function_response_string = str(function_response) - else: - if strict: - # TODO add better error message - raise ValueError(function_response_string) - - # Try to convert to a string, but throw a warning to alert the user - try: - function_response_string = str(function_response_string) - except: - raise ValueError(function_response_string) - - # Now check the length and make sure it doesn't go over the limit # TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window) if truncate and len(function_response_string) > return_char_limit: - print( - f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated" - ) + logger.warning(f"function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated") function_response_string = f"{function_response_string[:return_char_limit]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {return_char_limit})]" return function_response_string diff --git a/tests/test_utils.py b/tests/test_utils.py index 23658b84..84622882 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from letta.constants import MAX_FILENAME_LENGTH @@ -7,7 +5,7 @@ from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_fun from letta.schemas.file import FileMetadata from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.helpers.agent_manager_helper import safe_format -from letta.utils import sanitize_filename +from letta.utils import sanitize_filename, validate_function_response CORE_MEMORY_VAR = "My core memory is that I like to eat bananas" VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR} @@ -555,3 +553,122 @@ async def test_get_latest_alembic_revision_consistency(event_loop): # They should be identical assert revision_id1 == revision_id2 + + +# ---------------------- validate_function_response TESTS ---------------------- # + + +def test_validate_function_response_string_input(): + """Test that string inputs are returned unchanged when within limit""" + response = validate_function_response("hello world", return_char_limit=100) + assert response == "hello world" + + +def test_validate_function_response_none_input(): + """Test that None inputs are converted to 'None' string""" + response = validate_function_response(None, return_char_limit=100) + assert response == "None" + + +def test_validate_function_response_dict_input(): + """Test that dict inputs are JSON serialized""" + test_dict = {"key": "value", "number": 42} + response = validate_function_response(test_dict, return_char_limit=100) + # Response should be valid JSON string + import json + + parsed = json.loads(response) + assert parsed == test_dict + + +def test_validate_function_response_other_types(): + """Test that other types are converted to strings""" + # Test integer + response = validate_function_response(42, return_char_limit=100) + assert response == "42" + + # Test list + response = validate_function_response([1, 2, 3], return_char_limit=100) + assert response == "[1, 2, 3]" + + # Test boolean + response = validate_function_response(True, return_char_limit=100) + assert response == "True" + + +def test_validate_function_response_strict_mode_string(): + """Test strict mode allows strings""" + response = validate_function_response("test", return_char_limit=100, strict=True) + assert response == "test" + + +def test_validate_function_response_strict_mode_none(): + """Test strict mode allows None""" + response = validate_function_response(None, return_char_limit=100, strict=True) + assert response == "None" + + +def test_validate_function_response_strict_mode_violation(): + """Test strict mode raises ValueError for non-string/None types""" + with pytest.raises(ValueError, match="Strict mode violation. Function returned type: int"): + validate_function_response(42, return_char_limit=100, strict=True) + + with pytest.raises(ValueError, match="Strict mode violation. Function returned type: dict"): + validate_function_response({"key": "value"}, return_char_limit=100, strict=True) + + +def test_validate_function_response_truncation(): + """Test that long responses are truncated when truncate=True""" + long_string = "a" * 200 + response = validate_function_response(long_string, return_char_limit=50, truncate=True) + assert len(response) > 50 # Should include truncation message + assert response.startswith("a" * 50) + assert "NOTE: function output was truncated" in response + assert "200 > 50" in response + + +def test_validate_function_response_no_truncation(): + """Test that long responses are not truncated when truncate=False""" + long_string = "a" * 200 + response = validate_function_response(long_string, return_char_limit=50, truncate=False) + assert response == long_string + assert len(response) == 200 + + +def test_validate_function_response_exact_limit(): + """Test response exactly at the character limit""" + exact_string = "a" * 50 + response = validate_function_response(exact_string, return_char_limit=50, truncate=True) + assert response == exact_string + + +def test_validate_function_response_complex_dict(): + """Test with complex nested dictionary""" + complex_dict = {"nested": {"key": "value"}, "list": [1, 2, {"inner": "dict"}], "null": None, "bool": True} + response = validate_function_response(complex_dict, return_char_limit=1000) + # Should be valid JSON + import json + + parsed = json.loads(response) + assert parsed == complex_dict + + +def test_validate_function_response_dict_truncation(): + """Test that serialized dict gets truncated properly""" + # Create a dict that when serialized will exceed limit + large_dict = {"data": "x" * 100} + response = validate_function_response(large_dict, return_char_limit=20, truncate=True) + assert "NOTE: function output was truncated" in response + assert len(response) > 20 # Includes truncation message + + +def test_validate_function_response_empty_string(): + """Test empty string handling""" + response = validate_function_response("", return_char_limit=100) + assert response == "" + + +def test_validate_function_response_whitespace(): + """Test whitespace-only string handling""" + response = validate_function_response(" \n\t ", return_char_limit=100) + assert response == " \n\t "