fix: structured outputs for send_message, LettaMessage

This commit is contained in:
Andy Li
2025-07-24 14:50:52 -07:00
committed by GitHub
parent 357e30fc55
commit 94d10589c8
6 changed files with 238 additions and 130 deletions

View File

@@ -42,7 +42,7 @@ def get_local_time_timezone(timezone=DEFAULT_TIMEZONE):
return formatted_time
def get_local_time(timezone=DEFAULT_TIMEZONE):
def get_local_time(timezone: str | None = DEFAULT_TIMEZONE):
if timezone is not None:
time_str = get_local_time_timezone(timezone)
else:

View File

@@ -7,7 +7,7 @@ def json_loads(data):
return json.loads(data, strict=False)
def json_dumps(data, indent=2):
def json_dumps(data, indent=2) -> str:
def safe_serializer(obj):
if isinstance(obj, datetime):
return obj.isoformat()

View File

@@ -41,7 +41,7 @@ from letta.schemas.letta_message_content import (
get_letta_message_content_union_str_json_schema,
)
from letta.system import unpack_message
from letta.utils import parse_json
from letta.utils import parse_json, validate_function_response
def add_inner_thoughts_to_tool_call(
@@ -251,10 +251,10 @@ class Message(BaseMessage):
include_err: Optional[bool] = None,
) -> List[LettaMessage]:
"""Convert message object (in DB format) to the style used by the original Letta API"""
messages = []
# TODO (cliandy): break this into more manageable pieces
if self.role == MessageRole.assistant:
messages = []
# Handle reasoning
if self.content:
# Check for ReACT-style COT inside of TextContent
@@ -348,7 +348,7 @@ class Message(BaseMessage):
# We need to unpack the actual message contents from the function call
try:
func_args = parse_json(tool_call.function.arguments)
message_string = func_args[assistant_message_tool_kwarg]
message_string = validate_function_response(func_args[assistant_message_tool_kwarg], 0, truncate=False)
except KeyError:
raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument")
messages.append(
@@ -380,99 +380,106 @@ class Message(BaseMessage):
is_err=self.is_err,
)
)
elif self.role == MessageRole.tool:
# This is type ToolReturnMessage
# Try to interpret the function return, recall that this is how we packaged:
# def package_function_response(was_success, response_string, timestamp=None):
# formatted_time = get_local_time() if timestamp is None else timestamp
# packaged_message = {
# "status": "OK" if was_success else "Failed",
# "message": response_string,
# "time": formatted_time,
# }
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
text_content = self.content[0].text
else:
raise ValueError(f"Invalid tool return (no text object on message): {self.content}")
try:
function_return = parse_json(text_content)
text_content = str(function_return.get("message", text_content))
status = function_return["status"]
if status == "OK":
status_enum = "success"
elif status == "Failed":
status_enum = "error"
else:
raise ValueError(f"Invalid status: {status}")
except json.JSONDecodeError:
raise ValueError(f"Failed to decode function return: {text_content}")
assert self.tool_call_id is not None
messages.append(
# TODO make sure this is what the API returns
# function_return may not match exactly...
ToolReturnMessage(
id=self.id,
date=self.created_at,
tool_return=text_content,
status=self.tool_returns[0].status if self.tool_returns else status_enum,
tool_call_id=self.tool_call_id,
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
name=self.name,
otid=Message.generate_otid_from_id(self.id, len(messages)),
sender_id=self.sender_id,
step_id=self.step_id,
is_err=self.is_err,
)
)
messages = [self._convert_tool_message()]
elif self.role == MessageRole.user:
# This is type UserMessage
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
text_content = self.content[0].text
elif self.content:
text_content = self.content
else:
raise ValueError(f"Invalid user message (no text object on message): {self.content}")
message = unpack_message(text_content)
messages.append(
UserMessage(
id=self.id,
date=self.created_at,
content=message,
name=self.name,
otid=self.otid,
sender_id=self.sender_id,
step_id=self.step_id,
is_err=self.is_err,
)
)
messages = [self._convert_user_message()]
elif self.role == MessageRole.system:
# This is type SystemMessage
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
text_content = self.content[0].text
else:
raise ValueError(f"Invalid system message (no text object on system): {self.content}")
messages.append(
SystemMessage(
id=self.id,
date=self.created_at,
content=text_content,
name=self.name,
otid=self.otid,
sender_id=self.sender_id,
step_id=self.step_id,
)
)
messages = [self._convert_system_message()]
else:
raise ValueError(self.role)
raise ValueError(f"Unknown role: {self.role}")
if reverse:
messages.reverse()
return messages[::-1] if reverse else messages
return messages
def _convert_tool_message(self) -> ToolReturnMessage:
"""Convert tool role message to ToolReturnMessage
the tool return is packaged as follows:
packaged_message = {
"status": "OK" if was_success else "Failed",
"message": response_string,
"time": formatted_time,
}
"""
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
text_content = self.content[0].text
else:
raise ValueError(f"Invalid tool return (no text object on message): {self.content}")
try:
function_return = parse_json(text_content)
message_text = str(function_return.get("message", text_content))
status = self._parse_tool_status(function_return["status"])
except json.JSONDecodeError:
raise ValueError(f"Failed to decode function return: {text_content}")
assert self.tool_call_id is not None
return ToolReturnMessage(
id=self.id,
date=self.created_at,
tool_return=message_text,
status=self.tool_returns[0].status if self.tool_returns else status,
tool_call_id=self.tool_call_id,
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
name=self.name,
otid=Message.generate_otid_from_id(self.id, 0),
sender_id=self.sender_id,
step_id=self.step_id,
is_err=self.is_err,
)
@staticmethod
def _parse_tool_status(status: str) -> Literal["success", "error"]:
"""Convert tool status string to enum value"""
if status == "OK":
return "success"
elif status == "Failed":
return "error"
else:
raise ValueError(f"Invalid status: {status}")
def _convert_user_message(self) -> UserMessage:
"""Convert user role message to UserMessage"""
# Extract text content
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
text_content = self.content[0].text
elif self.content:
text_content = self.content
else:
raise ValueError(f"Invalid user message (no text object on message): {self.content}")
message = unpack_message(text_content)
return UserMessage(
id=self.id,
date=self.created_at,
content=message,
name=self.name,
otid=self.otid,
sender_id=self.sender_id,
step_id=self.step_id,
is_err=self.is_err,
)
def _convert_system_message(self) -> SystemMessage:
"""Convert system role message to SystemMessage"""
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
text_content = self.content[0].text
else:
raise ValueError(f"Invalid system message (no text object on system): {self.content}")
return SystemMessage(
id=self.id,
date=self.created_at,
content=text_content,
name=self.name,
otid=self.otid,
sender_id=self.sender_id,
step_id=self.step_id,
)
@staticmethod
def dict_to_message(

View File

@@ -141,7 +141,7 @@ def package_user_message(
return json_dumps(packaged_message)
def package_function_response(was_success, response_string, timezone):
def package_function_response(was_success: bool, response_string: str, timezone: str | None) -> str:
formatted_time = get_local_time(timezone=timezone)
packaged_message = {
"status": "OK" if was_success else "Failed",

View File

@@ -27,7 +27,6 @@ from sqlalchemy import text
import letta
from letta.constants import (
CLI_WARNING_PREFIX,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT,
@@ -851,47 +850,32 @@ def parse_json(string) -> dict:
raise e
def validate_function_response(function_response_string: any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str:
def validate_function_response(function_response: Any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str:
"""Check to make sure that a function used by Letta returned a valid response. Truncates to return_char_limit if necessary.
Responses need to be strings (or None) that fall under a certain text count limit.
This makes sure that we can coerce the function_response into a string that meets our criteria. We handle some soft coercion.
If strict is True, we raise a ValueError if function_response is not a string or None.
"""
if not isinstance(function_response_string, str):
# Soft correction for a few basic types
if isinstance(function_response, str):
function_response_string = function_response
if function_response_string is None:
# function_response_string = "Empty (no function output)"
function_response_string = "None" # backcompat
elif function_response is None:
function_response_string = "None"
elif isinstance(function_response_string, dict):
if strict:
# TODO add better error message
raise ValueError(function_response_string)
elif strict:
raise ValueError(f"Strict mode violation. Function returned type: {type(function_response).__name__}")
# Allow dict through since it will be cast to json.dumps()
try:
# TODO find a better way to do this that won't result in double escapes
function_response_string = json_dumps(function_response_string)
except:
raise ValueError(function_response_string)
elif isinstance(function_response, dict):
# As functions can return arbitrary data, if there's already nesting somewhere in the response, it's difficult
# for us to not result in double escapes.
function_response_string = json_dumps(function_response)
else:
logger.debug(f"Function returned type {type(function_response).__name__}. Coercing to string.")
function_response_string = str(function_response)
else:
if strict:
# TODO add better error message
raise ValueError(function_response_string)
# Try to convert to a string, but throw a warning to alert the user
try:
function_response_string = str(function_response_string)
except:
raise ValueError(function_response_string)
# Now check the length and make sure it doesn't go over the limit
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
if truncate and len(function_response_string) > return_char_limit:
print(
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated"
)
logger.warning(f"function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated")
function_response_string = f"{function_response_string[:return_char_limit]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {return_char_limit})]"
return function_response_string

View File

@@ -1,5 +1,3 @@
import asyncio
import pytest
from letta.constants import MAX_FILENAME_LENGTH
@@ -7,7 +5,7 @@ from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_fun
from letta.schemas.file import FileMetadata
from letta.services.file_processor.chunker.line_chunker import LineChunker
from letta.services.helpers.agent_manager_helper import safe_format
from letta.utils import sanitize_filename
from letta.utils import sanitize_filename, validate_function_response
CORE_MEMORY_VAR = "My core memory is that I like to eat bananas"
VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR}
@@ -555,3 +553,122 @@ async def test_get_latest_alembic_revision_consistency(event_loop):
# They should be identical
assert revision_id1 == revision_id2
# ---------------------- validate_function_response TESTS ---------------------- #
def test_validate_function_response_string_input():
"""Test that string inputs are returned unchanged when within limit"""
response = validate_function_response("hello world", return_char_limit=100)
assert response == "hello world"
def test_validate_function_response_none_input():
"""Test that None inputs are converted to 'None' string"""
response = validate_function_response(None, return_char_limit=100)
assert response == "None"
def test_validate_function_response_dict_input():
"""Test that dict inputs are JSON serialized"""
test_dict = {"key": "value", "number": 42}
response = validate_function_response(test_dict, return_char_limit=100)
# Response should be valid JSON string
import json
parsed = json.loads(response)
assert parsed == test_dict
def test_validate_function_response_other_types():
"""Test that other types are converted to strings"""
# Test integer
response = validate_function_response(42, return_char_limit=100)
assert response == "42"
# Test list
response = validate_function_response([1, 2, 3], return_char_limit=100)
assert response == "[1, 2, 3]"
# Test boolean
response = validate_function_response(True, return_char_limit=100)
assert response == "True"
def test_validate_function_response_strict_mode_string():
"""Test strict mode allows strings"""
response = validate_function_response("test", return_char_limit=100, strict=True)
assert response == "test"
def test_validate_function_response_strict_mode_none():
"""Test strict mode allows None"""
response = validate_function_response(None, return_char_limit=100, strict=True)
assert response == "None"
def test_validate_function_response_strict_mode_violation():
"""Test strict mode raises ValueError for non-string/None types"""
with pytest.raises(ValueError, match="Strict mode violation. Function returned type: int"):
validate_function_response(42, return_char_limit=100, strict=True)
with pytest.raises(ValueError, match="Strict mode violation. Function returned type: dict"):
validate_function_response({"key": "value"}, return_char_limit=100, strict=True)
def test_validate_function_response_truncation():
"""Test that long responses are truncated when truncate=True"""
long_string = "a" * 200
response = validate_function_response(long_string, return_char_limit=50, truncate=True)
assert len(response) > 50 # Should include truncation message
assert response.startswith("a" * 50)
assert "NOTE: function output was truncated" in response
assert "200 > 50" in response
def test_validate_function_response_no_truncation():
"""Test that long responses are not truncated when truncate=False"""
long_string = "a" * 200
response = validate_function_response(long_string, return_char_limit=50, truncate=False)
assert response == long_string
assert len(response) == 200
def test_validate_function_response_exact_limit():
"""Test response exactly at the character limit"""
exact_string = "a" * 50
response = validate_function_response(exact_string, return_char_limit=50, truncate=True)
assert response == exact_string
def test_validate_function_response_complex_dict():
"""Test with complex nested dictionary"""
complex_dict = {"nested": {"key": "value"}, "list": [1, 2, {"inner": "dict"}], "null": None, "bool": True}
response = validate_function_response(complex_dict, return_char_limit=1000)
# Should be valid JSON
import json
parsed = json.loads(response)
assert parsed == complex_dict
def test_validate_function_response_dict_truncation():
"""Test that serialized dict gets truncated properly"""
# Create a dict that when serialized will exceed limit
large_dict = {"data": "x" * 100}
response = validate_function_response(large_dict, return_char_limit=20, truncate=True)
assert "NOTE: function output was truncated" in response
assert len(response) > 20 # Includes truncation message
def test_validate_function_response_empty_string():
"""Test empty string handling"""
response = validate_function_response("", return_char_limit=100)
assert response == ""
def test_validate_function_response_whitespace():
"""Test whitespace-only string handling"""
response = validate_function_response(" \n\t ", return_char_limit=100)
assert response == " \n\t "