feat: add run_id to input messages and step_id to messages (#5099)
This commit is contained in:
committed by
Caren Thomas
parent
7c03288c05
commit
ef07e03ee3
@@ -59,6 +59,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
put_inner_thoughts_in_kwarg=self.llm_config.put_inner_thoughts_in_kwargs,
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
elif self.llm_config.model_endpoint_type == ProviderType.openai:
|
||||
# For non-v1 agents, always use Chat Completions streaming interface
|
||||
@@ -70,6 +71,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
tools=tools,
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
|
||||
|
||||
@@ -52,6 +52,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
self.interface = SimpleAnthropicStreamingInterface(
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
elif self.llm_config.model_endpoint_type == ProviderType.openai:
|
||||
# Decide interface based on payload shape
|
||||
@@ -65,6 +66,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
tools=tools,
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
self.interface = SimpleOpenAIStreamingInterface(
|
||||
@@ -74,6 +76,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
model=self.llm_config.model,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
|
||||
|
||||
@@ -82,6 +82,7 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
message_creates=[system_message_create] + input_messages,
|
||||
agent_id=self.agent_id,
|
||||
timezone=agent_state.timezone,
|
||||
run_id=None, # TODO: add this
|
||||
)
|
||||
|
||||
request_data = llm_client.build_request_data(agent_state.agent_type, messages, agent_state.llm_config, tools=[])
|
||||
|
||||
@@ -56,6 +56,7 @@ def _prepare_in_context_messages(
|
||||
agent_state: AgentState,
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
run_id: str,
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
@@ -65,6 +66,7 @@ def _prepare_in_context_messages(
|
||||
agent_state (AgentState): The current state of the agent, including message buffer config.
|
||||
message_manager (MessageManager): The manager used to retrieve and create messages.
|
||||
actor (User): The user performing the action, used for access control and attribution.
|
||||
run_id (str): The run ID associated with this message processing.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Message], List[Message]]: A tuple containing:
|
||||
@@ -81,7 +83,9 @@ def _prepare_in_context_messages(
|
||||
|
||||
# Create a new user message from the input and store it
|
||||
new_in_context_messages = message_manager.create_many_messages(
|
||||
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor),
|
||||
create_input_messages(
|
||||
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -93,6 +97,7 @@ async def _prepare_in_context_messages_async(
|
||||
agent_state: AgentState,
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
run_id: str,
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
@@ -103,6 +108,7 @@ async def _prepare_in_context_messages_async(
|
||||
agent_state (AgentState): The current state of the agent, including message buffer config.
|
||||
message_manager (MessageManager): The manager used to retrieve and create messages.
|
||||
actor (User): The user performing the action, used for access control and attribution.
|
||||
run_id (str): The run ID associated with this message processing.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Message], List[Message]]: A tuple containing:
|
||||
@@ -119,7 +125,9 @@ async def _prepare_in_context_messages_async(
|
||||
|
||||
# Create a new user message from the input and store it
|
||||
new_in_context_messages = await message_manager.create_many_messages_async(
|
||||
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor),
|
||||
create_input_messages(
|
||||
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
|
||||
),
|
||||
actor=actor,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
@@ -132,6 +140,7 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
agent_state: AgentState,
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
run_id: Optional[str] = None,
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
@@ -141,6 +150,7 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
agent_state (AgentState): The current state of the agent, including message buffer config.
|
||||
message_manager (MessageManager): The manager used to retrieve and create messages.
|
||||
actor (User): The user performing the action, used for access control and attribution.
|
||||
run_id (str): The run ID associated with this message processing.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Message], List[Message]]: A tuple containing:
|
||||
@@ -176,7 +186,7 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
|
||||
# Create a new user message from the input but dont store it yet
|
||||
new_in_context_messages = create_input_messages(
|
||||
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor
|
||||
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
|
||||
)
|
||||
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
|
||||
@@ -616,7 +616,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate]
|
||||
) -> List[Message]:
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
input_messages, agent_state, self.message_manager, self.actor, run_id=None
|
||||
)
|
||||
|
||||
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)
|
||||
|
||||
@@ -142,7 +142,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
"""
|
||||
request = {}
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, None
|
||||
)
|
||||
response = self._step(
|
||||
run_id=None,
|
||||
@@ -185,7 +185,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
|
||||
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
)
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
response_letta_messages = []
|
||||
@@ -283,7 +283,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
|
||||
try:
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
)
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
for i in range(max_steps):
|
||||
|
||||
@@ -84,7 +84,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
|
||||
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
)
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
response_letta_messages = []
|
||||
@@ -180,7 +180,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
try:
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
)
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
for i in range(max_steps):
|
||||
|
||||
@@ -13,6 +13,7 @@ def convert_message_creates_to_messages(
|
||||
message_creates: list[MessageCreate],
|
||||
agent_id: str,
|
||||
timezone: str,
|
||||
run_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
) -> list[Message]:
|
||||
@@ -21,6 +22,7 @@ def convert_message_creates_to_messages(
|
||||
message_create=create,
|
||||
agent_id=agent_id,
|
||||
timezone=timezone,
|
||||
run_id=run_id,
|
||||
wrap_user_message=wrap_user_message,
|
||||
wrap_system_message=wrap_system_message,
|
||||
)
|
||||
@@ -32,6 +34,7 @@ def _convert_message_create_to_message(
|
||||
message_create: MessageCreate,
|
||||
agent_id: str,
|
||||
timezone: str,
|
||||
run_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
) -> Message:
|
||||
@@ -81,4 +84,5 @@ def _convert_message_create_to_message(
|
||||
sender_id=message_create.sender_id,
|
||||
group_id=message_create.group_id,
|
||||
batch_item_id=message_create.batch_item_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
@@ -67,10 +67,12 @@ class AnthropicStreamingInterface:
|
||||
put_inner_thoughts_in_kwarg: bool = False,
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
):
|
||||
self.json_parser: JSONParser = PydanticJSONParser()
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
@@ -299,6 +301,7 @@ class AnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(hidden_reasoning_message)
|
||||
prev_message_type = hidden_reasoning_message.message_type
|
||||
@@ -345,6 +348,7 @@ class AnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
@@ -373,6 +377,7 @@ class AnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
@@ -495,6 +500,7 @@ class AnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
@@ -516,6 +522,7 @@ class AnthropicStreamingInterface:
|
||||
signature=delta.signature,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
@@ -550,9 +557,11 @@ class SimpleAnthropicStreamingInterface:
|
||||
self,
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
):
|
||||
self.json_parser: JSONParser = PydanticJSONParser()
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
@@ -764,6 +773,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
@@ -774,6 +784,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
@@ -795,6 +806,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
|
||||
self.reasoning_messages.append(hidden_reasoning_message)
|
||||
@@ -819,6 +831,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
# self.assistant_messages.append(assistant_msg)
|
||||
self.reasoning_messages.append(assistant_msg)
|
||||
@@ -842,6 +855,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
@@ -852,6 +866,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
|
||||
yield tool_call_msg
|
||||
@@ -872,6 +887,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
@@ -894,6 +910,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
signature=delta.signature,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
|
||||
@@ -77,12 +77,14 @@ class OpenAIStreamingInterface:
|
||||
put_inner_thoughts_in_kwarg: bool = True,
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
):
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
|
||||
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
|
||||
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
|
||||
@@ -247,6 +249,7 @@ class OpenAIStreamingInterface:
|
||||
hidden_reasoning=None,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
yield hidden_message
|
||||
prev_message_type = hidden_message.message_type
|
||||
@@ -287,6 +290,7 @@ class OpenAIStreamingInterface:
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
@@ -329,6 +333,7 @@ class OpenAIStreamingInterface:
|
||||
),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
tool_call_msg = ToolCallMessage(
|
||||
@@ -341,6 +346,7 @@ class OpenAIStreamingInterface:
|
||||
),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
@@ -387,6 +393,7 @@ class OpenAIStreamingInterface:
|
||||
content=extracted,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
@@ -413,6 +420,7 @@ class OpenAIStreamingInterface:
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
tool_call_msg = ToolCallMessage(
|
||||
@@ -426,6 +434,7 @@ class OpenAIStreamingInterface:
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
@@ -448,6 +457,7 @@ class OpenAIStreamingInterface:
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
tool_call_msg = ToolCallMessage(
|
||||
@@ -461,6 +471,7 @@ class OpenAIStreamingInterface:
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
@@ -482,8 +493,10 @@ class SimpleOpenAIStreamingInterface:
|
||||
requires_approval_tools: list = [],
|
||||
model: str = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
):
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
|
||||
@@ -579,6 +592,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
hidden_reasoning=None,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.content_messages.append(hidden_message)
|
||||
prev_message_type = hidden_message.message_type
|
||||
@@ -650,6 +664,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.content_messages.append(assistant_msg)
|
||||
prev_message_type = assistant_msg.message_type
|
||||
@@ -699,6 +714,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
tool_call_msg = ToolCallMessage(
|
||||
@@ -712,6 +728,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
message_index += 1 # Increment for the next message
|
||||
@@ -731,6 +748,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
requires_approval_tools: list = [],
|
||||
model: str = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
):
|
||||
self.is_openai_proxy = is_openai_proxy
|
||||
self.messages = messages
|
||||
@@ -741,6 +759,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
# ID responses used
|
||||
self.message_id = None
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
@@ -894,6 +913,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
source="reasoner_model",
|
||||
reasoning=concat_summary,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "reasoning_message"
|
||||
else:
|
||||
@@ -919,6 +939,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
tool_call_id=call_id,
|
||||
),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "tool_call_message"
|
||||
else:
|
||||
@@ -934,6 +955,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
tool_call_id=call_id,
|
||||
),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "tool_call_message"
|
||||
|
||||
@@ -951,6 +973,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
date=datetime.now(timezone.utc),
|
||||
content=content_item.text,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "assistant_message"
|
||||
else:
|
||||
@@ -1004,6 +1027,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
source="reasoner_model",
|
||||
reasoning=delta,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "reasoning_message"
|
||||
else:
|
||||
@@ -1047,6 +1071,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
date=datetime.now(timezone.utc),
|
||||
content=delta,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "assistant_message"
|
||||
else:
|
||||
@@ -1082,6 +1107,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
tool_call_id=None,
|
||||
),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "approval_request_message"
|
||||
else:
|
||||
@@ -1097,6 +1123,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
tool_call_id=None,
|
||||
),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "tool_call_message"
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ def capture_sentry_exception(e: BaseException):
|
||||
sentry_sdk.capture_exception(e)
|
||||
|
||||
|
||||
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, timezone: str, actor: User) -> List[Message]:
|
||||
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, timezone: str, run_id: str, actor: User) -> List[Message]:
|
||||
"""
|
||||
Converts a user input message into the internal structured format.
|
||||
|
||||
@@ -162,7 +162,9 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti
|
||||
we should unify this when it's clear what message attributes we need.
|
||||
"""
|
||||
|
||||
messages = convert_message_creates_to_messages(input_messages, agent_id, timezone, wrap_user_message=False, wrap_system_message=False)
|
||||
messages = convert_message_creates_to_messages(
|
||||
input_messages, agent_id, timezone, run_id, wrap_user_message=False, wrap_system_message=False
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
@@ -348,6 +350,7 @@ def create_letta_messages_from_llm_response(
|
||||
actor=actor,
|
||||
timezone=timezone,
|
||||
heartbeat_reason=heartbeat_reason,
|
||||
run_id=run_id,
|
||||
)
|
||||
messages.append(heartbeat_system_message)
|
||||
|
||||
@@ -365,6 +368,7 @@ def create_heartbeat_system_message(
|
||||
actor: User,
|
||||
llm_batch_item_id: Optional[str] = None,
|
||||
heartbeat_reason: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
) -> Message:
|
||||
if heartbeat_reason:
|
||||
text_content = heartbeat_reason
|
||||
@@ -380,6 +384,7 @@ def create_heartbeat_system_message(
|
||||
tool_call_id=None,
|
||||
created_at=get_utc_time(),
|
||||
batch_item_id=llm_batch_item_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
return heartbeat_system_message
|
||||
|
||||
|
||||
@@ -189,6 +189,7 @@ class Summarizer:
|
||||
# We already packed, don't pack again
|
||||
wrap_user_message=False,
|
||||
wrap_system_message=False,
|
||||
run_id=None, # TODO: add this
|
||||
)[0]
|
||||
|
||||
# Create the message in the DB
|
||||
|
||||
@@ -232,6 +232,13 @@ TESTED_LLM_CONFIGS = [
|
||||
]
|
||||
|
||||
|
||||
def assert_first_message_is_user_message(messages: List[Any]) -> None:
|
||||
"""
|
||||
Asserts that the first message is a user message.
|
||||
"""
|
||||
assert isinstance(messages[0], UserMessage)
|
||||
|
||||
|
||||
def assert_greeting_with_assistant_message_response(
|
||||
messages: List[Any],
|
||||
llm_config: LLMConfig,
|
||||
@@ -283,6 +290,24 @@ def assert_greeting_with_assistant_message_response(
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def assert_contains_run_id(messages: List[Any]) -> None:
|
||||
"""
|
||||
Asserts that the messages list contains a run_id.
|
||||
"""
|
||||
for message in messages:
|
||||
if hasattr(message, "run_id"):
|
||||
assert message.run_id is not None
|
||||
|
||||
|
||||
def assert_contains_step_id(messages: List[Any]) -> None:
|
||||
"""
|
||||
Asserts that the messages list contains a step_id.
|
||||
"""
|
||||
for message in messages:
|
||||
if hasattr(message, "step_id"):
|
||||
assert message.step_id is not None
|
||||
|
||||
|
||||
def assert_greeting_no_reasoning_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
@@ -410,6 +435,7 @@ def assert_tool_call_response(
|
||||
and getattr(messages[2], "message_type", None) == "tool_return_message"
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
assert len(messages) == expected_message_count, messages
|
||||
except:
|
||||
@@ -804,8 +830,10 @@ def test_greeting_with_assistant_message(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
assert_contains_run_id(response.messages)
|
||||
assert_greeting_with_assistant_message_response(response.messages, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_first_message_is_user_message(messages_from_db)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
|
||||
@@ -1024,9 +1052,13 @@ def test_step_streaming_greeting_with_assistant_message(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
messages = accumulate_chunks(list(response))
|
||||
chunks = list(response)
|
||||
assert_contains_step_id(chunks)
|
||||
assert_contains_run_id(chunks)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_contains_run_id(messages_from_db)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
|
||||
@@ -1514,7 +1546,8 @@ def test_async_greeting_with_assistant_message(
|
||||
messages = client.runs.messages.list(run_id=run.id)
|
||||
usage = client.runs.usage.retrieve(run_id=run.id)
|
||||
|
||||
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
|
||||
# TODO: add results API test later
|
||||
assert_greeting_with_assistant_message_response(messages, from_db=True, llm_config=llm_config) # TODO: remove from_db=True later
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
@@ -1595,7 +1628,8 @@ def test_async_tool_call(
|
||||
)
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
messages = client.runs.messages.list(run_id=run.id)
|
||||
assert_tool_call_response(messages, llm_config=llm_config)
|
||||
# TODO: add test for response api
|
||||
assert_tool_call_response(messages, from_db=True, llm_config=llm_config) # NOTE: skip first message which is the user message
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
@@ -1725,7 +1759,7 @@ def test_async_greeting_with_callback_url(
|
||||
|
||||
# Validate job completed successfully
|
||||
messages = client.runs.messages.list(run_id=run.id)
|
||||
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
|
||||
assert_greeting_with_assistant_message_response(messages, from_db=True, llm_config=llm_config)
|
||||
|
||||
# Validate callback was received
|
||||
assert server.wait_for_callback(timeout=15), "Callback was not received within timeout"
|
||||
|
||||
Reference in New Issue
Block a user