fix: structured outputs for send_message, LettaMessage
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.constants import MAX_FILENAME_LENGTH
|
||||
@@ -7,7 +5,7 @@ from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_fun
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.services.file_processor.chunker.line_chunker import LineChunker
|
||||
from letta.services.helpers.agent_manager_helper import safe_format
|
||||
from letta.utils import sanitize_filename
|
||||
from letta.utils import sanitize_filename, validate_function_response
|
||||
|
||||
CORE_MEMORY_VAR = "My core memory is that I like to eat bananas"
|
||||
VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR}
|
||||
@@ -555,3 +553,122 @@ async def test_get_latest_alembic_revision_consistency(event_loop):
|
||||
|
||||
# They should be identical
|
||||
assert revision_id1 == revision_id2
|
||||
|
||||
|
||||
# ---------------------- validate_function_response TESTS ---------------------- #
|
||||
|
||||
|
||||
def test_validate_function_response_string_input():
|
||||
"""Test that string inputs are returned unchanged when within limit"""
|
||||
response = validate_function_response("hello world", return_char_limit=100)
|
||||
assert response == "hello world"
|
||||
|
||||
|
||||
def test_validate_function_response_none_input():
|
||||
"""Test that None inputs are converted to 'None' string"""
|
||||
response = validate_function_response(None, return_char_limit=100)
|
||||
assert response == "None"
|
||||
|
||||
|
||||
def test_validate_function_response_dict_input():
|
||||
"""Test that dict inputs are JSON serialized"""
|
||||
test_dict = {"key": "value", "number": 42}
|
||||
response = validate_function_response(test_dict, return_char_limit=100)
|
||||
# Response should be valid JSON string
|
||||
import json
|
||||
|
||||
parsed = json.loads(response)
|
||||
assert parsed == test_dict
|
||||
|
||||
|
||||
def test_validate_function_response_other_types():
|
||||
"""Test that other types are converted to strings"""
|
||||
# Test integer
|
||||
response = validate_function_response(42, return_char_limit=100)
|
||||
assert response == "42"
|
||||
|
||||
# Test list
|
||||
response = validate_function_response([1, 2, 3], return_char_limit=100)
|
||||
assert response == "[1, 2, 3]"
|
||||
|
||||
# Test boolean
|
||||
response = validate_function_response(True, return_char_limit=100)
|
||||
assert response == "True"
|
||||
|
||||
|
||||
def test_validate_function_response_strict_mode_string():
|
||||
"""Test strict mode allows strings"""
|
||||
response = validate_function_response("test", return_char_limit=100, strict=True)
|
||||
assert response == "test"
|
||||
|
||||
|
||||
def test_validate_function_response_strict_mode_none():
|
||||
"""Test strict mode allows None"""
|
||||
response = validate_function_response(None, return_char_limit=100, strict=True)
|
||||
assert response == "None"
|
||||
|
||||
|
||||
def test_validate_function_response_strict_mode_violation():
|
||||
"""Test strict mode raises ValueError for non-string/None types"""
|
||||
with pytest.raises(ValueError, match="Strict mode violation. Function returned type: int"):
|
||||
validate_function_response(42, return_char_limit=100, strict=True)
|
||||
|
||||
with pytest.raises(ValueError, match="Strict mode violation. Function returned type: dict"):
|
||||
validate_function_response({"key": "value"}, return_char_limit=100, strict=True)
|
||||
|
||||
|
||||
def test_validate_function_response_truncation():
|
||||
"""Test that long responses are truncated when truncate=True"""
|
||||
long_string = "a" * 200
|
||||
response = validate_function_response(long_string, return_char_limit=50, truncate=True)
|
||||
assert len(response) > 50 # Should include truncation message
|
||||
assert response.startswith("a" * 50)
|
||||
assert "NOTE: function output was truncated" in response
|
||||
assert "200 > 50" in response
|
||||
|
||||
|
||||
def test_validate_function_response_no_truncation():
|
||||
"""Test that long responses are not truncated when truncate=False"""
|
||||
long_string = "a" * 200
|
||||
response = validate_function_response(long_string, return_char_limit=50, truncate=False)
|
||||
assert response == long_string
|
||||
assert len(response) == 200
|
||||
|
||||
|
||||
def test_validate_function_response_exact_limit():
|
||||
"""Test response exactly at the character limit"""
|
||||
exact_string = "a" * 50
|
||||
response = validate_function_response(exact_string, return_char_limit=50, truncate=True)
|
||||
assert response == exact_string
|
||||
|
||||
|
||||
def test_validate_function_response_complex_dict():
|
||||
"""Test with complex nested dictionary"""
|
||||
complex_dict = {"nested": {"key": "value"}, "list": [1, 2, {"inner": "dict"}], "null": None, "bool": True}
|
||||
response = validate_function_response(complex_dict, return_char_limit=1000)
|
||||
# Should be valid JSON
|
||||
import json
|
||||
|
||||
parsed = json.loads(response)
|
||||
assert parsed == complex_dict
|
||||
|
||||
|
||||
def test_validate_function_response_dict_truncation():
|
||||
"""Test that serialized dict gets truncated properly"""
|
||||
# Create a dict that when serialized will exceed limit
|
||||
large_dict = {"data": "x" * 100}
|
||||
response = validate_function_response(large_dict, return_char_limit=20, truncate=True)
|
||||
assert "NOTE: function output was truncated" in response
|
||||
assert len(response) > 20 # Includes truncation message
|
||||
|
||||
|
||||
def test_validate_function_response_empty_string():
|
||||
"""Test empty string handling"""
|
||||
response = validate_function_response("", return_char_limit=100)
|
||||
assert response == ""
|
||||
|
||||
|
||||
def test_validate_function_response_whitespace():
|
||||
"""Test whitespace-only string handling"""
|
||||
response = validate_function_response(" \n\t ", return_char_limit=100)
|
||||
assert response == " \n\t "
|
||||
|
||||
Reference in New Issue
Block a user