fix: structured outputs for send_message, LettaMessage
This commit is contained in:
@@ -42,7 +42,7 @@ def get_local_time_timezone(timezone=DEFAULT_TIMEZONE):
|
||||
return formatted_time
|
||||
|
||||
|
||||
def get_local_time(timezone=DEFAULT_TIMEZONE):
|
||||
def get_local_time(timezone: str | None = DEFAULT_TIMEZONE):
|
||||
if timezone is not None:
|
||||
time_str = get_local_time_timezone(timezone)
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,7 @@ def json_loads(data):
|
||||
return json.loads(data, strict=False)
|
||||
|
||||
|
||||
def json_dumps(data, indent=2):
|
||||
def json_dumps(data, indent=2) -> str:
|
||||
def safe_serializer(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
|
||||
@@ -41,7 +41,7 @@ from letta.schemas.letta_message_content import (
|
||||
get_letta_message_content_union_str_json_schema,
|
||||
)
|
||||
from letta.system import unpack_message
|
||||
from letta.utils import parse_json
|
||||
from letta.utils import parse_json, validate_function_response
|
||||
|
||||
|
||||
def add_inner_thoughts_to_tool_call(
|
||||
@@ -251,10 +251,10 @@ class Message(BaseMessage):
|
||||
include_err: Optional[bool] = None,
|
||||
) -> List[LettaMessage]:
|
||||
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
||||
messages = []
|
||||
|
||||
# TODO (cliandy): break this into more manageable pieces
|
||||
if self.role == MessageRole.assistant:
|
||||
|
||||
messages = []
|
||||
# Handle reasoning
|
||||
if self.content:
|
||||
# Check for ReACT-style COT inside of TextContent
|
||||
@@ -348,7 +348,7 @@ class Message(BaseMessage):
|
||||
# We need to unpack the actual message contents from the function call
|
||||
try:
|
||||
func_args = parse_json(tool_call.function.arguments)
|
||||
message_string = func_args[assistant_message_tool_kwarg]
|
||||
message_string = validate_function_response(func_args[assistant_message_tool_kwarg], 0, truncate=False)
|
||||
except KeyError:
|
||||
raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument")
|
||||
messages.append(
|
||||
@@ -380,99 +380,106 @@ class Message(BaseMessage):
|
||||
is_err=self.is_err,
|
||||
)
|
||||
)
|
||||
|
||||
elif self.role == MessageRole.tool:
|
||||
# This is type ToolReturnMessage
|
||||
# Try to interpret the function return, recall that this is how we packaged:
|
||||
# def package_function_response(was_success, response_string, timestamp=None):
|
||||
# formatted_time = get_local_time() if timestamp is None else timestamp
|
||||
# packaged_message = {
|
||||
# "status": "OK" if was_success else "Failed",
|
||||
# "message": response_string,
|
||||
# "time": formatted_time,
|
||||
# }
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
raise ValueError(f"Invalid tool return (no text object on message): {self.content}")
|
||||
|
||||
try:
|
||||
function_return = parse_json(text_content)
|
||||
text_content = str(function_return.get("message", text_content))
|
||||
status = function_return["status"]
|
||||
if status == "OK":
|
||||
status_enum = "success"
|
||||
elif status == "Failed":
|
||||
status_enum = "error"
|
||||
else:
|
||||
raise ValueError(f"Invalid status: {status}")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode function return: {text_content}")
|
||||
assert self.tool_call_id is not None
|
||||
messages.append(
|
||||
# TODO make sure this is what the API returns
|
||||
# function_return may not match exactly...
|
||||
ToolReturnMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
tool_return=text_content,
|
||||
status=self.tool_returns[0].status if self.tool_returns else status_enum,
|
||||
tool_call_id=self.tool_call_id,
|
||||
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
|
||||
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
|
||||
name=self.name,
|
||||
otid=Message.generate_otid_from_id(self.id, len(messages)),
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
is_err=self.is_err,
|
||||
)
|
||||
)
|
||||
messages = [self._convert_tool_message()]
|
||||
elif self.role == MessageRole.user:
|
||||
# This is type UserMessage
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
elif self.content:
|
||||
text_content = self.content
|
||||
else:
|
||||
raise ValueError(f"Invalid user message (no text object on message): {self.content}")
|
||||
|
||||
message = unpack_message(text_content)
|
||||
messages.append(
|
||||
UserMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
content=message,
|
||||
name=self.name,
|
||||
otid=self.otid,
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
is_err=self.is_err,
|
||||
)
|
||||
)
|
||||
messages = [self._convert_user_message()]
|
||||
elif self.role == MessageRole.system:
|
||||
# This is type SystemMessage
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
raise ValueError(f"Invalid system message (no text object on system): {self.content}")
|
||||
|
||||
messages.append(
|
||||
SystemMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
content=text_content,
|
||||
name=self.name,
|
||||
otid=self.otid,
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
)
|
||||
messages = [self._convert_system_message()]
|
||||
else:
|
||||
raise ValueError(self.role)
|
||||
raise ValueError(f"Unknown role: {self.role}")
|
||||
|
||||
if reverse:
|
||||
messages.reverse()
|
||||
return messages[::-1] if reverse else messages
|
||||
|
||||
return messages
|
||||
def _convert_tool_message(self) -> ToolReturnMessage:
|
||||
"""Convert tool role message to ToolReturnMessage
|
||||
|
||||
the tool return is packaged as follows:
|
||||
packaged_message = {
|
||||
"status": "OK" if was_success else "Failed",
|
||||
"message": response_string,
|
||||
"time": formatted_time,
|
||||
}
|
||||
"""
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
raise ValueError(f"Invalid tool return (no text object on message): {self.content}")
|
||||
|
||||
try:
|
||||
function_return = parse_json(text_content)
|
||||
message_text = str(function_return.get("message", text_content))
|
||||
status = self._parse_tool_status(function_return["status"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode function return: {text_content}")
|
||||
|
||||
assert self.tool_call_id is not None
|
||||
|
||||
return ToolReturnMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
tool_return=message_text,
|
||||
status=self.tool_returns[0].status if self.tool_returns else status,
|
||||
tool_call_id=self.tool_call_id,
|
||||
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
|
||||
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
|
||||
name=self.name,
|
||||
otid=Message.generate_otid_from_id(self.id, 0),
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
is_err=self.is_err,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_status(status: str) -> Literal["success", "error"]:
|
||||
"""Convert tool status string to enum value"""
|
||||
if status == "OK":
|
||||
return "success"
|
||||
elif status == "Failed":
|
||||
return "error"
|
||||
else:
|
||||
raise ValueError(f"Invalid status: {status}")
|
||||
|
||||
def _convert_user_message(self) -> UserMessage:
|
||||
"""Convert user role message to UserMessage"""
|
||||
# Extract text content
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
elif self.content:
|
||||
text_content = self.content
|
||||
else:
|
||||
raise ValueError(f"Invalid user message (no text object on message): {self.content}")
|
||||
|
||||
message = unpack_message(text_content)
|
||||
|
||||
return UserMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
content=message,
|
||||
name=self.name,
|
||||
otid=self.otid,
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
is_err=self.is_err,
|
||||
)
|
||||
|
||||
def _convert_system_message(self) -> SystemMessage:
|
||||
"""Convert system role message to SystemMessage"""
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
raise ValueError(f"Invalid system message (no text object on system): {self.content}")
|
||||
|
||||
return SystemMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
content=text_content,
|
||||
name=self.name,
|
||||
otid=self.otid,
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def dict_to_message(
|
||||
|
||||
@@ -141,7 +141,7 @@ def package_user_message(
|
||||
return json_dumps(packaged_message)
|
||||
|
||||
|
||||
def package_function_response(was_success, response_string, timezone):
|
||||
def package_function_response(was_success: bool, response_string: str, timezone: str | None) -> str:
|
||||
formatted_time = get_local_time(timezone=timezone)
|
||||
packaged_message = {
|
||||
"status": "OK" if was_success else "Failed",
|
||||
|
||||
@@ -27,7 +27,6 @@ from sqlalchemy import text
|
||||
|
||||
import letta
|
||||
from letta.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
||||
@@ -851,47 +850,32 @@ def parse_json(string) -> dict:
|
||||
raise e
|
||||
|
||||
|
||||
def validate_function_response(function_response_string: any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str:
|
||||
def validate_function_response(function_response: Any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str:
|
||||
"""Check to make sure that a function used by Letta returned a valid response. Truncates to return_char_limit if necessary.
|
||||
|
||||
Responses need to be strings (or None) that fall under a certain text count limit.
|
||||
This makes sure that we can coerce the function_response into a string that meets our criteria. We handle some soft coercion.
|
||||
If strict is True, we raise a ValueError if function_response is not a string or None.
|
||||
"""
|
||||
if not isinstance(function_response_string, str):
|
||||
# Soft correction for a few basic types
|
||||
if isinstance(function_response, str):
|
||||
function_response_string = function_response
|
||||
|
||||
if function_response_string is None:
|
||||
# function_response_string = "Empty (no function output)"
|
||||
function_response_string = "None" # backcompat
|
||||
elif function_response is None:
|
||||
function_response_string = "None"
|
||||
|
||||
elif isinstance(function_response_string, dict):
|
||||
if strict:
|
||||
# TODO add better error message
|
||||
raise ValueError(function_response_string)
|
||||
elif strict:
|
||||
raise ValueError(f"Strict mode violation. Function returned type: {type(function_response).__name__}")
|
||||
|
||||
# Allow dict through since it will be cast to json.dumps()
|
||||
try:
|
||||
# TODO find a better way to do this that won't result in double escapes
|
||||
function_response_string = json_dumps(function_response_string)
|
||||
except:
|
||||
raise ValueError(function_response_string)
|
||||
elif isinstance(function_response, dict):
|
||||
# As functions can return arbitrary data, if there's already nesting somewhere in the response, it's difficult
|
||||
# for us to not result in double escapes.
|
||||
function_response_string = json_dumps(function_response)
|
||||
else:
|
||||
logger.debug(f"Function returned type {type(function_response).__name__}. Coercing to string.")
|
||||
function_response_string = str(function_response)
|
||||
|
||||
else:
|
||||
if strict:
|
||||
# TODO add better error message
|
||||
raise ValueError(function_response_string)
|
||||
|
||||
# Try to convert to a string, but throw a warning to alert the user
|
||||
try:
|
||||
function_response_string = str(function_response_string)
|
||||
except:
|
||||
raise ValueError(function_response_string)
|
||||
|
||||
# Now check the length and make sure it doesn't go over the limit
|
||||
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
|
||||
if truncate and len(function_response_string) > return_char_limit:
|
||||
print(
|
||||
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated"
|
||||
)
|
||||
logger.warning(f"function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated")
|
||||
function_response_string = f"{function_response_string[:return_char_limit]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {return_char_limit})]"
|
||||
|
||||
return function_response_string
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.constants import MAX_FILENAME_LENGTH
|
||||
@@ -7,7 +5,7 @@ from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_fun
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.services.file_processor.chunker.line_chunker import LineChunker
|
||||
from letta.services.helpers.agent_manager_helper import safe_format
|
||||
from letta.utils import sanitize_filename
|
||||
from letta.utils import sanitize_filename, validate_function_response
|
||||
|
||||
CORE_MEMORY_VAR = "My core memory is that I like to eat bananas"
|
||||
VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR}
|
||||
@@ -555,3 +553,122 @@ async def test_get_latest_alembic_revision_consistency(event_loop):
|
||||
|
||||
# They should be identical
|
||||
assert revision_id1 == revision_id2
|
||||
|
||||
|
||||
# ---------------------- validate_function_response TESTS ---------------------- #
|
||||
|
||||
|
||||
def test_validate_function_response_string_input():
|
||||
"""Test that string inputs are returned unchanged when within limit"""
|
||||
response = validate_function_response("hello world", return_char_limit=100)
|
||||
assert response == "hello world"
|
||||
|
||||
|
||||
def test_validate_function_response_none_input():
|
||||
"""Test that None inputs are converted to 'None' string"""
|
||||
response = validate_function_response(None, return_char_limit=100)
|
||||
assert response == "None"
|
||||
|
||||
|
||||
def test_validate_function_response_dict_input():
|
||||
"""Test that dict inputs are JSON serialized"""
|
||||
test_dict = {"key": "value", "number": 42}
|
||||
response = validate_function_response(test_dict, return_char_limit=100)
|
||||
# Response should be valid JSON string
|
||||
import json
|
||||
|
||||
parsed = json.loads(response)
|
||||
assert parsed == test_dict
|
||||
|
||||
|
||||
def test_validate_function_response_other_types():
|
||||
"""Test that other types are converted to strings"""
|
||||
# Test integer
|
||||
response = validate_function_response(42, return_char_limit=100)
|
||||
assert response == "42"
|
||||
|
||||
# Test list
|
||||
response = validate_function_response([1, 2, 3], return_char_limit=100)
|
||||
assert response == "[1, 2, 3]"
|
||||
|
||||
# Test boolean
|
||||
response = validate_function_response(True, return_char_limit=100)
|
||||
assert response == "True"
|
||||
|
||||
|
||||
def test_validate_function_response_strict_mode_string():
|
||||
"""Test strict mode allows strings"""
|
||||
response = validate_function_response("test", return_char_limit=100, strict=True)
|
||||
assert response == "test"
|
||||
|
||||
|
||||
def test_validate_function_response_strict_mode_none():
|
||||
"""Test strict mode allows None"""
|
||||
response = validate_function_response(None, return_char_limit=100, strict=True)
|
||||
assert response == "None"
|
||||
|
||||
|
||||
def test_validate_function_response_strict_mode_violation():
|
||||
"""Test strict mode raises ValueError for non-string/None types"""
|
||||
with pytest.raises(ValueError, match="Strict mode violation. Function returned type: int"):
|
||||
validate_function_response(42, return_char_limit=100, strict=True)
|
||||
|
||||
with pytest.raises(ValueError, match="Strict mode violation. Function returned type: dict"):
|
||||
validate_function_response({"key": "value"}, return_char_limit=100, strict=True)
|
||||
|
||||
|
||||
def test_validate_function_response_truncation():
|
||||
"""Test that long responses are truncated when truncate=True"""
|
||||
long_string = "a" * 200
|
||||
response = validate_function_response(long_string, return_char_limit=50, truncate=True)
|
||||
assert len(response) > 50 # Should include truncation message
|
||||
assert response.startswith("a" * 50)
|
||||
assert "NOTE: function output was truncated" in response
|
||||
assert "200 > 50" in response
|
||||
|
||||
|
||||
def test_validate_function_response_no_truncation():
|
||||
"""Test that long responses are not truncated when truncate=False"""
|
||||
long_string = "a" * 200
|
||||
response = validate_function_response(long_string, return_char_limit=50, truncate=False)
|
||||
assert response == long_string
|
||||
assert len(response) == 200
|
||||
|
||||
|
||||
def test_validate_function_response_exact_limit():
|
||||
"""Test response exactly at the character limit"""
|
||||
exact_string = "a" * 50
|
||||
response = validate_function_response(exact_string, return_char_limit=50, truncate=True)
|
||||
assert response == exact_string
|
||||
|
||||
|
||||
def test_validate_function_response_complex_dict():
|
||||
"""Test with complex nested dictionary"""
|
||||
complex_dict = {"nested": {"key": "value"}, "list": [1, 2, {"inner": "dict"}], "null": None, "bool": True}
|
||||
response = validate_function_response(complex_dict, return_char_limit=1000)
|
||||
# Should be valid JSON
|
||||
import json
|
||||
|
||||
parsed = json.loads(response)
|
||||
assert parsed == complex_dict
|
||||
|
||||
|
||||
def test_validate_function_response_dict_truncation():
|
||||
"""Test that serialized dict gets truncated properly"""
|
||||
# Create a dict that when serialized will exceed limit
|
||||
large_dict = {"data": "x" * 100}
|
||||
response = validate_function_response(large_dict, return_char_limit=20, truncate=True)
|
||||
assert "NOTE: function output was truncated" in response
|
||||
assert len(response) > 20 # Includes truncation message
|
||||
|
||||
|
||||
def test_validate_function_response_empty_string():
|
||||
"""Test empty string handling"""
|
||||
response = validate_function_response("", return_char_limit=100)
|
||||
assert response == ""
|
||||
|
||||
|
||||
def test_validate_function_response_whitespace():
|
||||
"""Test whitespace-only string handling"""
|
||||
response = validate_function_response(" \n\t ", return_char_limit=100)
|
||||
assert response == " \n\t "
|
||||
|
||||
Reference in New Issue
Block a user