diff --git a/letta/adapters/letta_llm_stream_adapter.py b/letta/adapters/letta_llm_stream_adapter.py index 1d3b64e4..f8408e3e 100644 --- a/letta/adapters/letta_llm_stream_adapter.py +++ b/letta/adapters/letta_llm_stream_adapter.py @@ -59,6 +59,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): put_inner_thoughts_in_kwarg=self.llm_config.put_inner_thoughts_in_kwargs, requires_approval_tools=requires_approval_tools, run_id=self.run_id, + step_id=step_id, ) elif self.llm_config.model_endpoint_type == ProviderType.openai: # For non-v1 agents, always use Chat Completions streaming interface @@ -70,6 +71,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): tools=tools, requires_approval_tools=requires_approval_tools, run_id=self.run_id, + step_id=step_id, ) else: raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}") diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index 97f65bc3..1e55a07e 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -52,6 +52,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): self.interface = SimpleAnthropicStreamingInterface( requires_approval_tools=requires_approval_tools, run_id=self.run_id, + step_id=step_id, ) elif self.llm_config.model_endpoint_type == ProviderType.openai: # Decide interface based on payload shape @@ -65,6 +66,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): tools=tools, requires_approval_tools=requires_approval_tools, run_id=self.run_id, + step_id=step_id, ) else: self.interface = SimpleOpenAIStreamingInterface( @@ -74,6 +76,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): requires_approval_tools=requires_approval_tools, model=self.llm_config.model, run_id=self.run_id, + step_id=step_id, ) else: raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}") diff --git a/letta/agents/ephemeral_summary_agent.py b/letta/agents/ephemeral_summary_agent.py index af56235c..b73c3f26 100644 --- a/letta/agents/ephemeral_summary_agent.py +++ b/letta/agents/ephemeral_summary_agent.py @@ -82,6 +82,7 @@ class EphemeralSummaryAgent(BaseAgent): message_creates=[system_message_create] + input_messages, agent_id=self.agent_id, timezone=agent_state.timezone, + run_id=None, # TODO: add this ) request_data = llm_client.build_request_data(agent_state.agent_type, messages, agent_state.llm_config, tools=[]) diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index e511c846..228a2f7c 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -56,6 +56,7 @@ def _prepare_in_context_messages( agent_state: AgentState, message_manager: MessageManager, actor: User, + run_id: str, ) -> Tuple[List[Message], List[Message]]: """ Prepares in-context messages for an agent, based on the current state and a new user input. @@ -65,6 +66,7 @@ def _prepare_in_context_messages( agent_state (AgentState): The current state of the agent, including message buffer config. message_manager (MessageManager): The manager used to retrieve and create messages. actor (User): The user performing the action, used for access control and attribution. + run_id (str): The run ID associated with this message processing. Returns: Tuple[List[Message], List[Message]]: A tuple containing: @@ -81,7 +83,9 @@ def _prepare_in_context_messages( # Create a new user message from the input and store it new_in_context_messages = message_manager.create_many_messages( - create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor), + create_input_messages( + input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor + ), actor=actor, ) @@ -93,6 +97,7 @@ async def _prepare_in_context_messages_async( agent_state: AgentState, message_manager: MessageManager, actor: User, + run_id: str, ) -> Tuple[List[Message], List[Message]]: """ Prepares in-context messages for an agent, based on the current state and a new user input. @@ -103,6 +108,7 @@ async def _prepare_in_context_messages_async( agent_state (AgentState): The current state of the agent, including message buffer config. message_manager (MessageManager): The manager used to retrieve and create messages. actor (User): The user performing the action, used for access control and attribution. + run_id (str): The run ID associated with this message processing. Returns: Tuple[List[Message], List[Message]]: A tuple containing: @@ -119,7 +125,9 @@ async def _prepare_in_context_messages_async( # Create a new user message from the input and store it new_in_context_messages = await message_manager.create_many_messages_async( - create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor), + create_input_messages( + input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor + ), actor=actor, project_id=agent_state.project_id, ) @@ -132,6 +140,7 @@ async def _prepare_in_context_messages_no_persist_async( agent_state: AgentState, message_manager: MessageManager, actor: User, + run_id: Optional[str] = None, ) -> Tuple[List[Message], List[Message]]: """ Prepares in-context messages for an agent, based on the current state and a new user input. @@ -141,6 +150,7 @@ async def _prepare_in_context_messages_no_persist_async( agent_state (AgentState): The current state of the agent, including message buffer config. message_manager (MessageManager): The manager used to retrieve and create messages. actor (User): The user performing the action, used for access control and attribution. + run_id (str): The run ID associated with this message processing. Returns: Tuple[List[Message], List[Message]]: A tuple containing: @@ -176,7 +186,7 @@ async def _prepare_in_context_messages_no_persist_async( # Create a new user message from the input but dont store it yet new_in_context_messages = create_input_messages( - input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor + input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor ) return current_in_context_messages, new_in_context_messages diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index c0385e1d..875147f7 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -616,7 +616,7 @@ class LettaAgentBatch(BaseAgent): self, agent_state: AgentState, input_messages: List[MessageCreate] ) -> List[Message]: current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( - input_messages, agent_state, self.message_manager, self.actor + input_messages, agent_state, self.message_manager, self.actor, run_id=None ) in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state) diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 82ec1950..e12a2306 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -142,7 +142,7 @@ class LettaAgentV2(BaseAgentV2): """ request = {} in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( - input_messages, self.agent_state, self.message_manager, self.actor + input_messages, self.agent_state, self.message_manager, self.actor, None ) response = self._step( run_id=None, @@ -185,7 +185,7 @@ class LettaAgentV2(BaseAgentV2): request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns) in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( - input_messages, self.agent_state, self.message_manager, self.actor + input_messages, self.agent_state, self.message_manager, self.actor, run_id ) in_context_messages = in_context_messages + input_messages_to_persist response_letta_messages = [] @@ -283,7 +283,7 @@ class LettaAgentV2(BaseAgentV2): try: in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( - input_messages, self.agent_state, self.message_manager, self.actor + input_messages, self.agent_state, self.message_manager, self.actor, run_id ) in_context_messages = in_context_messages + input_messages_to_persist for i in range(max_steps): diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index db95249d..a6beb6d1 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -84,7 +84,7 @@ class LettaAgentV3(LettaAgentV2): request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns) in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( - input_messages, self.agent_state, self.message_manager, self.actor + input_messages, self.agent_state, self.message_manager, self.actor, run_id ) in_context_messages = in_context_messages + input_messages_to_persist response_letta_messages = [] @@ -180,7 +180,7 @@ class LettaAgentV3(LettaAgentV2): try: in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( - input_messages, self.agent_state, self.message_manager, self.actor + input_messages, self.agent_state, self.message_manager, self.actor, run_id ) in_context_messages = in_context_messages + input_messages_to_persist for i in range(max_steps): diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index 47c58f71..b834a305 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -13,6 +13,7 @@ def convert_message_creates_to_messages( message_creates: list[MessageCreate], agent_id: str, timezone: str, + run_id: str, wrap_user_message: bool = True, wrap_system_message: bool = True, ) -> list[Message]: @@ -21,6 +22,7 @@ def convert_message_creates_to_messages( message_create=create, agent_id=agent_id, timezone=timezone, + run_id=run_id, wrap_user_message=wrap_user_message, wrap_system_message=wrap_system_message, ) @@ -32,6 +34,7 @@ def _convert_message_create_to_message( message_create: MessageCreate, agent_id: str, timezone: str, + run_id: str, wrap_user_message: bool = True, wrap_system_message: bool = True, ) -> Message: @@ -81,4 +84,5 @@ def _convert_message_create_to_message( sender_id=message_create.sender_id, group_id=message_create.group_id, batch_item_id=message_create.batch_item_id, + run_id=run_id, ) diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 86e82723..3b87b4ff 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -67,10 +67,12 @@ class AnthropicStreamingInterface: put_inner_thoughts_in_kwarg: bool = False, requires_approval_tools: list = [], run_id: str | None = None, + step_id: str | None = None, ): self.json_parser: JSONParser = PydanticJSONParser() self.use_assistant_message = use_assistant_message self.run_id = run_id + self.step_id = step_id # Premake IDs for database writes self.letta_message_id = Message.generate_id() @@ -299,6 +301,7 @@ class AnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.reasoning_messages.append(hidden_reasoning_message) prev_message_type = hidden_reasoning_message.message_type @@ -345,6 +348,7 @@ class AnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type @@ -373,6 +377,7 @@ class AnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type @@ -495,6 +500,7 @@ class AnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type @@ -516,6 +522,7 @@ class AnthropicStreamingInterface: signature=delta.signature, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type @@ -550,9 +557,11 @@ class SimpleAnthropicStreamingInterface: self, requires_approval_tools: list = [], run_id: str | None = None, + step_id: str | None = None, ): self.json_parser: JSONParser = PydanticJSONParser() self.run_id = run_id + self.step_id = step_id # Premake IDs for database writes self.letta_message_id = Message.generate_id() @@ -764,6 +773,7 @@ class SimpleAnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) else: if prev_message_type and prev_message_type != "tool_call_message": @@ -774,6 +784,7 @@ class SimpleAnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = tool_call_msg.message_type yield tool_call_msg @@ -795,6 +806,7 @@ class SimpleAnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.reasoning_messages.append(hidden_reasoning_message) @@ -819,6 +831,7 @@ class SimpleAnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) # self.assistant_messages.append(assistant_msg) self.reasoning_messages.append(assistant_msg) @@ -842,6 +855,7 @@ class SimpleAnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) else: if prev_message_type and prev_message_type != "tool_call_message": @@ -852,6 +866,7 @@ class SimpleAnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) yield tool_call_msg @@ -872,6 +887,7 @@ class SimpleAnthropicStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type @@ -894,6 +910,7 @@ class SimpleAnthropicStreamingInterface: signature=delta.signature, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index d5c58df6..6e539579 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -77,12 +77,14 @@ class OpenAIStreamingInterface: put_inner_thoughts_in_kwarg: bool = True, requires_approval_tools: list = [], run_id: str | None = None, + step_id: str | None = None, ): self.use_assistant_message = use_assistant_message self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg self.run_id = run_id + self.step_id = step_id self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg) @@ -247,6 +249,7 @@ class OpenAIStreamingInterface: hidden_reasoning=None, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) yield hidden_message prev_message_type = hidden_message.message_type @@ -287,6 +290,7 @@ class OpenAIStreamingInterface: # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = reasoning_message.message_type yield reasoning_message @@ -329,6 +333,7 @@ class OpenAIStreamingInterface: ), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) else: tool_call_msg = ToolCallMessage( @@ -341,6 +346,7 @@ class OpenAIStreamingInterface: ), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = tool_call_msg.message_type yield tool_call_msg @@ -387,6 +393,7 @@ class OpenAIStreamingInterface: content=extracted, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = assistant_message.message_type yield assistant_message @@ -413,6 +420,7 @@ class OpenAIStreamingInterface: # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) else: tool_call_msg = ToolCallMessage( @@ -426,6 +434,7 @@ class OpenAIStreamingInterface: # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = tool_call_msg.message_type yield tool_call_msg @@ -448,6 +457,7 @@ class OpenAIStreamingInterface: # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) else: tool_call_msg = ToolCallMessage( @@ -461,6 +471,7 @@ class OpenAIStreamingInterface: # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = tool_call_msg.message_type yield tool_call_msg @@ -482,8 +493,10 @@ class SimpleOpenAIStreamingInterface: requires_approval_tools: list = [], model: str = None, run_id: str | None = None, + step_id: str | None = None, ): self.run_id = run_id + self.step_id = step_id # Premake IDs for database writes self.letta_message_id = Message.generate_id() @@ -579,6 +592,7 @@ class SimpleOpenAIStreamingInterface: hidden_reasoning=None, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.content_messages.append(hidden_message) prev_message_type = hidden_message.message_type @@ -650,6 +664,7 @@ class SimpleOpenAIStreamingInterface: date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) self.content_messages.append(assistant_msg) prev_message_type = assistant_msg.message_type @@ -699,6 +714,7 @@ class SimpleOpenAIStreamingInterface: # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) else: tool_call_msg = ToolCallMessage( @@ -712,6 +728,7 @@ class SimpleOpenAIStreamingInterface: # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = tool_call_msg.message_type message_index += 1 # Increment for the next message @@ -731,6 +748,7 @@ class SimpleOpenAIResponsesStreamingInterface: requires_approval_tools: list = [], model: str = None, run_id: str | None = None, + step_id: str | None = None, ): self.is_openai_proxy = is_openai_proxy self.messages = messages @@ -741,6 +759,7 @@ class SimpleOpenAIResponsesStreamingInterface: # ID responses used self.message_id = None self.run_id = run_id + self.step_id = step_id # Premake IDs for database writes self.letta_message_id = Message.generate_id() @@ -894,6 +913,7 @@ class SimpleOpenAIResponsesStreamingInterface: source="reasoner_model", reasoning=concat_summary, run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = "reasoning_message" else: @@ -919,6 +939,7 @@ class SimpleOpenAIResponsesStreamingInterface: tool_call_id=call_id, ), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = "tool_call_message" else: @@ -934,6 +955,7 @@ class SimpleOpenAIResponsesStreamingInterface: tool_call_id=call_id, ), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = "tool_call_message" @@ -951,6 +973,7 @@ class SimpleOpenAIResponsesStreamingInterface: date=datetime.now(timezone.utc), content=content_item.text, run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = "assistant_message" else: @@ -1004,6 +1027,7 @@ class SimpleOpenAIResponsesStreamingInterface: source="reasoner_model", reasoning=delta, run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = "reasoning_message" else: @@ -1047,6 +1071,7 @@ class SimpleOpenAIResponsesStreamingInterface: date=datetime.now(timezone.utc), content=delta, run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = "assistant_message" else: @@ -1082,6 +1107,7 @@ class SimpleOpenAIResponsesStreamingInterface: tool_call_id=None, ), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = "approval_request_message" else: @@ -1097,6 +1123,7 @@ class SimpleOpenAIResponsesStreamingInterface: tool_call_id=None, ), run_id=self.run_id, + step_id=self.step_id, ) prev_message_type = "tool_call_message" diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 45c634bb..726f7c0a 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -154,7 +154,7 @@ def capture_sentry_exception(e: BaseException): sentry_sdk.capture_exception(e) -def create_input_messages(input_messages: List[MessageCreate], agent_id: str, timezone: str, actor: User) -> List[Message]: +def create_input_messages(input_messages: List[MessageCreate], agent_id: str, timezone: str, run_id: str, actor: User) -> List[Message]: """ Converts a user input message into the internal structured format. @@ -162,7 +162,9 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti we should unify this when it's clear what message attributes we need. """ - messages = convert_message_creates_to_messages(input_messages, agent_id, timezone, wrap_user_message=False, wrap_system_message=False) + messages = convert_message_creates_to_messages( + input_messages, agent_id, timezone, run_id, wrap_user_message=False, wrap_system_message=False + ) return messages @@ -348,6 +350,7 @@ def create_letta_messages_from_llm_response( actor=actor, timezone=timezone, heartbeat_reason=heartbeat_reason, + run_id=run_id, ) messages.append(heartbeat_system_message) @@ -365,6 +368,7 @@ def create_heartbeat_system_message( actor: User, llm_batch_item_id: Optional[str] = None, heartbeat_reason: Optional[str] = None, + run_id: Optional[str] = None, ) -> Message: if heartbeat_reason: text_content = heartbeat_reason @@ -380,6 +384,7 @@ def create_heartbeat_system_message( tool_call_id=None, created_at=get_utc_time(), batch_item_id=llm_batch_item_id, + run_id=run_id, ) return heartbeat_system_message diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index 67530d01..ef54fd46 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -189,6 +189,7 @@ class Summarizer: # We already packed, don't pack again wrap_user_message=False, wrap_system_message=False, + run_id=None, # TODO: add this )[0] # Create the message in the DB diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 209cd49b..2380bda0 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -232,6 +232,13 @@ TESTED_LLM_CONFIGS = [ ] +def assert_first_message_is_user_message(messages: List[Any]) -> None: + """ + Asserts that the first message is a user message. + """ + assert isinstance(messages[0], UserMessage) + + def assert_greeting_with_assistant_message_response( messages: List[Any], llm_config: LLMConfig, @@ -283,6 +290,24 @@ def assert_greeting_with_assistant_message_response( assert messages[index].step_count > 0 +def assert_contains_run_id(messages: List[Any]) -> None: + """ + Asserts that the messages list contains a run_id. + """ + for message in messages: + if hasattr(message, "run_id"): + assert message.run_id is not None + + +def assert_contains_step_id(messages: List[Any]) -> None: + """ + Asserts that the messages list contains a step_id. + """ + for message in messages: + if hasattr(message, "step_id"): + assert message.step_id is not None + + def assert_greeting_no_reasoning_response( messages: List[Any], streaming: bool = False, @@ -410,6 +435,7 @@ def assert_tool_call_response( and getattr(messages[2], "message_type", None) == "tool_return_message" ): return + try: assert len(messages) == expected_message_count, messages except: @@ -804,8 +830,10 @@ def test_greeting_with_assistant_message( agent_id=agent_state.id, messages=USER_MESSAGE_FORCE_REPLY, ) + assert_contains_run_id(response.messages) assert_greeting_with_assistant_message_response(response.messages, llm_config=llm_config) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_first_message_is_user_message(messages_from_db) assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) @@ -1024,9 +1052,13 @@ def test_step_streaming_greeting_with_assistant_message( agent_id=agent_state.id, messages=USER_MESSAGE_FORCE_REPLY, ) - messages = accumulate_chunks(list(response)) + chunks = list(response) + assert_contains_step_id(chunks) + assert_contains_run_id(chunks) + messages = accumulate_chunks(chunks) assert_greeting_with_assistant_message_response(messages, streaming=True, llm_config=llm_config) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_contains_run_id(messages_from_db) assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) @@ -1514,7 +1546,8 @@ def test_async_greeting_with_assistant_message( messages = client.runs.messages.list(run_id=run.id) usage = client.runs.usage.retrieve(run_id=run.id) - assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) + # TODO: add results API test later + assert_greeting_with_assistant_message_response(messages, from_db=True, llm_config=llm_config) # TODO: remove from_db=True later messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) @@ -1595,7 +1628,8 @@ def test_async_tool_call( ) run = wait_for_run_completion(client, run.id) messages = client.runs.messages.list(run_id=run.id) - assert_tool_call_response(messages, llm_config=llm_config) + # TODO: add test for response api + assert_tool_call_response(messages, from_db=True, llm_config=llm_config) # NOTE: skip first message which is the user message messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) @@ -1725,7 +1759,7 @@ def test_async_greeting_with_callback_url( # Validate job completed successfully messages = client.runs.messages.list(run_id=run.id) - assert_greeting_with_assistant_message_response(messages, llm_config=llm_config) + assert_greeting_with_assistant_message_response(messages, from_db=True, llm_config=llm_config) # Validate callback was received assert server.wait_for_callback(timeout=15), "Callback was not received within timeout"