184 lines
7.4 KiB
Python
184 lines
7.4 KiB
Python
"""Tests for LettaRequest schema validation"""
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
from letta.schemas.letta_request import CreateBatch, LettaBatchRequest, LettaRequest, LettaStreamingRequest
|
|
from letta.schemas.message import MessageCreate
|
|
|
|
|
|
class TestLettaRequest:
|
|
"""Test cases for LettaRequest schema"""
|
|
|
|
def test_letta_request_with_default_max_steps(self):
|
|
"""Test that LettaRequest uses default max_steps value"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaRequest(messages=messages)
|
|
|
|
assert request.max_steps == 10
|
|
assert request.messages == messages
|
|
assert request.use_assistant_message is True
|
|
assert request.assistant_message_tool_name == DEFAULT_MESSAGE_TOOL
|
|
assert request.assistant_message_tool_kwarg == DEFAULT_MESSAGE_TOOL_KWARG
|
|
|
|
def test_letta_request_with_custom_max_steps(self):
|
|
"""Test that LettaRequest accepts custom max_steps value"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaRequest(messages=messages, max_steps=5)
|
|
|
|
assert request.max_steps == 5
|
|
assert request.messages == messages
|
|
|
|
def test_letta_request_with_zero_max_steps(self):
|
|
"""Test that LettaRequest accepts zero max_steps"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaRequest(messages=messages, max_steps=0)
|
|
|
|
assert request.max_steps == 0
|
|
|
|
def test_letta_request_with_negative_max_steps(self):
|
|
"""Test that LettaRequest accepts negative max_steps (edge case)"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaRequest(messages=messages, max_steps=-1)
|
|
|
|
assert request.max_steps == -1
|
|
|
|
def test_letta_request_required_fields(self):
|
|
"""Test that messages field is required"""
|
|
with pytest.raises(ValidationError) as exc_info:
|
|
LettaRequest()
|
|
|
|
assert "messages" in str(exc_info.value)
|
|
|
|
def test_letta_request_with_all_fields(self):
|
|
"""Test LettaRequest with all fields specified"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaRequest(
|
|
messages=messages,
|
|
max_steps=15,
|
|
use_assistant_message=False,
|
|
assistant_message_tool_name="custom_tool",
|
|
assistant_message_tool_kwarg="custom_kwarg",
|
|
)
|
|
|
|
assert request.max_steps == 15
|
|
assert request.use_assistant_message is False
|
|
assert request.assistant_message_tool_name == "custom_tool"
|
|
assert request.assistant_message_tool_kwarg == "custom_kwarg"
|
|
|
|
def test_letta_request_json_serialization(self):
|
|
"""Test that LettaRequest can be serialized to/from JSON"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaRequest(messages=messages, max_steps=7)
|
|
|
|
# Serialize to dict
|
|
request_dict = request.model_dump()
|
|
assert request_dict["max_steps"] == 7
|
|
|
|
# Deserialize from dict
|
|
request_from_dict = LettaRequest.model_validate(request_dict)
|
|
assert request_from_dict.max_steps == 7
|
|
assert request_from_dict.messages[0].role == "user"
|
|
|
|
|
|
class TestLettaStreamingRequest:
|
|
"""Test cases for LettaStreamingRequest schema"""
|
|
|
|
def test_letta_streaming_request_inherits_max_steps(self):
|
|
"""Test that LettaStreamingRequest inherits max_steps from LettaRequest"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaStreamingRequest(messages=messages, max_steps=12)
|
|
|
|
assert request.max_steps == 12
|
|
assert request.stream_tokens is False # Default value
|
|
|
|
def test_letta_streaming_request_with_streaming_options(self):
|
|
"""Test LettaStreamingRequest with streaming-specific options"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaStreamingRequest(messages=messages, max_steps=8, stream_tokens=True)
|
|
|
|
assert request.max_steps == 8
|
|
assert request.stream_tokens is True
|
|
|
|
|
|
class TestLettaBatchRequest:
|
|
"""Test cases for LettaBatchRequest schema"""
|
|
|
|
def test_letta_batch_request_inherits_max_steps(self):
|
|
"""Test that LettaBatchRequest inherits max_steps from LettaRequest"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
request = LettaBatchRequest(messages=messages, agent_id="test-agent-id", max_steps=20)
|
|
|
|
assert request.max_steps == 20
|
|
assert request.agent_id == "test-agent-id"
|
|
|
|
def test_letta_batch_request_required_agent_id(self):
|
|
"""Test that agent_id is required for LettaBatchRequest"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
|
|
with pytest.raises(ValidationError) as exc_info:
|
|
LettaBatchRequest(messages=messages)
|
|
|
|
assert "agent_id" in str(exc_info.value)
|
|
|
|
|
|
class TestCreateBatch:
|
|
"""Test cases for CreateBatch schema"""
|
|
|
|
def test_create_batch_with_max_steps(self):
|
|
"""Test CreateBatch containing requests with max_steps"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
batch_requests = [
|
|
LettaBatchRequest(messages=messages, agent_id="agent-1", max_steps=5),
|
|
LettaBatchRequest(messages=messages, agent_id="agent-2", max_steps=10),
|
|
]
|
|
|
|
batch = CreateBatch(requests=batch_requests)
|
|
|
|
assert len(batch.requests) == 2
|
|
assert batch.requests[0].max_steps == 5
|
|
assert batch.requests[1].max_steps == 10
|
|
|
|
def test_create_batch_with_callback_url(self):
|
|
"""Test CreateBatch with callback URL"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
batch_requests = [LettaBatchRequest(messages=messages, agent_id="agent-1", max_steps=3)]
|
|
|
|
batch = CreateBatch(requests=batch_requests, callback_url="https://example.com/callback")
|
|
|
|
assert str(batch.callback_url) == "https://example.com/callback"
|
|
assert batch.requests[0].max_steps == 3
|
|
|
|
|
|
class TestLettaRequestIntegration:
|
|
"""Integration tests for LettaRequest usage patterns"""
|
|
|
|
def test_max_steps_propagation_in_inheritance_chain(self):
|
|
"""Test that max_steps works correctly across the inheritance chain"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
|
|
# Test base LettaRequest
|
|
base_request = LettaRequest(messages=messages, max_steps=3)
|
|
assert base_request.max_steps == 3
|
|
|
|
# Test LettaStreamingRequest inheritance
|
|
streaming_request = LettaStreamingRequest(messages=messages, max_steps=6)
|
|
assert streaming_request.max_steps == 6
|
|
|
|
# Test LettaBatchRequest inheritance
|
|
batch_request = LettaBatchRequest(messages=messages, agent_id="test-agent", max_steps=9)
|
|
assert batch_request.max_steps == 9
|
|
|
|
def test_backwards_compatibility(self):
|
|
"""Test that existing code without max_steps still works"""
|
|
messages = [MessageCreate(role="user", content="Test message")]
|
|
|
|
# Should work without max_steps (uses default)
|
|
request = LettaRequest(messages=messages)
|
|
assert request.max_steps == 10
|
|
|
|
# Should work with all other fields
|
|
request = LettaRequest(messages=messages, use_assistant_message=False, assistant_message_tool_name="custom_tool")
|
|
assert request.max_steps == 10 # Still uses default
|