feat: send stop reason in letta APIs (#2789)

This commit is contained in:
cthomas
2025-06-13 16:04:48 -07:00
committed by GitHub
parent 93c15244ab
commit 1405464a1c
16 changed files with 128 additions and 62 deletions

View File

@@ -6,6 +6,7 @@ import pytest
from letta.config import LettaConfig
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
from letta.server.server import SyncServer
@@ -216,7 +217,9 @@ def test_single_path_agent_tool_call_graph(
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(messages=letta_messages, usage=usage_stats)
response = LettaResponse(
messages=letta_messages, stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), usage=usage_stats
)
# Make checks
assert_sanity_checks(response)
@@ -332,7 +335,11 @@ def test_claude_initial_tool_rule_enforced(
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(messages=letta_messages, usage=usage_stats)
response = LettaResponse(
messages=letta_messages,
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=usage_stats,
)
assert_sanity_checks(response)
@@ -407,7 +414,11 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized(
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(messages=letta_messages, usage=usage_stats)
response = LettaResponse(
messages=letta_messages,
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=usage_stats,
)
# Run assertions
assert_sanity_checks(response)
@@ -465,7 +476,11 @@ def test_init_tool_rule_always_fails(
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
response = LettaResponse(messages=letta_messages, usage=usage_stats)
response = LettaResponse(
messages=letta_messages,
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=usage_stats,
)
assert_invoked_function_call(response.messages, auto_error_tool.name)
@@ -504,7 +519,11 @@ def test_continue_tool_rule(server, default_user):
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
response = LettaResponse(messages=letta_messages, usage=usage_stats)
response = LettaResponse(
messages=letta_messages,
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=usage_stats,
)
assert_invoked_function_call(response.messages, "send_message")
assert_invoked_function_call(response.messages, "core_memory_append")