From bafc47c655151ff513bf2668e0fee8230c8a7d4b Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 12 May 2025 15:58:52 -0700 Subject: [PATCH] test: add additional new agent messaging tests (#2120) --- letta/agent.py | 31 ++--- letta/llm_api/anthropic.py | 6 +- letta/llm_api/openai.py | 6 +- letta/schemas/message.py | 2 +- .../rest_api/chat_completions_interface.py | 1 + letta/server/rest_api/interface.py | 70 +++++++--- letta/streaming_interface.py | 2 + tests/integration_test_send_message.py | 128 ++++++++++++++---- 8 files changed, 184 insertions(+), 62 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index d0c9ac0f..50a8ed20 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -133,7 +133,6 @@ class Agent(BaseAgent): # Different interfaces can handle events differently # e.g., print in CLI vs send a discord message with a discord bot self.interface = interface - self.chunk_index = 0 # Create the persistence manager object based on the AgentState info self.message_manager = MessageManager() @@ -248,11 +247,9 @@ class Agent(BaseAgent): group_id=group_id, ) messages.append(new_message) - self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=self.chunk_index) - self.chunk_index += 1 + self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=0) if include_function_failed_message: - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message, chunk_index=self.chunk_index) - self.chunk_index += 1 + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message) # Return updated messages return messages @@ -422,6 +419,7 @@ class Agent(BaseAgent): messages = [] # append these to the history when done function_name = None function_args = {} + chunk_index = 0 # Step 2: check if LLM wanted to call a function if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0): @@ -465,8 +463,8 @@ class Agent(BaseAgent): nonnull_content = False if response_message.content or response_message.reasoning_content or response_message.redacted_reasoning_content: # The content if then internal monologue, not chat - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index) - self.chunk_index += 1 + self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index) + chunk_index += 1 # Flag to avoid printing a duplicate if inner thoughts get popped from the function call nonnull_content = True @@ -515,8 +513,8 @@ class Agent(BaseAgent): response_message.content = function_args.pop("inner_thoughts") # The content if then internal monologue, not chat if response_message.content and not nonnull_content: - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index) - self.chunk_index += 1 + self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index) + chunk_index += 1 # (Still parsing function args) # Handle requests for immediate heartbeat @@ -542,8 +540,8 @@ class Agent(BaseAgent): # handle cases where we return a json message if "message" in function_args: function_args["message"] = str(function_args.get("message", "")) - self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index) - self.chunk_index += 1 + self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=chunk_index) + chunk_index = 0 # reset chunk index after assistant message try: # handle tool execution (sandbox) and state updates log_telemetry( @@ -667,10 +665,9 @@ class Agent(BaseAgent): group_id=group_id, ) ) # extend conversation with function response - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index) - self.chunk_index += 1 - self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=self.chunk_index) - self.chunk_index += 1 + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=chunk_index) + self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=chunk_index) + chunk_index += 1 self.last_function_response = function_response else: @@ -685,8 +682,8 @@ class Agent(BaseAgent): group_id=group_id, ) ) # extend conversation with assistant's reply - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index) - self.chunk_index += 1 + self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index) + chunk_index += 1 heartbeat_request = False function_failed = False diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 88cf0e79..89329d01 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -997,10 +997,12 @@ def anthropic_chat_completions_process_stream( expect_reasoning_content=extended_thinking, name=name, message_index=message_idx, + prev_message_type=prev_message_type, ) - if message_type != prev_message_type and message_type is not None: + if message_type != prev_message_type and message_type is not None and prev_message_type is not None: message_idx += 1 - prev_message_type = message_type + if message_type is not None: + prev_message_type = message_type elif isinstance(stream_interface, AgentRefreshStreamingInterface): stream_interface.process_refresh(chat_completion_response) else: diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 6a0f182b..ac39ddc6 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -326,10 +326,12 @@ def openai_chat_completions_process_stream( expect_reasoning_content=expect_reasoning_content, name=name, message_index=message_idx, + prev_message_type=prev_message_type, ) - if message_type != prev_message_type and message_type is not None: + if message_type != prev_message_type and message_type is not None and prev_message_type is not None: message_idx += 1 - prev_message_type = message_type + if message_type is not None: + prev_message_type = message_type elif isinstance(stream_interface, AgentRefreshStreamingInterface): stream_interface.process_refresh(chat_completion_response) else: diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 182f608e..e3f6a433 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -404,7 +404,7 @@ class Message(BaseMessage): stdout=self.tool_returns[0].stdout if self.tool_returns else None, stderr=self.tool_returns[0].stderr if self.tool_returns else None, name=self.name, - otid=self.id.replace("message-", ""), + otid=Message.generate_otid_from_id(self.id, len(messages)), sender_id=self.sender_id, step_id=self.step_id, ) diff --git a/letta/server/rest_api/chat_completions_interface.py b/letta/server/rest_api/chat_completions_interface.py index 9b05ca84..76373043 100644 --- a/letta/server/rest_api/chat_completions_interface.py +++ b/letta/server/rest_api/chat_completions_interface.py @@ -162,6 +162,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface): expect_reasoning_content: bool = False, name: Optional[str] = None, message_index: int = 0, + prev_message_type: Optional[str] = None, ) -> None: """ Called externally with a ChatCompletionChunkResponse. Transforms diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 66cc77a9..33085408 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -472,6 +472,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): expect_reasoning_content: bool = False, name: Optional[str] = None, message_index: int = 0, + prev_message_type: Optional[str] = None, ) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]: """ Example data from non-streaming response looks like: @@ -488,7 +489,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface): choice = chunk.choices[0] message_delta = choice.delta - otid = Message.generate_otid_from_id(message_id, message_index) if ( message_delta.content is None @@ -503,6 +503,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # inner thoughts if expect_reasoning_content and message_delta.reasoning_content is not None: + if prev_message_type and prev_message_type != "reasoning_message": + message_index += 1 processed_chunk = ReasoningMessage( id=message_id, date=message_date, @@ -510,16 +512,18 @@ class StreamingServerInterface(AgentChunkStreamingInterface): signature=message_delta.reasoning_content_signature, source="reasoner_model" if message_delta.reasoning_content else "non_reasoner_model", name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) elif expect_reasoning_content and message_delta.redacted_reasoning_content is not None: + if prev_message_type and prev_message_type != "hidden_reasoning_message": + message_index += 1 processed_chunk = HiddenReasoningMessage( id=message_id, date=message_date, hidden_reasoning=message_delta.redacted_reasoning_content, state="redacted", name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) elif expect_reasoning_content and message_delta.content is not None: # "ignore" content if we expect reasoning content @@ -537,6 +541,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # NOTE: this is hardcoded for our DeepSeek API integration json_reasoning_content = parse_json(self.expect_reasoning_content_buffer) + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 processed_chunk = ToolCallMessage( id=message_id, date=message_date, @@ -546,7 +552,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=None, ), name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) except json.JSONDecodeError as e: @@ -576,12 +582,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # print(f"Hiding content delta stream: '{message_delta.content}'") # return None elif message_delta.content is not None: + if prev_message_type and prev_message_type != "reasoning_message": + message_index += 1 processed_chunk = ReasoningMessage( id=message_id, date=message_date, reasoning=message_delta.content, name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) # tool calls @@ -629,7 +637,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # TODO: Assumes consistent state and that prev_content is subset of new_content diff = new_content.replace(prev_content, "", 1) self.current_json_parse_result = parsed_args - processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name, otid=otid) + if prev_message_type and prev_message_type != "assistant_message": + message_index += 1 + processed_chunk = AssistantMessage( + id=message_id, + date=message_date, + content=diff, + name=name, + otid=Message.generate_otid_from_id(message_id, message_index), + ) else: return None @@ -653,6 +669,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface): processed_chunk = None print("skipping empty chunk...") else: + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 processed_chunk = ToolCallMessage( id=message_id, date=message_date, @@ -662,7 +680,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=tool_call_delta.get("id"), ), name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) elif self.inner_thoughts_in_kwargs and tool_call.function: @@ -694,12 +712,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # If we have inner thoughts, we should output them as a chunk if updates_inner_thoughts: + if prev_message_type and prev_message_type != "reasoning_message": + message_index += 1 processed_chunk = ReasoningMessage( id=message_id, date=message_date, reasoning=updates_inner_thoughts, name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) # Additionally inner thoughts may stream back with a chunk of main JSON # In that case, since we can only return a chunk at a time, we should buffer it @@ -727,6 +747,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.prev_assistant_message_id = self.function_id_buffer else: + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 processed_chunk = ToolCallMessage( id=message_id, date=message_date, @@ -736,7 +758,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=self.function_id_buffer, ), name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) # Record what the last function name we flushed was @@ -789,12 +811,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # In this case, we should release the buffer + new data at once combined_chunk = self.function_args_buffer + updates_main_json + if prev_message_type and prev_message_type != "assistant_message": + message_index += 1 processed_chunk = AssistantMessage( id=message_id, date=message_date, content=combined_chunk, name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) # Store the ID of the tool call so allow skipping the corresponding response if self.function_id_buffer: @@ -818,8 +842,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # TODO: Assumes consistent state and that prev_content is subset of new_content diff = new_content.replace(prev_content, "", 1) self.current_json_parse_result = parsed_args + if prev_message_type and prev_message_type != "assistant_message": + message_index += 1 processed_chunk = AssistantMessage( - id=message_id, date=message_date, content=diff, name=name, otid=otid + id=message_id, + date=message_date, + content=diff, + name=name, + otid=Message.generate_otid_from_id(message_id, message_index), ) else: return None @@ -836,6 +866,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface): if self.function_args_buffer: # In this case, we should release the buffer + new data at once combined_chunk = self.function_args_buffer + updates_main_json + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 processed_chunk = ToolCallMessage( id=message_id, date=message_date, @@ -845,13 +877,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=self.function_id_buffer, ), name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) # clear buffer self.function_args_buffer = None self.function_id_buffer = None else: # If there's no buffer to clear, just output a new chunk with new data + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 processed_chunk = ToolCallMessage( id=message_id, date=message_date, @@ -861,7 +895,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=self.function_id_buffer, ), name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) self.function_id_buffer = None @@ -982,6 +1016,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface): processed_chunk = None print("skipping empty chunk...") else: + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 processed_chunk = ToolCallMessage( id=message_id, date=message_date, @@ -991,7 +1027,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=tool_call_delta.get("id"), ), name=name, - otid=otid, + otid=Message.generate_otid_from_id(message_id, message_index), ) elif choice.finish_reason is not None: @@ -1074,6 +1110,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): expect_reasoning_content: bool = False, name: Optional[str] = None, message_index: int = 0, + prev_message_type: Optional[str] = None, ): """Process a streaming chunk from an OpenAI-compatible server. @@ -1101,6 +1138,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): expect_reasoning_content=expect_reasoning_content, name=name, message_index=message_index, + prev_message_type=prev_message_type, ) if processed_chunk is None: return @@ -1303,7 +1341,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None, stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None, name=msg_obj.name, - otid=Message.generate_otid_from_id(msg_obj.id, chunk_index), + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) elif msg.startswith("Error: "): @@ -1319,7 +1357,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None, stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None, name=msg_obj.name, - otid=Message.generate_otid_from_id(msg_obj.id, chunk_index), + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) else: diff --git a/letta/streaming_interface.py b/letta/streaming_interface.py index 4b2c3330..7533e25f 100644 --- a/letta/streaming_interface.py +++ b/letta/streaming_interface.py @@ -56,6 +56,7 @@ class AgentChunkStreamingInterface(ABC): expect_reasoning_content: bool = False, name: Optional[str] = None, message_index: int = 0, + prev_message_type: Optional[str] = None, ): """Process a streaming chunk from an OpenAI-compatible server""" raise NotImplementedError @@ -108,6 +109,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface): expect_reasoning_content: bool = False, name: Optional[str] = None, message_index: int = 0, + prev_message_type: Optional[str] = None, ): assert len(chunk.choices) == 1, chunk diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index cb8b5d3b..d30a6418 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -2,13 +2,14 @@ import json import os import threading import time +import uuid from typing import Any, Dict, List import pytest import requests from dotenv import load_dotenv -from letta_client import AsyncLetta, Letta, Run -from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage +from letta_client import AsyncLetta, Letta, MessageCreate, Run +from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.agent import AgentState from letta.schemas.llm_config import LLMConfig @@ -120,9 +121,10 @@ def roll_dice(num_sides: int) -> int: return random.randint(1, num_sides) -USER_MESSAGE_GREETING: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}] -USER_MESSAGE_TOOL_CALL: List[Dict[str, str]] = [ - {"role": "user", "content": "Call the roll_dice tool with 16 sides and tell me the outcome."} +USER_MESSAGE_OTID = str(uuid.uuid4()) +USER_MESSAGE_GREETING: List[MessageCreate] = [MessageCreate(role="user", content="Hi there.", otid=USER_MESSAGE_OTID)] +USER_MESSAGE_TOOL_CALL: List[MessageCreate] = [ + MessageCreate(role="user", content="Call the roll_dice tool with 16 sides and tell me the outcome.", otid=USER_MESSAGE_OTID) ] all_configs = [ "openai-gpt-4o-mini.json", @@ -139,54 +141,123 @@ filenames = [requested] if requested else all_configs TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames] -def assert_greeting_with_assistant_message_response(messages: List[Any], streaming: bool = False) -> None: +def assert_greeting_with_assistant_message_response( + messages: List[Any], + streaming: bool = False, + from_db: bool = False, +) -> None: """ Asserts that the messages list follows the expected sequence: ReasoningMessage -> AssistantMessage. """ - expected_message_count = 3 if streaming else 2 + expected_message_count = 3 if streaming or from_db else 2 assert len(messages) == expected_message_count - assert isinstance(messages[0], ReasoningMessage) - assert isinstance(messages[1], AssistantMessage) + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + assert isinstance(messages[index], AssistantMessage) + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 if streaming: - assert isinstance(messages[2], LettaUsageStatistics) + assert isinstance(messages[index], LettaUsageStatistics) -def assert_greeting_without_assistant_message_response(messages: List[Any], streaming: bool = False) -> None: +def assert_greeting_without_assistant_message_response( + messages: List[Any], + streaming: bool = False, + from_db: bool = False, +) -> None: """ Asserts that the messages list follows the expected sequence: ReasoningMessage -> ToolCallMessage -> ToolReturnMessage. """ - expected_message_count = 4 if streaming else 3 + expected_message_count = 4 if streaming or from_db else 3 assert len(messages) == expected_message_count - assert isinstance(messages[0], ReasoningMessage) - assert isinstance(messages[1], ToolCallMessage) - assert isinstance(messages[2], ToolReturnMessage) + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + assert isinstance(messages[index], ToolCallMessage) + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 + + # Agent Step 2 + assert isinstance(messages[index], ToolReturnMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 if streaming: - assert isinstance(messages[3], LettaUsageStatistics) + assert isinstance(messages[index], LettaUsageStatistics) -def assert_tool_call_response(messages: List[Any], streaming: bool = False) -> None: +def assert_tool_call_response( + messages: List[Any], + streaming: bool = False, + from_db: bool = False, +) -> None: """ Asserts that the messages list follows the expected sequence: ReasoningMessage -> ToolCallMessage -> ToolReturnMessage -> ReasoningMessage -> AssistantMessage. """ - expected_message_count = 6 if streaming else 5 + expected_message_count = 6 if streaming else 7 if from_db else 5 assert len(messages) == expected_message_count - assert isinstance(messages[0], ReasoningMessage) - assert isinstance(messages[1], ToolCallMessage) - assert isinstance(messages[2], ToolReturnMessage) - assert isinstance(messages[3], ReasoningMessage) - assert isinstance(messages[4], AssistantMessage) + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + assert isinstance(messages[index], ToolCallMessage) + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 + + # Agent Step 2 + assert isinstance(messages[index], ToolReturnMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + # Hidden User Message + if from_db: + assert isinstance(messages[index], UserMessage) + assert "request_heartbeat=true" in messages[index].content + index += 1 + + # Agent Step 3 + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + assert isinstance(messages[index], AssistantMessage) + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 if streaming: - assert isinstance(messages[5], LettaUsageStatistics) + assert isinstance(messages[index], LettaUsageStatistics) def accumulate_chunks(chunks: List[Any]) -> List[Any]: @@ -259,12 +330,15 @@ def test_greeting_with_assistant_message( Tests sending a message with a synchronous client. Verifies that the response messages follow the expected order. """ + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( agent_id=agent_state.id, messages=USER_MESSAGE_GREETING, ) assert_greeting_with_assistant_message_response(response.messages) + messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_greeting_with_assistant_message_response(messages_from_db, from_db=True) @pytest.mark.parametrize( @@ -282,6 +356,7 @@ def test_greeting_without_assistant_message( Tests sending a message with a synchronous client. Verifies that the response messages follow the expected order. """ + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( agent_id=agent_state.id, @@ -289,6 +364,8 @@ def test_greeting_without_assistant_message( use_assistant_message=False, ) assert_greeting_without_assistant_message_response(response.messages) + messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False) + assert_greeting_without_assistant_message_response(messages_from_db, from_db=True) @pytest.mark.parametrize( @@ -308,12 +385,15 @@ def test_tool_call( """ dice_tool = client.tools.upsert_from_function(func=roll_dice) client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id) + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( agent_id=agent_state.id, messages=USER_MESSAGE_TOOL_CALL, ) assert_tool_call_response(response.messages) + messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_tool_call_response(messages_from_db, from_db=True) @pytest.mark.asyncio