From 1405464a1c3e42ee22ded3530207922ba83b44bd Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 13 Jun 2025 16:04:48 -0700 Subject: [PATCH] feat: send stop reason in letta APIs (#2789) --- letta/agents/base_agent.py | 6 ++- letta/agents/helpers.py | 7 +++- letta/agents/letta_agent.py | 40 ++++++++++++++----- letta/agents/voice_sleeptime_agent.py | 3 +- letta/functions/helpers.py | 7 +++- .../anthropic_streaming_interface.py | 3 ++ .../interfaces/openai_streaming_interface.py | 3 ++ letta/schemas/letta_response.py | 7 +++- letta/schemas/letta_stop_reason.py | 30 ++------------ letta/server/rest_api/app.py | 2 - letta/server/server.py | 13 +++++- tests/integration_test_agent_tool_graph.py | 29 +++++++++++--- tests/integration_test_pinecone_tool.py | 3 +- tests/integration_test_send_message.py | 31 +++++++++++--- tests/test_sdk_client.py | 4 +- tests/utils.py | 2 - 16 files changed, 128 insertions(+), 62 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 9624fa49..cf903c79 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -12,6 +12,7 @@ from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -138,8 +139,11 @@ class BaseAgent(ABC): logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})") raise - def get_finish_chunks_for_stream(self, usage: LettaUsageStatistics): + def get_finish_chunks_for_stream(self, usage: LettaUsageStatistics, stop_reason: Optional[LettaStopReason] = None): + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) return [ + stop_reason.model_dump_json(), usage.model_dump_json(), MessageStreamStatus.done.value, ] diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 5e96996a..de7c4d15 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple from letta.schemas.agent import AgentState from letta.schemas.letta_message import MessageType from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message, MessageCreate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -16,6 +17,7 @@ def _create_letta_response( new_in_context_messages: list[Message], use_assistant_message: bool, usage: LettaUsageStatistics, + stop_reason: Optional[LettaStopReason] = None, include_return_message_types: Optional[List[MessageType]] = None, ) -> LettaResponse: """ @@ -32,8 +34,9 @@ def _create_letta_response( # Apply message type filtering if specified if include_return_message_types is not None: response_messages = [msg for msg in response_messages if msg.message_type in include_return_message_types] - - return LettaResponse(messages=response_messages, usage=usage) + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + return LettaResponse(messages=response_messages, stop_reason=stop_reason, usage=usage) def _prepare_in_context_messages( diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 495278db..c8c5284d 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -30,6 +30,7 @@ from letta.schemas.enums import MessageRole from letta.schemas.letta_message import MessageType from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics @@ -125,7 +126,7 @@ class LettaAgent(BaseAgent): agent_state = await self.agent_manager.get_agent_by_id_async( agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor ) - _, new_in_context_messages, usage = await self._step( + _, new_in_context_messages, usage, stop_reason = await self._step( agent_state=agent_state, input_messages=input_messages, max_steps=max_steps, @@ -134,6 +135,7 @@ class LettaAgent(BaseAgent): return _create_letta_response( new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, + stop_reason=stop_reason, usage=usage, include_return_message_types=include_return_message_types, ) @@ -160,6 +162,7 @@ class LettaAgent(BaseAgent): put_inner_thoughts_first=True, actor=self.actor, ) + stop_reason = None usage = LettaUsageStatistics() # span for request @@ -218,7 +221,7 @@ class LettaAgent(BaseAgent): logger.info("No reasoning content found.") reasoning = None - persisted_messages, should_continue = await self._handle_ai_response( + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( tool_call, valid_tool_names, agent_state, @@ -285,7 +288,7 @@ class LettaAgent(BaseAgent): request_span.end() # Return back usage - for finish_chunk in self.get_finish_chunks_for_stream(usage): + for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason): yield f"data: {finish_chunk}\n\n" async def _step( @@ -294,7 +297,7 @@ class LettaAgent(BaseAgent): input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, request_start_timestamp_ns: Optional[int] = None, - ) -> Tuple[List[Message], List[Message], LettaUsageStatistics]: + ) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]: """ Carries out an invocation of the agent loop. In each step, the agent 1. Rebuilds its memory @@ -317,6 +320,7 @@ class LettaAgent(BaseAgent): request_span = tracer.start_span("time_to_first_token") request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) + stop_reason = None usage = LettaUsageStatistics() for i in range(max_steps): step_id = generate_step_id() @@ -364,7 +368,7 @@ class LettaAgent(BaseAgent): logger.info("No reasoning content found.") reasoning = None - persisted_messages, should_continue = await self._handle_ai_response( + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( tool_call, valid_tool_names, agent_state, @@ -420,7 +424,7 @@ class LettaAgent(BaseAgent): force=False, ) - return current_in_context_messages, new_in_context_messages, usage + return current_in_context_messages, new_in_context_messages, usage, stop_reason @trace_method async def step_stream( @@ -453,6 +457,7 @@ class LettaAgent(BaseAgent): put_inner_thoughts_first=True, actor=self.actor, ) + stop_reason = None usage = LettaUsageStatistics() first_chunk, request_span = True, None if request_start_timestamp_ns: @@ -536,9 +541,18 @@ class LettaAgent(BaseAgent): ) # Process resulting stream content - tool_call = interface.get_tool_call_object() + try: + tool_call = interface.get_tool_call_object() + except ValueError as e: + stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) + yield f"data: {stop_reason.model_dump_json()}\n\n" + raise e + except Exception as e: + stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value) + yield f"data: {stop_reason.model_dump_json()}\n\n" + raise e reasoning_content = interface.get_reasoning_content() - persisted_messages, should_continue = await self._handle_ai_response( + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( tool_call, valid_tool_names, agent_state, @@ -621,7 +635,7 @@ class LettaAgent(BaseAgent): request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(request_ns)}) request_span.end() - for finish_chunk in self.get_finish_chunks_for_stream(usage): + for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason): yield f"data: {finish_chunk}\n\n" # noinspection PyInconsistentReturns @@ -876,12 +890,13 @@ class LettaAgent(BaseAgent): initial_messages: Optional[List[Message]] = None, agent_step_span: Optional["Span"] = None, is_final_step: Optional[bool] = None, - ) -> Tuple[List[Message], bool]: + ) -> Tuple[List[Message], bool, Optional[LettaStopReason]]: """ Now that streaming is done, handle the final AI response. This might yield additional SSE tokens if we do stalling. At the end, set self._continue_execution accordingly. """ + stop_reason = None # Check if the called tool is allowed by tool name: tool_call_name = tool_call.function.name tool_call_args_str = tool_call.function.arguments @@ -899,6 +914,7 @@ class LettaAgent(BaseAgent): tool_args = json.loads(tool_args) if is_final_step: + stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value) logger.info("Agent has reached max steps.") request_heartbeat = False else: @@ -967,6 +983,8 @@ class LettaAgent(BaseAgent): continue_stepping = request_heartbeat tool_rules_solver.register_tool_call(tool_name=tool_call_name) if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name): + if continue_stepping: + stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value) continue_stepping = False elif tool_rules_solver.has_children_tools(tool_name=tool_call_name): continue_stepping = True @@ -1013,7 +1031,7 @@ class LettaAgent(BaseAgent): ) self.last_function_response = function_response - return persisted_messages, continue_stepping + return persisted_messages, continue_stepping, stop_reason @trace_method async def _execute_tool( diff --git a/letta/agents/voice_sleeptime_agent.py b/letta/agents/voice_sleeptime_agent.py index 8a9c61c6..4e70c56e 100644 --- a/letta/agents/voice_sleeptime_agent.py +++ b/letta/agents/voice_sleeptime_agent.py @@ -82,7 +82,7 @@ class VoiceSleeptimeAgent(LettaAgent): ] # Summarize - current_in_context_messages, new_in_context_messages, usage = await super()._step( + current_in_context_messages, new_in_context_messages, usage, stop_reason = await super()._step( agent_state=agent_state, input_messages=input_messages, max_steps=max_steps ) new_in_context_messages, updated = self.summarizer.summarize( @@ -95,6 +95,7 @@ class VoiceSleeptimeAgent(LettaAgent): return _create_letta_response( new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, + stop_reason=stop_reason, usage=usage, include_return_message_types=include_return_message_types, ) diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 238f85d2..161bef96 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -14,6 +14,7 @@ from letta.orm.errors import NoResultFound from letta.schemas.enums import MessageRole from letta.schemas.letta_message import AssistantMessage from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message, MessageCreate from letta.schemas.user import User from letta.server.rest_api.utils import get_letta_server @@ -292,7 +293,11 @@ async def _send_message_to_agent_no_stream( ) final_messages = interface.get_captured_send_messages() - return LettaResponse(messages=final_messages, usage=usage_stats) + return LettaResponse( + messages=final_messages, + stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), + usage=usage_stats, + ) async def _async_send_message_with_retries( diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index ad3be05a..3259bdd9 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -37,6 +37,7 @@ from letta.schemas.letta_message import ( ToolCallMessage, ) from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser @@ -385,6 +386,8 @@ class AnthropicStreamingInterface: self.anthropic_mode = None except Exception as e: logger.error("Error processing stream: %s", e) + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + yield f"data: {stop_reason.model_dump_json()}\n\n" raise finally: logger.info("AnthropicStreamingInterface: Stream processing complete.") diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 47d2ab40..25a945a2 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -11,6 +11,7 @@ from letta.otel.context import get_ctx_attributes from letta.otel.metric_registry import MetricRegistry from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage from letta.schemas.letta_message_content import TextContent +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import OptimisticJSONParser @@ -343,6 +344,8 @@ class OpenAIStreamingInterface: self.function_id_buffer = None except Exception as e: logger.error("Error processing stream: %s", e) + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + yield f"data: {stop_reason.model_dump_json()}\n\n" raise finally: logger.info("OpenAIStreamingInterface: Stream processing complete.") diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index a4057298..428d263a 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from letta.helpers.json_helpers import json_dumps from letta.schemas.enums import JobStatus, MessageStreamStatus from letta.schemas.letta_message import LettaMessage, LettaMessageUnion +from letta.schemas.letta_stop_reason import LettaStopReason from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics @@ -34,6 +35,10 @@ class LettaResponse(BaseModel): } }, ) + stop_reason: LettaStopReason = Field( + ..., + description="The stop reason from Letta indicating why agent loop stopped execution.", + ) usage: LettaUsageStatistics = Field( ..., description="The usage statistics of the agent.", @@ -166,7 +171,7 @@ class LettaResponse(BaseModel): # The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage -LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics] +LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaStopReason, LettaUsageStatistics] class LettaBatchResponse(BaseModel): diff --git a/letta/schemas/letta_stop_reason.py b/letta/schemas/letta_stop_reason.py index e5c65a3b..66761222 100644 --- a/letta/schemas/letta_stop_reason.py +++ b/letta/schemas/letta_stop_reason.py @@ -10,35 +10,13 @@ class StopReasonType(str, Enum): invalid_tool_call = "invalid_tool_call" max_steps = "max_steps" no_tool_call = "no_tool_call" + tool_rule = "tool_rule" class LettaStopReason(BaseModel): """ - The stop reason from letta used during streaming response. + The stop reason from Letta indicating why agent loop stopped execution. """ - message_type: Literal["stop_reason"] = "stop_reason" - stop_reason: StopReasonType = Field(..., description="The type of the message.") - - -def create_letta_stop_reason_schema(): - return { - "properties": { - "message_type": { - "type": "string", - "const": "stop_reason", - "title": "Message Type", - "description": "The type of the message.", - "default": "stop_reason", - }, - "stop_reason": { - "type": "string", - "enum": list(StopReasonType.__members__.keys()), - "title": "Stop Reason", - }, - }, - "type": "object", - "required": ["stop_reason"], - "title": "LettaStopReason", - "description": "Letta provided stop reason for why agent loop ended.", - } + message_type: Literal["stop_reason"] = Field("stop_reason", description="The type of the message.") + stop_reason: StopReasonType = Field(..., description="The reason why execution stopped.") diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 5d166c67..7c249087 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -24,7 +24,6 @@ from letta.schemas.letta_message_content import ( create_letta_message_content_union_schema, create_letta_user_message_content_union_schema, ) -from letta.schemas.letta_stop_reason import create_letta_stop_reason_schema from letta.server.constants import REST_DEFAULT_PORT # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests @@ -69,7 +68,6 @@ def generate_openapi_schema(app: FastAPI): letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema() letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema() letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema() - letta_docs["components"]["schemas"]["LettaStopReason"] = create_letta_stop_reason_schema() # Update the app's schema with our modified version app.openapi_schema = letta_docs diff --git a/letta/server/server.py b/letta/server/server.py index 2cdb0d33..6d3fb0a0 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -48,6 +48,7 @@ from letta.schemas.job import Job, JobUpdate from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, MessageType, ToolReturnMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from letta.schemas.message import Message, MessageCreate, MessageUpdate @@ -2359,7 +2360,11 @@ class SyncServer(Server): # If we want to convert these to Message, we can use the attached IDs # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead - return LettaResponse(messages=filtered_stream, usage=usage) + return LettaResponse( + messages=filtered_stream, + stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), + usage=usage, + ) except HTTPException: raise @@ -2461,4 +2466,8 @@ class SyncServer(Server): # If we want to convert these to Message, we can use the attached IDs # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead - return LettaResponse(messages=filtered_stream, usage=usage) + return LettaResponse( + messages=filtered_stream, + stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), + usage=usage, + ) diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 9647eb1b..6c413dcf 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -6,6 +6,7 @@ import pytest from letta.config import LettaConfig from letta.schemas.letta_message import ToolCallMessage from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import MessageCreate from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule from letta.server.server import SyncServer @@ -216,7 +217,9 @@ def test_single_path_agent_tool_call_graph( for m in messages: letta_messages += m.to_letta_messages() - response = LettaResponse(messages=letta_messages, usage=usage_stats) + response = LettaResponse( + messages=letta_messages, stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), usage=usage_stats + ) # Make checks assert_sanity_checks(response) @@ -332,7 +335,11 @@ def test_claude_initial_tool_rule_enforced( for m in messages: letta_messages += m.to_letta_messages() - response = LettaResponse(messages=letta_messages, usage=usage_stats) + response = LettaResponse( + messages=letta_messages, + stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), + usage=usage_stats, + ) assert_sanity_checks(response) @@ -407,7 +414,11 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized( for m in messages: letta_messages += m.to_letta_messages() - response = LettaResponse(messages=letta_messages, usage=usage_stats) + response = LettaResponse( + messages=letta_messages, + stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), + usage=usage_stats, + ) # Run assertions assert_sanity_checks(response) @@ -465,7 +476,11 @@ def test_init_tool_rule_always_fails( ) messages = [m for step in usage_stats.steps_messages for m in step] letta_messages = [msg for m in messages for msg in m.to_letta_messages()] - response = LettaResponse(messages=letta_messages, usage=usage_stats) + response = LettaResponse( + messages=letta_messages, + stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), + usage=usage_stats, + ) assert_invoked_function_call(response.messages, auto_error_tool.name) @@ -504,7 +519,11 @@ def test_continue_tool_rule(server, default_user): ) messages = [m for step in usage_stats.steps_messages for m in step] letta_messages = [msg for m in messages for msg in m.to_letta_messages()] - response = LettaResponse(messages=letta_messages, usage=usage_stats) + response = LettaResponse( + messages=letta_messages, + stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), + usage=usage_stats, + ) assert_invoked_function_call(response.messages, "send_message") assert_invoked_function_call(response.messages, "core_memory_append") diff --git a/tests/integration_test_pinecone_tool.py b/tests/integration_test_pinecone_tool.py index 9ce7589d..c9aecfb7 100644 --- a/tests/integration_test_pinecone_tool.py +++ b/tests/integration_test_pinecone_tool.py @@ -164,8 +164,9 @@ async def test_pinecone_tool(client: AsyncLetta) -> None: assert pinecone_results is not None, "No Pinecone results received from the agent." assert len(queries) > 0, "No queries received from the agent." + assert messages[-2].message_type == "stop_reason", "Penultimate message in stream must be stop reason." assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats." - response_messages_from_stream = [m for m in messages if m.message_type != "usage_statistics"] + response_messages_from_stream = [m for m in messages if m.message_type not in ["stop_reason", "usage_statistics"]] response_message_types_from_stream = [m.message_type for m in response_messages_from_stream] messages_from_db = await client.agents.messages.list( diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 93aea777..9a4ba908 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -17,6 +17,7 @@ from letta_client.types import ( Base64Image, HiddenReasoningMessage, ImageContent, + LettaStopReason, LettaUsageStatistics, ReasoningMessage, TextContent, @@ -67,7 +68,7 @@ USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [ USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [ MessageCreate( role="user", - content="This is an automated test message. Call the roll_dice tool with 16 sides and tell me the outcome.", + content="This is an automated test message. Call the roll_dice tool with 16 sides and send me a message with the outcome.", otid=USER_MESSAGE_OTID, ) ] @@ -128,7 +129,7 @@ def assert_greeting_with_assistant_message_response( Asserts that the messages list follows the expected sequence: ReasoningMessage -> AssistantMessage. """ - expected_message_count = 3 if streaming or from_db else 2 + expected_message_count = 4 if streaming else 3 if from_db else 2 assert len(messages) == expected_message_count index = 0 @@ -153,6 +154,9 @@ def assert_greeting_with_assistant_message_response( index += 1 if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 assert isinstance(messages[index], LettaUsageStatistics) assert messages[index].prompt_tokens > 0 assert messages[index].completion_tokens > 0 @@ -171,7 +175,7 @@ def assert_greeting_without_assistant_message_response( Asserts that the messages list follows the expected sequence: ReasoningMessage -> ToolCallMessage -> ToolReturnMessage. """ - expected_message_count = 4 if streaming or from_db else 3 + expected_message_count = 5 if streaming else 4 if from_db else 3 assert len(messages) == expected_message_count index = 0 @@ -201,7 +205,14 @@ def assert_greeting_without_assistant_message_response( index += 1 if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 def assert_tool_call_response( @@ -215,7 +226,7 @@ def assert_tool_call_response( ReasoningMessage -> ToolCallMessage -> ToolReturnMessage -> ReasoningMessage -> AssistantMessage. """ - expected_message_count = 6 if streaming else 7 if from_db else 5 + expected_message_count = 7 if streaming or from_db else 5 assert len(messages) == expected_message_count index = 0 @@ -260,7 +271,14 @@ def assert_tool_call_response( index += 1 if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 def assert_image_input_response( @@ -274,7 +292,7 @@ def assert_image_input_response( Asserts that the messages list follows the expected sequence: ReasoningMessage -> AssistantMessage. """ - expected_message_count = 3 if streaming or from_db else 2 + expected_message_count = 4 if streaming else 3 if from_db else 2 assert len(messages) == expected_message_count index = 0 @@ -296,6 +314,9 @@ def assert_image_input_response( index += 1 if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 assert isinstance(messages[index], LettaUsageStatistics) assert messages[index].prompt_tokens > 0 assert messages[index].completion_tokens > 0 diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 916f1863..8bcd0772 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -666,7 +666,7 @@ def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, ], include_return_message_types=message_types, ) - messages = [message for message in list(response) if message.message_type != "usage_statistics"] + messages = [message for message in list(response) if message.message_type not in ["stop_reason", "usage_statistics"]] verify_message_types(messages, message_types) elif message_create == "async": @@ -698,7 +698,7 @@ def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, ], include_return_message_types=message_types, ) - messages = [message for message in list(response) if message.message_type != "usage_statistics"] + messages = [message for message in list(response) if message.message_type not in ["stop_reason", "usage_statistics"]] verify_message_types(messages, message_types) elif message_create == "sync": diff --git a/tests/utils.py b/tests/utils.py index 86215d0f..d8b03395 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,10 +3,8 @@ import random import string import time from datetime import datetime, timezone -from importlib import util from typing import Dict, Iterator, List, Optional, Tuple -import requests from letta_client import Letta, SystemMessage from letta.config import LettaConfig