test: add additional new agent messaging tests (#2120)
This commit is contained in:
@@ -133,7 +133,6 @@ class Agent(BaseAgent):
|
||||
# Different interfaces can handle events differently
|
||||
# e.g., print in CLI vs send a discord message with a discord bot
|
||||
self.interface = interface
|
||||
self.chunk_index = 0
|
||||
|
||||
# Create the persistence manager object based on the AgentState info
|
||||
self.message_manager = MessageManager()
|
||||
@@ -248,11 +247,9 @@ class Agent(BaseAgent):
|
||||
group_id=group_id,
|
||||
)
|
||||
messages.append(new_message)
|
||||
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=0)
|
||||
if include_function_failed_message:
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message, chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message)
|
||||
|
||||
# Return updated messages
|
||||
return messages
|
||||
@@ -422,6 +419,7 @@ class Agent(BaseAgent):
|
||||
messages = [] # append these to the history when done
|
||||
function_name = None
|
||||
function_args = {}
|
||||
chunk_index = 0
|
||||
|
||||
# Step 2: check if LLM wanted to call a function
|
||||
if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0):
|
||||
@@ -465,8 +463,8 @@ class Agent(BaseAgent):
|
||||
nonnull_content = False
|
||||
if response_message.content or response_message.reasoning_content or response_message.redacted_reasoning_content:
|
||||
# The content if then internal monologue, not chat
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index)
|
||||
chunk_index += 1
|
||||
# Flag to avoid printing a duplicate if inner thoughts get popped from the function call
|
||||
nonnull_content = True
|
||||
|
||||
@@ -515,8 +513,8 @@ class Agent(BaseAgent):
|
||||
response_message.content = function_args.pop("inner_thoughts")
|
||||
# The content if then internal monologue, not chat
|
||||
if response_message.content and not nonnull_content:
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index)
|
||||
chunk_index += 1
|
||||
|
||||
# (Still parsing function args)
|
||||
# Handle requests for immediate heartbeat
|
||||
@@ -542,8 +540,8 @@ class Agent(BaseAgent):
|
||||
# handle cases where we return a json message
|
||||
if "message" in function_args:
|
||||
function_args["message"] = str(function_args.get("message", ""))
|
||||
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=chunk_index)
|
||||
chunk_index = 0 # reset chunk index after assistant message
|
||||
try:
|
||||
# handle tool execution (sandbox) and state updates
|
||||
log_telemetry(
|
||||
@@ -667,10 +665,9 @@ class Agent(BaseAgent):
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with function response
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=chunk_index)
|
||||
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=chunk_index)
|
||||
chunk_index += 1
|
||||
self.last_function_response = function_response
|
||||
|
||||
else:
|
||||
@@ -685,8 +682,8 @@ class Agent(BaseAgent):
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with assistant's reply
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index)
|
||||
chunk_index += 1
|
||||
heartbeat_request = False
|
||||
function_failed = False
|
||||
|
||||
|
||||
@@ -997,10 +997,12 @@ def anthropic_chat_completions_process_stream(
|
||||
expect_reasoning_content=extended_thinking,
|
||||
name=name,
|
||||
message_index=message_idx,
|
||||
prev_message_type=prev_message_type,
|
||||
)
|
||||
if message_type != prev_message_type and message_type is not None:
|
||||
if message_type != prev_message_type and message_type is not None and prev_message_type is not None:
|
||||
message_idx += 1
|
||||
prev_message_type = message_type
|
||||
if message_type is not None:
|
||||
prev_message_type = message_type
|
||||
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
|
||||
stream_interface.process_refresh(chat_completion_response)
|
||||
else:
|
||||
|
||||
@@ -326,10 +326,12 @@ def openai_chat_completions_process_stream(
|
||||
expect_reasoning_content=expect_reasoning_content,
|
||||
name=name,
|
||||
message_index=message_idx,
|
||||
prev_message_type=prev_message_type,
|
||||
)
|
||||
if message_type != prev_message_type and message_type is not None:
|
||||
if message_type != prev_message_type and message_type is not None and prev_message_type is not None:
|
||||
message_idx += 1
|
||||
prev_message_type = message_type
|
||||
if message_type is not None:
|
||||
prev_message_type = message_type
|
||||
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
|
||||
stream_interface.process_refresh(chat_completion_response)
|
||||
else:
|
||||
|
||||
@@ -404,7 +404,7 @@ class Message(BaseMessage):
|
||||
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
|
||||
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
|
||||
name=self.name,
|
||||
otid=self.id.replace("message-", ""),
|
||||
otid=Message.generate_otid_from_id(self.id, len(messages)),
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
|
||||
@@ -162,6 +162,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
prev_message_type: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Called externally with a ChatCompletionChunkResponse. Transforms
|
||||
|
||||
@@ -472,6 +472,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
prev_message_type: Optional[str] = None,
|
||||
) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]:
|
||||
"""
|
||||
Example data from non-streaming response looks like:
|
||||
@@ -488,7 +489,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
choice = chunk.choices[0]
|
||||
message_delta = choice.delta
|
||||
otid = Message.generate_otid_from_id(message_id, message_index)
|
||||
|
||||
if (
|
||||
message_delta.content is None
|
||||
@@ -503,6 +503,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
# inner thoughts
|
||||
if expect_reasoning_content and message_delta.reasoning_content is not None:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
processed_chunk = ReasoningMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
@@ -510,16 +512,18 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
signature=message_delta.reasoning_content_signature,
|
||||
source="reasoner_model" if message_delta.reasoning_content else "non_reasoner_model",
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
elif expect_reasoning_content and message_delta.redacted_reasoning_content is not None:
|
||||
if prev_message_type and prev_message_type != "hidden_reasoning_message":
|
||||
message_index += 1
|
||||
processed_chunk = HiddenReasoningMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
hidden_reasoning=message_delta.redacted_reasoning_content,
|
||||
state="redacted",
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
elif expect_reasoning_content and message_delta.content is not None:
|
||||
# "ignore" content if we expect reasoning content
|
||||
@@ -537,6 +541,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# NOTE: this is hardcoded for our DeepSeek API integration
|
||||
json_reasoning_content = parse_json(self.expect_reasoning_content_buffer)
|
||||
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
processed_chunk = ToolCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
@@ -546,7 +552,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=None,
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -576,12 +582,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# print(f"Hiding content delta stream: '{message_delta.content}'")
|
||||
# return None
|
||||
elif message_delta.content is not None:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
processed_chunk = ReasoningMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
reasoning=message_delta.content,
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
|
||||
# tool calls
|
||||
@@ -629,7 +637,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name, otid=otid)
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
processed_chunk = AssistantMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
content=diff,
|
||||
name=name,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -653,6 +669,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = None
|
||||
print("skipping empty chunk...")
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
processed_chunk = ToolCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
@@ -662,7 +680,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=tool_call_delta.get("id"),
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
|
||||
elif self.inner_thoughts_in_kwargs and tool_call.function:
|
||||
@@ -694,12 +712,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
if updates_inner_thoughts:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
processed_chunk = ReasoningMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
reasoning=updates_inner_thoughts,
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
# Additionally inner thoughts may stream back with a chunk of main JSON
|
||||
# In that case, since we can only return a chunk at a time, we should buffer it
|
||||
@@ -727,6 +747,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
processed_chunk = ToolCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
@@ -736,7 +758,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
|
||||
# Record what the last function name we flushed was
|
||||
@@ -789,12 +811,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
processed_chunk = AssistantMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
content=combined_chunk,
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
@@ -818,8 +842,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
processed_chunk = AssistantMessage(
|
||||
id=message_id, date=message_date, content=diff, name=name, otid=otid
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
content=diff,
|
||||
name=name,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -836,6 +866,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
processed_chunk = ToolCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
@@ -845,13 +877,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
processed_chunk = ToolCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
@@ -861,7 +895,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
self.function_id_buffer = None
|
||||
|
||||
@@ -982,6 +1016,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = None
|
||||
print("skipping empty chunk...")
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
processed_chunk = ToolCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
@@ -991,7 +1027,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=tool_call_delta.get("id"),
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
otid=Message.generate_otid_from_id(message_id, message_index),
|
||||
)
|
||||
|
||||
elif choice.finish_reason is not None:
|
||||
@@ -1074,6 +1110,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
prev_message_type: Optional[str] = None,
|
||||
):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server.
|
||||
|
||||
@@ -1101,6 +1138,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
expect_reasoning_content=expect_reasoning_content,
|
||||
name=name,
|
||||
message_index=message_index,
|
||||
prev_message_type=prev_message_type,
|
||||
)
|
||||
if processed_chunk is None:
|
||||
return
|
||||
@@ -1303,7 +1341,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
|
||||
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index),
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
|
||||
elif msg.startswith("Error: "):
|
||||
@@ -1319,7 +1357,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
|
||||
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index),
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -56,6 +56,7 @@ class AgentChunkStreamingInterface(ABC):
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
prev_message_type: Optional[str] = None,
|
||||
):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server"""
|
||||
raise NotImplementedError
|
||||
@@ -108,6 +109,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
prev_message_type: Optional[str] = None,
|
||||
):
|
||||
assert len(chunk.choices) == 1, chunk
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@ import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, Letta, Run
|
||||
from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage
|
||||
from letta_client import AsyncLetta, Letta, MessageCreate, Run
|
||||
from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -120,9 +121,10 @@ def roll_dice(num_sides: int) -> int:
|
||||
return random.randint(1, num_sides)
|
||||
|
||||
|
||||
USER_MESSAGE_GREETING: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}]
|
||||
USER_MESSAGE_TOOL_CALL: List[Dict[str, str]] = [
|
||||
{"role": "user", "content": "Call the roll_dice tool with 16 sides and tell me the outcome."}
|
||||
USER_MESSAGE_OTID = str(uuid.uuid4())
|
||||
USER_MESSAGE_GREETING: List[MessageCreate] = [MessageCreate(role="user", content="Hi there.", otid=USER_MESSAGE_OTID)]
|
||||
USER_MESSAGE_TOOL_CALL: List[MessageCreate] = [
|
||||
MessageCreate(role="user", content="Call the roll_dice tool with 16 sides and tell me the outcome.", otid=USER_MESSAGE_OTID)
|
||||
]
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
@@ -139,54 +141,123 @@ filenames = [requested] if requested else all_configs
|
||||
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
||||
|
||||
|
||||
def assert_greeting_with_assistant_message_response(messages: List[Any], streaming: bool = False) -> None:
|
||||
def assert_greeting_with_assistant_message_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 3 if streaming else 2
|
||||
expected_message_count = 3 if streaming or from_db else 2
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
assert isinstance(messages[0], ReasoningMessage)
|
||||
assert isinstance(messages[1], AssistantMessage)
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[2], LettaUsageStatistics)
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
|
||||
|
||||
def assert_greeting_without_assistant_message_response(messages: List[Any], streaming: bool = False) -> None:
|
||||
def assert_greeting_without_assistant_message_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage.
|
||||
"""
|
||||
expected_message_count = 4 if streaming else 3
|
||||
expected_message_count = 4 if streaming or from_db else 3
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
assert isinstance(messages[0], ReasoningMessage)
|
||||
assert isinstance(messages[1], ToolCallMessage)
|
||||
assert isinstance(messages[2], ToolReturnMessage)
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
# Agent Step 2
|
||||
assert isinstance(messages[index], ToolReturnMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[3], LettaUsageStatistics)
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
|
||||
|
||||
def assert_tool_call_response(messages: List[Any], streaming: bool = False) -> None:
|
||||
def assert_tool_call_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage ->
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 6 if streaming else 5
|
||||
expected_message_count = 6 if streaming else 7 if from_db else 5
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
assert isinstance(messages[0], ReasoningMessage)
|
||||
assert isinstance(messages[1], ToolCallMessage)
|
||||
assert isinstance(messages[2], ToolReturnMessage)
|
||||
assert isinstance(messages[3], ReasoningMessage)
|
||||
assert isinstance(messages[4], AssistantMessage)
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
# Agent Step 2
|
||||
assert isinstance(messages[index], ToolReturnMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
# Hidden User Message
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert "request_heartbeat=true" in messages[index].content
|
||||
index += 1
|
||||
|
||||
# Agent Step 3
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[5], LettaUsageStatistics)
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
|
||||
|
||||
def accumulate_chunks(chunks: List[Any]) -> List[Any]:
|
||||
@@ -259,12 +330,15 @@ def test_greeting_with_assistant_message(
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
)
|
||||
assert_greeting_with_assistant_message_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -282,6 +356,7 @@ def test_greeting_without_assistant_message(
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
@@ -289,6 +364,8 @@ def test_greeting_without_assistant_message(
|
||||
use_assistant_message=False,
|
||||
)
|
||||
assert_greeting_without_assistant_message_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -308,12 +385,15 @@ def test_tool_call(
|
||||
"""
|
||||
dice_tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_TOOL_CALL,
|
||||
)
|
||||
assert_tool_call_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user