feat: add max_steps as argument to messages.create (#2664)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Shangyin Tan
2025-06-09 16:54:48 -07:00
committed by GitHub
parent 8780676905
commit f9b6efa632
6 changed files with 676 additions and 67 deletions

View File

@@ -0,0 +1,183 @@
"""Tests for LettaRequest schema validation"""
import pytest
from pydantic import ValidationError
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.letta_request import CreateBatch, LettaBatchRequest, LettaRequest, LettaStreamingRequest
from letta.schemas.message import MessageCreate
class TestLettaRequest:
"""Test cases for LettaRequest schema"""
def test_letta_request_with_default_max_steps(self):
"""Test that LettaRequest uses default max_steps value"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages)
assert request.max_steps == 10
assert request.messages == messages
assert request.use_assistant_message is True
assert request.assistant_message_tool_name == DEFAULT_MESSAGE_TOOL
assert request.assistant_message_tool_kwarg == DEFAULT_MESSAGE_TOOL_KWARG
def test_letta_request_with_custom_max_steps(self):
"""Test that LettaRequest accepts custom max_steps value"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages, max_steps=5)
assert request.max_steps == 5
assert request.messages == messages
def test_letta_request_with_zero_max_steps(self):
"""Test that LettaRequest accepts zero max_steps"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages, max_steps=0)
assert request.max_steps == 0
def test_letta_request_with_negative_max_steps(self):
"""Test that LettaRequest accepts negative max_steps (edge case)"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages, max_steps=-1)
assert request.max_steps == -1
def test_letta_request_required_fields(self):
"""Test that messages field is required"""
with pytest.raises(ValidationError) as exc_info:
LettaRequest()
assert "messages" in str(exc_info.value)
def test_letta_request_with_all_fields(self):
"""Test LettaRequest with all fields specified"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(
messages=messages,
max_steps=15,
use_assistant_message=False,
assistant_message_tool_name="custom_tool",
assistant_message_tool_kwarg="custom_kwarg",
)
assert request.max_steps == 15
assert request.use_assistant_message is False
assert request.assistant_message_tool_name == "custom_tool"
assert request.assistant_message_tool_kwarg == "custom_kwarg"
def test_letta_request_json_serialization(self):
"""Test that LettaRequest can be serialized to/from JSON"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages, max_steps=7)
# Serialize to dict
request_dict = request.model_dump()
assert request_dict["max_steps"] == 7
# Deserialize from dict
request_from_dict = LettaRequest.model_validate(request_dict)
assert request_from_dict.max_steps == 7
assert request_from_dict.messages[0].role == "user"
class TestLettaStreamingRequest:
"""Test cases for LettaStreamingRequest schema"""
def test_letta_streaming_request_inherits_max_steps(self):
"""Test that LettaStreamingRequest inherits max_steps from LettaRequest"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaStreamingRequest(messages=messages, max_steps=12)
assert request.max_steps == 12
assert request.stream_tokens is False # Default value
def test_letta_streaming_request_with_streaming_options(self):
"""Test LettaStreamingRequest with streaming-specific options"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaStreamingRequest(messages=messages, max_steps=8, stream_tokens=True)
assert request.max_steps == 8
assert request.stream_tokens is True
class TestLettaBatchRequest:
"""Test cases for LettaBatchRequest schema"""
def test_letta_batch_request_inherits_max_steps(self):
"""Test that LettaBatchRequest inherits max_steps from LettaRequest"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaBatchRequest(messages=messages, agent_id="test-agent-id", max_steps=20)
assert request.max_steps == 20
assert request.agent_id == "test-agent-id"
def test_letta_batch_request_required_agent_id(self):
"""Test that agent_id is required for LettaBatchRequest"""
messages = [MessageCreate(role="user", content="Test message")]
with pytest.raises(ValidationError) as exc_info:
LettaBatchRequest(messages=messages)
assert "agent_id" in str(exc_info.value)
class TestCreateBatch:
"""Test cases for CreateBatch schema"""
def test_create_batch_with_max_steps(self):
"""Test CreateBatch containing requests with max_steps"""
messages = [MessageCreate(role="user", content="Test message")]
batch_requests = [
LettaBatchRequest(messages=messages, agent_id="agent-1", max_steps=5),
LettaBatchRequest(messages=messages, agent_id="agent-2", max_steps=10),
]
batch = CreateBatch(requests=batch_requests)
assert len(batch.requests) == 2
assert batch.requests[0].max_steps == 5
assert batch.requests[1].max_steps == 10
def test_create_batch_with_callback_url(self):
"""Test CreateBatch with callback URL"""
messages = [MessageCreate(role="user", content="Test message")]
batch_requests = [LettaBatchRequest(messages=messages, agent_id="agent-1", max_steps=3)]
batch = CreateBatch(requests=batch_requests, callback_url="https://example.com/callback")
assert str(batch.callback_url) == "https://example.com/callback"
assert batch.requests[0].max_steps == 3
class TestLettaRequestIntegration:
"""Integration tests for LettaRequest usage patterns"""
def test_max_steps_propagation_in_inheritance_chain(self):
"""Test that max_steps works correctly across the inheritance chain"""
messages = [MessageCreate(role="user", content="Test message")]
# Test base LettaRequest
base_request = LettaRequest(messages=messages, max_steps=3)
assert base_request.max_steps == 3
# Test LettaStreamingRequest inheritance
streaming_request = LettaStreamingRequest(messages=messages, max_steps=6)
assert streaming_request.max_steps == 6
# Test LettaBatchRequest inheritance
batch_request = LettaBatchRequest(messages=messages, agent_id="test-agent", max_steps=9)
assert batch_request.max_steps == 9
def test_backwards_compatibility(self):
"""Test that existing code without max_steps still works"""
messages = [MessageCreate(role="user", content="Test message")]
# Should work without max_steps (uses default)
request = LettaRequest(messages=messages)
assert request.max_steps == 10
# Should work with all other fields
request = LettaRequest(messages=messages, use_assistant_message=False, assistant_message_tool_name="custom_tool")
assert request.max_steps == 10 # Still uses default