diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 86111b25..bc826275 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1224,3 +1224,224 @@ def test_update_tool_source_code_duplicate_name_error(client: LettaSDKClient): # Clean up both tools client.tools.delete(tool_id=tool1.id) client.tools.delete(tool_id=tool2.id) + + +def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient): + """Test adding a tool with multiple functions in the source code""" + import textwrap + + # Define source code with multiple functions + source_code = textwrap.dedent( + """ + def helper_function(x: int) -> int: + ''' + Helper function that doubles the input + + Args: + x: The input number + + Returns: + The input multiplied by 2 + ''' + return x * 2 + + def another_helper(text: str) -> str: + ''' + Another helper that uppercases text + + Args: + text: The input text to uppercase + + Returns: + The uppercased text + ''' + return text.upper() + + def main_function(x: int, y: int) -> int: + ''' + Main function that uses the helper + + Args: + x: First number + y: Second number + + Returns: + Result of (x * 2) + y + ''' + doubled_x = helper_function(x) + return doubled_x + y + """ + ).strip() + + # Create the tool with multiple functions + tool = client.tools.create( + source_code=source_code, + ) + + try: + # Verify the tool was created + assert tool is not None + assert tool.name == "main_function" + assert tool.source_code == source_code + + # Verify the JSON schema was generated for the main function + assert tool.json_schema is not None + assert tool.json_schema["name"] == "main_function" + assert tool.json_schema["description"] == "Main function that uses the helper" + + # Check parameters + params = tool.json_schema.get("parameters", {}) + properties = params.get("properties", {}) + assert "x" in properties + assert "y" in properties + assert properties["x"]["type"] == "integer" + assert properties["y"]["type"] == "integer" + assert params["required"] == ["x", "y"] + + # Test that we can retrieve the tool + retrieved_tool = client.tools.retrieve(tool_id=tool.id) + assert retrieved_tool.name == "main_function" + assert retrieved_tool.source_code == source_code + + finally: + # Clean up + client.tools.delete(tool_id=tool.id) + + +def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient): + """Test that tool name auto-updates when source code changes with multiple functions""" + import textwrap + + # Initial source code with multiple functions + initial_source_code = textwrap.dedent( + """ + def helper_function(x: int) -> int: + ''' + Helper function that doubles the input + + Args: + x: The input number + + Returns: + The input multiplied by 2 + ''' + return x * 2 + + def another_helper(text: str) -> str: + ''' + Another helper that uppercases text + + Args: + text: The input text to uppercase + + Returns: + The uppercased text + ''' + return text.upper() + + def main_function(x: int, y: int) -> int: + ''' + Main function that uses the helper + + Args: + x: First number + y: Second number + + Returns: + Result of (x * 2) + y + ''' + doubled_x = helper_function(x) + return doubled_x + y + """ + ).strip() + + # Create tool with initial source code + tool = client.tools.create( + source_code=initial_source_code, + ) + + try: + # Verify the tool was created with the last function's name + assert tool is not None + assert tool.name == "main_function" + assert tool.source_code == initial_source_code + + # Now modify the source code with a different function order + new_source_code = textwrap.dedent( + """ + def process_data(data: str, count: int) -> str: + ''' + Process data by repeating it + + Args: + data: The input data + count: Number of times to repeat + + Returns: + The processed data + ''' + return data * count + + def helper_utility(x: float) -> float: + ''' + Helper utility function + + Args: + x: Input value + + Returns: + Squared value + ''' + return x * x + """ + ).strip() + + # Modify the tool with new source code + modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code) + + # Verify the name automatically updated to the last function + assert modified_tool.name == "helper_utility" + assert modified_tool.source_code == new_source_code + + # Verify the JSON schema updated correctly + assert modified_tool.json_schema is not None + assert modified_tool.json_schema["name"] == "helper_utility" + assert modified_tool.json_schema["description"] == "Helper utility function" + + # Check parameters updated correctly + params = modified_tool.json_schema.get("parameters", {}) + properties = params.get("properties", {}) + assert "x" in properties + assert properties["x"]["type"] == "number" # float maps to number + assert params["required"] == ["x"] + + # Test one more modification with only one function + single_function_code = textwrap.dedent( + """ + def calculate_total(items: list, tax_rate: float) -> float: + ''' + Calculate total with tax + + Args: + items: List of item prices + tax_rate: Tax rate as decimal + + Returns: + Total including tax + ''' + subtotal = sum(items) + return subtotal * (1 + tax_rate) + """ + ).strip() + + # Modify again + final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code) + + # Verify name updated again + assert final_tool.name == "calculate_total" + assert final_tool.source_code == single_function_code + assert final_tool.json_schema["description"] == "Calculate total with tax" + + finally: + # Clean up + client.tools.delete(tool_id=tool.id)