feat: add run_id to input messages and step_id to messages (#5099)

This commit is contained in:
Sarah Wooders
2025-10-02 17:44:14 -07:00
committed by Caren Thomas
parent 7c03288c05
commit ef07e03ee3
13 changed files with 119 additions and 15 deletions

View File

@@ -59,6 +59,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
put_inner_thoughts_in_kwarg=self.llm_config.put_inner_thoughts_in_kwargs,
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
)
elif self.llm_config.model_endpoint_type == ProviderType.openai:
# For non-v1 agents, always use Chat Completions streaming interface
@@ -70,6 +71,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
tools=tools,
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
)
else:
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")

View File

@@ -52,6 +52,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
self.interface = SimpleAnthropicStreamingInterface(
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
)
elif self.llm_config.model_endpoint_type == ProviderType.openai:
# Decide interface based on payload shape
@@ -65,6 +66,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
tools=tools,
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
)
else:
self.interface = SimpleOpenAIStreamingInterface(
@@ -74,6 +76,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
requires_approval_tools=requires_approval_tools,
model=self.llm_config.model,
run_id=self.run_id,
step_id=step_id,
)
else:
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")

View File

@@ -82,6 +82,7 @@ class EphemeralSummaryAgent(BaseAgent):
message_creates=[system_message_create] + input_messages,
agent_id=self.agent_id,
timezone=agent_state.timezone,
run_id=None, # TODO: add this
)
request_data = llm_client.build_request_data(agent_state.agent_type, messages, agent_state.llm_config, tools=[])

View File

