feat: add max_steps as argument to messages.create (#2664)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Shangyin Tan
2025-06-09 16:54:48 -07:00
committed by GitHub
parent 9601e74738
commit b8a0651578
6 changed files with 676 additions and 67 deletions

View File

@@ -136,6 +136,7 @@ class AbstractClient(object):
stream: Optional[bool] = False,
stream_steps: bool = False,
stream_tokens: bool = False,
max_steps: Optional[int] = None,
) -> LettaResponse:
raise NotImplementedError
@@ -977,7 +978,8 @@ class RESTClient(AbstractClient):
stream: Optional[bool] = False,
stream_steps: bool = False,
stream_tokens: bool = False,
) -> Union[LettaResponse, Generator[LettaStreamingResponse, None, None]]:
max_steps: Optional[int] = 10,
) -> LettaResponse:
"""
Send a message to an agent
@@ -988,6 +990,7 @@ class RESTClient(AbstractClient):
name(str): Name of the sender
stream (bool): Stream the response (default: `False`)
stream_tokens (bool): Stream tokens (default: `False`)
max_steps (int): Maximum number of steps the agent should take (default: 10)
Returns:
response (LettaResponse): Response from the agent

View File

@@ -9,6 +9,10 @@ from letta.schemas.message import MessageCreate
class LettaRequest(BaseModel):
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
max_steps: int = Field(
default=10,
description="Maximum number of steps the agent should take to process the request.",
)
use_assistant_message: bool = Field(
default=True,
description="Whether the server should parse specific tool call arguments (default `send_message`) as `AssistantMessage` objects.",

View File

@@ -709,7 +709,7 @@ async def send_message(
result = await agent_loop.step(
request.messages,
max_steps=10,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
@@ -798,7 +798,7 @@ async def send_message_streaming(
result = StreamingResponseWithStatusCode(
agent_loop.step_stream(
input_messages=request.messages,
max_steps=10,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
@@ -809,7 +809,7 @@ async def send_message_streaming(
result = StreamingResponseWithStatusCode(
agent_loop.step_stream_no_tokens(
request.messages,
max_steps=10,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
@@ -843,6 +843,7 @@ async def process_message_background(
use_assistant_message: bool,
assistant_message_tool_name: str,
assistant_message_tool_kwarg: str,
max_steps: int = 10,
include_return_message_types: Optional[List[MessageType]] = None,
) -> None:
"""Background task to process the message and update job status."""
@@ -927,6 +928,7 @@ async def send_message_async(
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
max_steps=request.max_steps,
include_return_message_types=request.include_return_message_types,
)

541
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -73,7 +73,7 @@ llama-index = "^0.12.2"
llama-index-embeddings-openai = "^0.3.1"
e2b-code-interpreter = {version = "^1.0.3", optional = true}
anthropic = "^0.49.0"
letta_client = "^0.1.148"
letta_client = "^0.1.149"
openai = "^1.60.0"
opentelemetry-api = "1.30.0"
opentelemetry-sdk = "1.30.0"

View File

@@ -0,0 +1,183 @@
"""Tests for LettaRequest schema validation"""
import pytest
from pydantic import ValidationError
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.letta_request import CreateBatch, LettaBatchRequest, LettaRequest, LettaStreamingRequest
from letta.schemas.message import MessageCreate
class TestLettaRequest:
"""Test cases for LettaRequest schema"""
def test_letta_request_with_default_max_steps(self):
"""Test that LettaRequest uses default max_steps value"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages)
assert request.max_steps == 10
assert request.messages == messages
assert request.use_assistant_message is True
assert request.assistant_message_tool_name == DEFAULT_MESSAGE_TOOL
assert request.assistant_message_tool_kwarg == DEFAULT_MESSAGE_TOOL_KWARG
def test_letta_request_with_custom_max_steps(self):
"""Test that LettaRequest accepts custom max_steps value"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages, max_steps=5)
assert request.max_steps == 5
assert request.messages == messages
def test_letta_request_with_zero_max_steps(self):
"""Test that LettaRequest accepts zero max_steps"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages, max_steps=0)
assert request.max_steps == 0
def test_letta_request_with_negative_max_steps(self):
"""Test that LettaRequest accepts negative max_steps (edge case)"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages, max_steps=-1)
assert request.max_steps == -1
def test_letta_request_required_fields(self):
"""Test that messages field is required"""
with pytest.raises(ValidationError) as exc_info:
LettaRequest()
assert "messages" in str(exc_info.value)
def test_letta_request_with_all_fields(self):
"""Test LettaRequest with all fields specified"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(
messages=messages,
max_steps=15,
use_assistant_message=False,
assistant_message_tool_name="custom_tool",
assistant_message_tool_kwarg="custom_kwarg",
)
assert request.max_steps == 15
assert request.use_assistant_message is False
assert request.assistant_message_tool_name == "custom_tool"
assert request.assistant_message_tool_kwarg == "custom_kwarg"
def test_letta_request_json_serialization(self):
"""Test that LettaRequest can be serialized to/from JSON"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaRequest(messages=messages, max_steps=7)
# Serialize to dict
request_dict = request.model_dump()
assert request_dict["max_steps"] == 7
# Deserialize from dict
request_from_dict = LettaRequest.model_validate(request_dict)
assert request_from_dict.max_steps == 7
assert request_from_dict.messages[0].role == "user"
class TestLettaStreamingRequest:
"""Test cases for LettaStreamingRequest schema"""
def test_letta_streaming_request_inherits_max_steps(self):
"""Test that LettaStreamingRequest inherits max_steps from LettaRequest"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaStreamingRequest(messages=messages, max_steps=12)
assert request.max_steps == 12
assert request.stream_tokens is False # Default value
def test_letta_streaming_request_with_streaming_options(self):
"""Test LettaStreamingRequest with streaming-specific options"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaStreamingRequest(messages=messages, max_steps=8, stream_tokens=True)
assert request.max_steps == 8
assert request.stream_tokens is True
class TestLettaBatchRequest:
"""Test cases for LettaBatchRequest schema"""
def test_letta_batch_request_inherits_max_steps(self):
"""Test that LettaBatchRequest inherits max_steps from LettaRequest"""
messages = [MessageCreate(role="user", content="Test message")]
request = LettaBatchRequest(messages=messages, agent_id="test-agent-id", max_steps=20)
assert request.max_steps == 20
assert request.agent_id == "test-agent-id"
def test_letta_batch_request_required_agent_id(self):
"""Test that agent_id is required for LettaBatchRequest"""
messages = [MessageCreate(role="user", content="Test message")]
with pytest.raises(ValidationError) as exc_info:
LettaBatchRequest(messages=messages)
assert "agent_id" in str(exc_info.value)
class TestCreateBatch:
"""Test cases for CreateBatch schema"""
def test_create_batch_with_max_steps(self):
"""Test CreateBatch containing requests with max_steps"""
messages = [MessageCreate(role="user", content="Test message")]
batch_requests = [
LettaBatchRequest(messages=messages, agent_id="agent-1", max_steps=5),
LettaBatchRequest(messages=messages, agent_id="agent-2", max_steps=10),
]
batch = CreateBatch(requests=batch_requests)
assert len(batch.requests) == 2
assert batch.requests[0].max_steps == 5
assert batch.requests[1].max_steps == 10
def test_create_batch_with_callback_url(self):
"""Test CreateBatch with callback URL"""
messages = [MessageCreate(role="user", content="Test message")]
batch_requests = [LettaBatchRequest(messages=messages, agent_id="agent-1", max_steps=3)]
batch = CreateBatch(requests=batch_requests, callback_url="https://example.com/callback")
assert str(batch.callback_url) == "https://example.com/callback"
assert batch.requests[0].max_steps == 3
class TestLettaRequestIntegration:
"""Integration tests for LettaRequest usage patterns"""
def test_max_steps_propagation_in_inheritance_chain(self):
"""Test that max_steps works correctly across the inheritance chain"""
messages = [MessageCreate(role="user", content="Test message")]
# Test base LettaRequest
base_request = LettaRequest(messages=messages, max_steps=3)
assert base_request.max_steps == 3
# Test LettaStreamingRequest inheritance
streaming_request = LettaStreamingRequest(messages=messages, max_steps=6)
assert streaming_request.max_steps == 6
# Test LettaBatchRequest inheritance
batch_request = LettaBatchRequest(messages=messages, agent_id="test-agent", max_steps=9)
assert batch_request.max_steps == 9
def test_backwards_compatibility(self):
"""Test that existing code without max_steps still works"""
messages = [MessageCreate(role="user", content="Test message")]
# Should work without max_steps (uses default)
request = LettaRequest(messages=messages)
assert request.max_steps == 10
# Should work with all other fields
request = LettaRequest(messages=messages, use_assistant_message=False, assistant_message_tool_name="custom_tool")
assert request.max_steps == 10 # Still uses default