@@ -56,6 +56,7 @@ def _prepare_in_context_messages(
agent_state: AgentState,
message_manager: MessageManager,
actor: User,
run_id: str,
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
@@ -65,6 +66,7 @@ def _prepare_in_context_messages(
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
run_id (str): The run ID associated with this message processing.
Returns:
Tuple[List[Message], List[Message]]: A tuple containing:
@@ -81,7 +83,9 @@ def _prepare_in_context_messages(
# Create a new user message from the input and store it
new_in_context_messages = message_manager.create_many_messages(
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor),
create_input_messages(
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
),
actor=actor,
)
@@ -93,6 +97,7 @@ async def _prepare_in_context_messages_async(
agent_state: AgentState,
message_manager: MessageManager,
actor: User,
run_id: str,
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
@@ -103,6 +108,7 @@ async def _prepare_in_context_messages_async(
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
run_id (str): The run ID associated with this message processing.
Returns:
Tuple[List[Message], List[Message]]: A tuple containing:
@@ -119,7 +125,9 @@ async def _prepare_in_context_messages_async(
# Create a new user message from the input and store it
new_in_context_messages = await message_manager.create_many_messages_async(
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor),
create_input_messages(
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
),
actor=actor,
project_id=agent_state.project_id,
)
@@ -132,6 +140,7 @@ async def _prepare_in_context_messages_no_persist_async(
agent_state: AgentState,
message_manager: MessageManager,
actor: User,
run_id: Optional[str] = None,
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
@@ -141,6 +150,7 @@ async def _prepare_in_context_messages_no_persist_async(
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
run_id (str): The run ID associated with this message processing.
Returns:
Tuple[List[Message], List[Message]]: A tuple containing:
@@ -176,7 +186,7 @@ async def _prepare_in_context_messages_no_persist_async(
# Create a new user message from the input but dont store it yet
new_in_context_messages = create_input_messages(
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
)
return current_in_context_messages, new_in_context_messages

View File

@@ -616,7 +616,7 @@ class LettaAgentBatch(BaseAgent):
self, agent_state: AgentState, input_messages: List[MessageCreate]
) -> List[Message]:
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
input_messages, agent_state, self.message_manager, self.actor, run_id=None
)
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)

View File

@@ -142,7 +142,7 @@ class LettaAgentV2(BaseAgentV2):
"""
request = {}
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor
input_messages, self.agent_state, self.message_manager, self.actor, None
)
response = self._step(
run_id=None,
@@ -185,7 +185,7 @@ class LettaAgentV2(BaseAgentV2):
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor
input_messages, self.agent_state, self.message_manager, self.actor, run_id
)
in_context_messages = in_context_messages + input_messages_to_persist
response_letta_messages = []
@@ -283,7 +283,7 @@ class LettaAgentV2(BaseAgentV2):
try:
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor
input_messages, self.agent_state, self.message_manager, self.actor, run_id
)
in_context_messages = in_context_messages + input_messages_to_persist
for i in range(max_steps):

View File

@@ -84,7 +84,7 @@ class LettaAgentV3(LettaAgentV2):
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor
input_messages, self.agent_state, self.message_manager, self.actor, run_id
)
in_context_messages = in_context_messages + input_messages_to_persist
response_letta_messages = []
@@ -180,7 +180,7 @@ class LettaAgentV3(LettaAgentV2):
try:
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor
input_messages, self.agent_state, self.message_manager, self.actor, run_id
)
in_context_messages = in_context_messages + input_messages_to_persist
for i in range(max_steps):

View File

@@ -13,6 +13,7 @@ def convert_message_creates_to_messages(
message_creates: list[MessageCreate],
agent_id: str,
timezone: str,
run_id: str,
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> list[Message]:
@@ -21,6 +22,7 @@ def convert_message_creates_to_messages(
message_create=create,
agent_id=agent_id,
timezone=timezone,
run_id=run_id,
wrap_user_message=wrap_user_message,
wrap_system_message=wrap_system_message,
)
@@ -32,6 +34,7 @@ def _convert_message_create_to_message(
message_create: MessageCreate,
agent_id: str,
timezone: str,
run_id: str,
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> Message:
@@ -81,4 +84,5 @@ def _convert_message_create_to_message(
sender_id=message_create.sender_id,
group_id=message_create.group_id,
batch_item_id=message_create.batch_item_id,
run_id=run_id,
)

View File

@@ -67,10 +67,12 @@ class AnthropicStreamingInterface:
put_inner_thoughts_in_kwarg: bool = False,
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
):
self.json_parser: JSONParser = PydanticJSONParser()
self.use_assistant_message = use_assistant_message
self.run_id = run_id
self.step_id = step_id
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -299,6 +301,7 @@ class AnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(hidden_reasoning_message)
prev_message_type = hidden_reasoning_message.message_type
@@ -345,6 +348,7 @@ class AnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -373,6 +377,7 @@ class AnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -495,6 +500,7 @@ class AnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -516,6 +522,7 @@ class AnthropicStreamingInterface:
signature=delta.signature,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -550,9 +557,11 @@ class SimpleAnthropicStreamingInterface:
self,
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
):
self.json_parser: JSONParser = PydanticJSONParser()
self.run_id = run_id
self.step_id = step_id
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -764,6 +773,7 @@ class SimpleAnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
@@ -774,6 +784,7 @@ class SimpleAnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -795,6 +806,7 @@ class SimpleAnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(hidden_reasoning_message)
@@ -819,6 +831,7 @@ class SimpleAnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
# self.assistant_messages.append(assistant_msg)
self.reasoning_messages.append(assistant_msg)
@@ -842,6 +855,7 @@ class SimpleAnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
@@ -852,6 +866,7 @@ class SimpleAnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
yield tool_call_msg
@@ -872,6 +887,7 @@ class SimpleAnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -894,6 +910,7 @@ class SimpleAnthropicStreamingInterface:
signature=delta.signature,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type

View File

@@ -77,12 +77,14 @@ class OpenAIStreamingInterface:
put_inner_thoughts_in_kwarg: bool = True,
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
):
self.use_assistant_message = use_assistant_message
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
self.run_id = run_id
self.step_id = step_id
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
@@ -247,6 +249,7 @@ class OpenAIStreamingInterface:
hidden_reasoning=None,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
yield hidden_message
prev_message_type = hidden_message.message_type
@@ -287,6 +290,7 @@ class OpenAIStreamingInterface:
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = reasoning_message.message_type
yield reasoning_message
@@ -329,6 +333,7 @@ class OpenAIStreamingInterface:
),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
else:
tool_call_msg = ToolCallMessage(
@@ -341,6 +346,7 @@ class OpenAIStreamingInterface:
),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -387,6 +393,7 @@ class OpenAIStreamingInterface:
content=extracted,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = assistant_message.message_type
yield assistant_message
@@ -413,6 +420,7 @@ class OpenAIStreamingInterface:
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
else:
tool_call_msg = ToolCallMessage(
@@ -426,6 +434,7 @@ class OpenAIStreamingInterface:
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -448,6 +457,7 @@ class OpenAIStreamingInterface:
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
else:
tool_call_msg = ToolCallMessage(
@@ -461,6 +471,7 @@ class OpenAIStreamingInterface:
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -482,8 +493,10 @@ class SimpleOpenAIStreamingInterface:
requires_approval_tools: list = [],
model: str = None,
run_id: str | None = None,
step_id: str | None = None,
):
self.run_id = run_id
self.step_id = step_id
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -579,6 +592,7 @@ class SimpleOpenAIStreamingInterface:
hidden_reasoning=None,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.content_messages.append(hidden_message)
prev_message_type = hidden_message.message_type
@@ -650,6 +664,7 @@ class SimpleOpenAIStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
self.content_messages.append(assistant_msg)
prev_message_type = assistant_msg.message_type
@@ -699,6 +714,7 @@ class SimpleOpenAIStreamingInterface:
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
else:
tool_call_msg = ToolCallMessage(
@@ -712,6 +728,7 @@ class SimpleOpenAIStreamingInterface:
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
message_index += 1 # Increment for the next message
@@ -731,6 +748,7 @@ class SimpleOpenAIResponsesStreamingInterface:
requires_approval_tools: list = [],
model: str = None,
run_id: str | None = None,
step_id: str | None = None,
):
self.is_openai_proxy = is_openai_proxy
self.messages = messages
@@ -741,6 +759,7 @@ class SimpleOpenAIResponsesStreamingInterface:
# ID responses used
self.message_id = None
self.run_id = run_id
self.step_id = step_id
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -894,6 +913,7 @@ class SimpleOpenAIResponsesStreamingInterface:
source="reasoner_model",
reasoning=concat_summary,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "reasoning_message"
else:
@@ -919,6 +939,7 @@ class SimpleOpenAIResponsesStreamingInterface:
tool_call_id=call_id,
),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "tool_call_message"
else:
@@ -934,6 +955,7 @@ class SimpleOpenAIResponsesStreamingInterface:
tool_call_id=call_id,
),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "tool_call_message"
@@ -951,6 +973,7 @@ class SimpleOpenAIResponsesStreamingInterface:
date=datetime.now(timezone.utc),
content=content_item.text,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "assistant_message"
else:
@@ -1004,6 +1027,7 @@ class SimpleOpenAIResponsesStreamingInterface:
source="reasoner_model",
reasoning=delta,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "reasoning_message"
else:
@@ -1047,6 +1071,7 @@ class SimpleOpenAIResponsesStreamingInterface:
date=datetime.now(timezone.utc),
content=delta,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "assistant_message"
else:
@@ -1082,6 +1107,7 @@ class SimpleOpenAIResponsesStreamingInterface:
tool_call_id=None,
),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "approval_request_message"
else:
@@ -1097,6 +1123,7 @@ class SimpleOpenAIResponsesStreamingInterface:
tool_call_id=None,
),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "tool_call_message"

View File

@@ -154,7 +154,7 @@ def capture_sentry_exception(e: BaseException):
sentry_sdk.capture_exception(e)
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, timezone: str, actor: User) -> List[Message]:
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, timezone: str, run_id: str, actor: User) -> List[Message]:
"""
Converts a user input message into the internal structured format.
@@ -162,7 +162,9 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti
we should unify this when it's clear what message attributes we need.
"""
messages = convert_message_creates_to_messages(input_messages, agent_id, timezone, wrap_user_message=False, wrap_system_message=False)
messages = convert_message_creates_to_messages(
input_messages, agent_id, timezone, run_id, wrap_user_message=False, wrap_system_message=False
)
return messages
@@ -348,6 +350,7 @@ def create_letta_messages_from_llm_response(
actor=actor,
timezone=timezone,
heartbeat_reason=heartbeat_reason,
run_id=run_id,
)
messages.append(heartbeat_system_message)
@@ -365,6 +368,7 @@ def create_heartbeat_system_message(
actor: User,
llm_batch_item_id: Optional[str] = None,
heartbeat_reason: Optional[str] = None,
run_id: Optional[str] = None,
) -> Message:
if heartbeat_reason:
text_content = heartbeat_reason
@@ -380,6 +384,7 @@ def create_heartbeat_system_message(
tool_call_id=None,
created_at=get_utc_time(),
batch_item_id=llm_batch_item_id,
run_id=run_id,
)
return heartbeat_system_message

View File

@@ -189,6 +189,7 @@ class Summarizer:
# We already packed, don't pack again
wrap_user_message=False,
wrap_system_message=False,
run_id=None, # TODO: add this
)[0]
# Create the message in the DB

View File

@@ -232,6 +232,13 @@ TESTED_LLM_CONFIGS = [
]
def assert_first_message_is_user_message(messages: List[Any]) -> None:
"""
Asserts that the first message is a user message.
"""
assert isinstance(messages[0], UserMessage)
def assert_greeting_with_assistant_message_response(
messages: List[Any],
llm_config: LLMConfig,
@@ -283,6 +290,24 @@ def assert_greeting_with_assistant_message_response(
assert messages[index].step_count > 0
def assert_contains_run_id(messages: List[Any]) -> None:
"""
Asserts that the messages list contains a run_id.
"""
for message in messages:
if hasattr(message, "run_id"):
assert message.run_id is not None
def assert_contains_step_id(messages: List[Any]) -> None:
"""
Asserts that the messages list contains a step_id.
"""
for message in messages:
if hasattr(message, "step_id"):
assert message.step_id is not None
def assert_greeting_no_reasoning_response(
messages: List[Any],
streaming: bool = False,
@@ -410,6 +435,7 @@ def assert_tool_call_response(
and getattr(messages[2], "message_type", None) == "tool_return_message"
):
return
try:
assert len(messages) == expected_message_count, messages
except:
@@ -804,8 +830,10 @@ def test_greeting_with_assistant_message(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
assert_contains_run_id(response.messages)
assert_greeting_with_assistant_message_response(response.messages, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_first_message_is_user_message(messages_from_db)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@@ -1024,9 +1052,13 @@ def test_step_streaming_greeting_with_assistant_message(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
messages = accumulate_chunks(list(response))
chunks = list(response)
assert_contains_step_id(chunks)
assert_contains_run_id(chunks)
messages = accumulate_chunks(chunks)
assert_greeting_with_assistant_message_response(messages, streaming=True, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_contains_run_id(messages_from_db)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@@ -1514,7 +1546,8 @@ def test_async_greeting_with_assistant_message(
messages = client.runs.messages.list(run_id=run.id)
usage = client.runs.usage.retrieve(run_id=run.id)
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
# TODO: add results API test later
assert_greeting_with_assistant_message_response(messages, from_db=True, llm_config=llm_config) # TODO: remove from_db=True later
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
@@ -1595,7 +1628,8 @@ def test_async_tool_call(
)
run = wait_for_run_completion(client, run.id)
messages = client.runs.messages.list(run_id=run.id)
assert_tool_call_response(messages, llm_config=llm_config)
# TODO: add test for response api
assert_tool_call_response(messages, from_db=True, llm_config=llm_config) # NOTE: skip first message which is the user message
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
@@ -1725,7 +1759,7 @@ def test_async_greeting_with_callback_url(
# Validate job completed successfully
messages = client.runs.messages.list(run_id=run.id)
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
assert_greeting_with_assistant_message_response(messages, from_db=True, llm_config=llm_config)
# Validate callback was received
assert server.wait_for_callback(timeout=15), "Callback was not received within timeout"