From a4d7284ebec1ca6d4927559e8d3befc1377ee425 Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 21 Apr 2025 13:49:46 -0700 Subject: [PATCH] feat: unify input message type on agent step (#1820) --- letta/agent.py | 74 ++++++----- letta/client/client.py | 4 +- letta/functions/function_sets/multi_agent.py | 4 +- letta/functions/helpers.py | 6 +- letta/groups/dynamic_multi_agent.py | 117 +++++++++--------- letta/groups/round_robin_multi_agent.py | 90 +++++++------- letta/groups/sleeptime_multi_agent.py | 46 ++++--- letta/groups/supervisor_multi_agent.py | 41 +++--- letta/helpers/message_helper.py | 1 + letta/schemas/message.py | 1 + .../chat_completions/chat_completions.py | 2 +- letta/server/rest_api/routers/v1/agents.py | 8 +- letta/server/rest_api/routers/v1/groups.py | 4 +- letta/server/server.py | 66 +++------- tests/integration_test_sleeptime_agent.py | 6 +- tests/test_agent_serialization.py | 24 ++-- tests/test_multi_agent.py | 10 +- 17 files changed, 244 insertions(+), 260 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 589e9acc..7dba5ead 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -27,6 +27,7 @@ from letta.helpers import ToolRulesSolver from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads +from letta.helpers.message_helper import prepare_input_message_create from letta.interface import AgentInterface from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error from letta.llm_api.llm_api_tools import create @@ -42,7 +43,7 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import ContextWindowOverview, Memory -from letta.schemas.message import Message, ToolReturn +from letta.schemas.message import Message, MessageCreate, ToolReturn from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage from letta.schemas.openai.chat_completion_response import UsageStatistics @@ -78,7 +79,7 @@ class BaseAgent(ABC): @abstractmethod def step( self, - messages: Union[Message, List[Message]], + input_messages: List[MessageCreate], ) -> LettaUsageStatistics: """ Top-level event message handler for the agent. @@ -691,7 +692,7 @@ class Agent(BaseAgent): @trace_method def step( self, - messages: Union[Message, List[Message]], + input_messages: List[MessageCreate], # additional args chaining: bool = True, max_chaining_steps: Optional[int] = None, @@ -704,7 +705,9 @@ class Agent(BaseAgent): # But just to be safe self.tool_rules_solver.clear_tool_history() - next_input_message = messages if isinstance(messages, list) else [messages] + # Convert MessageCreate objects to Message objects + message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages] + next_input_messages = message_objects counter = 0 total_usage = UsageStatistics() step_count = 0 @@ -715,7 +718,7 @@ class Agent(BaseAgent): kwargs["step_count"] = step_count kwargs["last_function_failed"] = function_failed step_response = self.inner_step( - messages=next_input_message, + messages=next_input_messages, put_inner_thoughts_first=put_inner_thoughts_first, **kwargs, ) @@ -745,36 +748,42 @@ class Agent(BaseAgent): # Chain handlers elif token_warning and summarizer_settings.send_memory_warning_message: assert self.agent_state.created_by_id is not None - next_input_message = Message.dict_to_message( - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": get_token_limit_warning(), - }, - ) + next_input_messages = [ + Message.dict_to_message( + agent_id=self.agent_state.id, + model=self.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": get_token_limit_warning(), + }, + ), + ] continue # always chain elif function_failed: assert self.agent_state.created_by_id is not None - next_input_message = Message.dict_to_message( - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE), - }, - ) + next_input_messages = [ + Message.dict_to_message( + agent_id=self.agent_state.id, + model=self.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE), + }, + ) + ] continue # always chain elif heartbeat_request: assert self.agent_state.created_by_id is not None - next_input_message = Message.dict_to_message( - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": get_heartbeat(REQ_HEARTBEAT_MESSAGE), - }, - ) + next_input_messages = [ + Message.dict_to_message( + agent_id=self.agent_state.id, + model=self.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": get_heartbeat(REQ_HEARTBEAT_MESSAGE), + }, + ) + ] continue # always chain # Letta no-op / yield else: @@ -788,7 +797,7 @@ class Agent(BaseAgent): def inner_step( self, - messages: Union[Message, List[Message]], + messages: List[Message], first_message: bool = False, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, @@ -814,11 +823,8 @@ class Agent(BaseAgent): self.update_memory_if_changed(current_persisted_memory) # Step 1: add user message - if isinstance(messages, Message): - messages = [messages] - if not all(isinstance(m, Message) for m in messages): - raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}") + raise ValueError(f"messages should be a list of Message, got {[type(m) for m in messages]}") in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) input_message_sequence = in_context_messages + messages diff --git a/letta/client/client.py b/letta/client/client.py index a119c882..7effa659 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2661,7 +2661,7 @@ class LocalClient(AbstractClient): response (LettaResponse): Response from the agent """ self.interface.clear() - usage = self.server.send_messages(actor=self.user, agent_id=agent_id, messages=messages) + usage = self.server.send_messages(actor=self.user, agent_id=agent_id, input_messages=messages) # format messages return LettaResponse(messages=messages, usage=usage) @@ -2703,7 +2703,7 @@ class LocalClient(AbstractClient): usage = self.server.send_messages( actor=self.user, agent_id=agent_id, - messages=[MessageCreate(role=MessageRole(role), content=message, name=name)], + input_messages=[MessageCreate(role=MessageRole(role), content=message, name=name)], ) ## TODO: need to make sure date/timestamp is propely passed diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py index 20db45df..2794cf78 100644 --- a/letta/functions/function_sets/multi_agent.py +++ b/letta/functions/function_sets/multi_agent.py @@ -9,7 +9,6 @@ from letta.functions.helpers import ( extract_send_message_from_steps_messages, fire_and_forget_send_to_agent, ) -from letta.helpers.message_helper import prepare_input_message_create from letta.schemas.enums import MessageRole from letta.schemas.message import MessageCreate from letta.server.rest_api.utils import get_letta_server @@ -109,11 +108,10 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: # Prepare the message messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)] - input_messages = [prepare_input_message_create(m, agent_id) for m in messages] # Run .step() and return the response usage_stats = agent.step( - messages=input_messages, + input_messages=messages, chaining=True, max_chaining_steps=None, stream=False, diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 2b2cf149..9b721985 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -352,7 +352,7 @@ async def send_message_to_agent_no_stream( server: "SyncServer", agent_id: str, actor: User, - messages: Union[List[Message], List[MessageCreate]], + messages: List[MessageCreate], metadata: Optional[dict] = None, ) -> LettaResponse: """ @@ -368,7 +368,7 @@ async def send_message_to_agent_no_stream( server.send_messages, actor=actor, agent_id=agent_id, - messages=messages, + input_messages=messages, interface=interface, metadata=metadata, ) @@ -478,7 +478,7 @@ def fire_and_forget_send_to_agent( await server.send_message_to_agent( agent_id=other_agent_id, actor=sender_agent.user, - messages=messages, + input_messages=messages, stream_steps=False, stream_tokens=False, use_assistant_message=True, diff --git a/letta/groups/dynamic_multi_agent.py b/letta/groups/dynamic_multi_agent.py index 9f0973ea..f89a6f3a 100644 --- a/letta/groups/dynamic_multi_agent.py +++ b/letta/groups/dynamic_multi_agent.py @@ -35,7 +35,7 @@ class DynamicMultiAgent(Agent): def step( self, - messages: List[MessageCreate], + input_messages: List[MessageCreate], chaining: bool = True, max_chaining_steps: Optional[int] = None, put_inner_thoughts_first: bool = True, @@ -43,27 +43,43 @@ class DynamicMultiAgent(Agent): ) -> LettaUsageStatistics: total_usage = UsageStatistics() step_count = 0 + speaker_id = None + # Load settings token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None - agents = {} + # Load agents and initialize chat history with indexing + agents = {self.agent_state.id: self.load_manager_agent()} message_index = {self.agent_state.id: 0} - agents[self.agent_state.id] = self.load_manager_agent() + chat_history: List[MessageCreate] = [] for agent_id in self.agent_ids: agents[agent_id] = self.load_participant_agent(agent_id=agent_id) message_index[agent_id] = 0 - chat_history: List[Message] = [] - new_messages = messages - speaker_id = None + # Prepare new messages + new_messages = [] + for message in input_messages: + if isinstance(message.content, str): + message.content = [TextContent(text=message.content)] + message.group_id = self.group_id + new_messages.append(message) + try: for _ in range(self.max_turns): + # Prepare manager message agent_id_options = [agent_id for agent_id in self.agent_ids if agent_id != speaker_id] - manager_message = self.ask_manager_to_choose_participant_message(new_messages, chat_history, agent_id_options) + manager_message = self.ask_manager_to_choose_participant_message( + manager_agent_id=self.agent_state.id, + new_messages=new_messages, + chat_history=chat_history, + agent_id_options=agent_id_options, + ) + + # Perform manager step manager_agent = agents[self.agent_state.id] usage_stats = manager_agent.step( - messages=[manager_message], + input_messages=[manager_message], chaining=chaining, max_chaining_steps=max_chaining_steps, stream=token_streaming, @@ -71,42 +87,27 @@ class DynamicMultiAgent(Agent): metadata=metadata, put_inner_thoughts_first=put_inner_thoughts_first, ) + + # Parse manager response responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages) assistant_message = [response for response in responses if response.message_type == "assistant_message"][0] for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]: if name.lower() in assistant_message.content.lower(): speaker_id = agent_id - # sum usage + # Sum usage total_usage.prompt_tokens += usage_stats.prompt_tokens total_usage.completion_tokens += usage_stats.completion_tokens total_usage.total_tokens += usage_stats.total_tokens step_count += 1 - # initialize input messages - for message in chat_history[message_index[speaker_id] :]: - message.id = Message.generate_id() - message.agent_id = speaker_id + # Update chat history + chat_history.extend(new_messages) - for message in new_messages: - chat_history.append( - Message( - agent_id=speaker_id, - role=message.role, - content=[TextContent(text=message.content)], - name=message.name, - model=None, - tool_calls=None, - tool_call_id=None, - group_id=self.group_id, - otid=message.otid, - ) - ) - - # load agent and perform step + # Perform participant step participant_agent = agents[speaker_id] usage_stats = participant_agent.step( - messages=chat_history[message_index[speaker_id] :], + input_messages=chat_history[message_index[speaker_id] :], chaining=chaining, max_chaining_steps=max_chaining_steps, stream=token_streaming, @@ -115,54 +116,54 @@ class DynamicMultiAgent(Agent): put_inner_thoughts_first=put_inner_thoughts_first, ) - # parse new messages for next step + # Parse participant response responses = Message.to_letta_messages_from_list( participant_agent.last_response_messages, ) - assistant_messages = [response for response in responses if response.message_type == "assistant_message"] new_messages = [ MessageCreate( role="system", - content=message.content, + content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content, name=participant_agent.agent_state.name, otid=message.otid, + sender_id=participant_agent.agent_state.id, + group_id=self.group_id, ) for message in assistant_messages ] + + # Update message index message_index[speaker_id] = len(chat_history) + len(new_messages) - # sum usage + # Sum usage total_usage.prompt_tokens += usage_stats.prompt_tokens total_usage.completion_tokens += usage_stats.completion_tokens total_usage.total_tokens += usage_stats.total_tokens step_count += 1 - # check for termination token + # Check for termination token if any(self.termination_token in message.content for message in new_messages): break - # persist remaining chat history - for message in new_messages: - chat_history.append( - Message( - agent_id=agent_id, - role=message.role, - content=[TextContent(text=message.content)], - name=message.name, - model=None, - tool_calls=None, - tool_call_id=None, - group_id=self.group_id, - ) - ) + # Persist remaining chat history + chat_history.extend(new_messages) for agent_id, index in message_index.items(): if agent_id == speaker_id: continue + messages_to_persist = [] for message in chat_history[index:]: - message.id = Message.generate_id() - message.agent_id = agent_id - self.message_manager.create_many_messages(chat_history[index:], actor=self.user) + message_to_persist = Message( + role=message.role, + content=message.content, + name=message.name, + otid=message.otid, + sender_id=message.sender_id, + group_id=message.group_id, + agent_id=agent_id, + ) + messages_to_persist.append(message_to_persist) + self.message_manager.create_many_messages(messages_to_persist, actor=self.user) except Exception as e: raise e @@ -249,10 +250,11 @@ class DynamicMultiAgent(Agent): def ask_manager_to_choose_participant_message( self, + manager_agent_id: str, new_messages: List[MessageCreate], chat_history: List[Message], agent_id_options: List[str], - ) -> Message: + ) -> MessageCreate: text_chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history] for message in new_messages: text_chat_history.append(f"{message.name or 'user'}: {message.content}") @@ -264,14 +266,11 @@ class DynamicMultiAgent(Agent): "respond to the messages yourself, your task is only to decide the " f"next speaker, not to participate. \nChat history:\n{context_messages}" ) - return Message( - agent_id=self.agent_state.id, + return MessageCreate( role="user", content=[TextContent(text=message_text)], name=None, - model=None, - tool_calls=None, - tool_call_id=None, - group_id=self.group_id, otid=Message.generate_otid(), + sender_id=manager_agent_id, + group_id=self.group_id, ) diff --git a/letta/groups/round_robin_multi_agent.py b/letta/groups/round_robin_multi_agent.py index 4a9bcaaa..9c7b319d 100644 --- a/letta/groups/round_robin_multi_agent.py +++ b/letta/groups/round_robin_multi_agent.py @@ -29,7 +29,7 @@ class RoundRobinMultiAgent(Agent): def step( self, - messages: List[MessageCreate], + input_messages: List[MessageCreate], chaining: bool = True, max_chaining_steps: Optional[int] = None, put_inner_thoughts_first: bool = True, @@ -37,46 +37,39 @@ class RoundRobinMultiAgent(Agent): ) -> LettaUsageStatistics: total_usage = UsageStatistics() step_count = 0 + speaker_id = None + # Load settings token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None - agents = {} + # Load agents and initialize chat history with indexing + agents, message_index = {}, {} + chat_history: List[MessageCreate] = [] for agent_id in self.agent_ids: agents[agent_id] = self.load_participant_agent(agent_id=agent_id) + message_index[agent_id] = 0 + + # Prepare new messages + new_messages = [] + for message in input_messages: + if isinstance(message.content, str): + message.content = [TextContent(text=message.content)] + message.group_id = self.group_id + new_messages.append(message) - message_index = {agent_id: 0 for agent_id in self.agent_ids} - chat_history: List[Message] = [] - new_messages = messages - speaker_id = None try: for i in range(self.max_turns): + # Select speaker speaker_id = self.agent_ids[i % len(self.agent_ids)] - # initialize input messages - start_index = message_index[speaker_id] if speaker_id in message_index else 0 - for message in chat_history[start_index:]: - message.id = Message.generate_id() - message.agent_id = speaker_id - for message in new_messages: - chat_history.append( - Message( - agent_id=speaker_id, - role=message.role, - content=[TextContent(text=message.content)], - name=message.name, - model=None, - tool_calls=None, - tool_call_id=None, - group_id=self.group_id, - otid=message.otid, - ) - ) + # Update chat history + chat_history.extend(new_messages) - # load agent and perform step + # Perform participant step participant_agent = agents[speaker_id] usage_stats = participant_agent.step( - messages=chat_history[start_index:], + input_messages=chat_history[message_index[speaker_id] :], chaining=chaining, max_chaining_steps=max_chaining_steps, stream=token_streaming, @@ -85,47 +78,48 @@ class RoundRobinMultiAgent(Agent): put_inner_thoughts_first=put_inner_thoughts_first, ) - # parse new messages for next step + # Parse participant response responses = Message.to_letta_messages_from_list(participant_agent.last_response_messages) assistant_messages = [response for response in responses if response.message_type == "assistant_message"] new_messages = [ MessageCreate( role="system", - content=message.content, - name=message.name, + content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content, + name=participant_agent.agent_state.name, otid=message.otid, + sender_id=participant_agent.agent_state.id, + group_id=self.group_id, ) for message in assistant_messages ] + + # Update message index message_index[speaker_id] = len(chat_history) + len(new_messages) - # sum usage + # Sum usage total_usage.prompt_tokens += usage_stats.prompt_tokens total_usage.completion_tokens += usage_stats.completion_tokens total_usage.total_tokens += usage_stats.total_tokens step_count += 1 - # persist remaining chat history - for message in new_messages: - chat_history.append( - Message( - agent_id=agent_id, - role=message.role, - content=[TextContent(text=message.content)], - name=message.name, - model=None, - tool_calls=None, - tool_call_id=None, - group_id=self.group_id, - ) - ) + # Persist remaining chat history + chat_history.extend(new_messages) for agent_id, index in message_index.items(): if agent_id == speaker_id: continue + messages_to_persist = [] for message in chat_history[index:]: - message.id = Message.generate_id() - message.agent_id = agent_id - self.message_manager.create_many_messages(chat_history[index:], actor=self.user) + message_to_persist = Message( + role=message.role, + content=message.content, + name=message.name, + otid=message.otid, + sender_id=message.sender_id, + group_id=self.group_id, + agent_id=agent_id, + ) + messages_to_persist.append(message_to_persist) + self.message_manager.create_many_messages(messages_to_persist, actor=self.user) except Exception as e: raise e diff --git a/letta/groups/sleeptime_multi_agent.py b/letta/groups/sleeptime_multi_agent.py index 0a6546c5..1eb1e8e3 100644 --- a/letta/groups/sleeptime_multi_agent.py +++ b/letta/groups/sleeptime_multi_agent.py @@ -143,8 +143,21 @@ class SleeptimeMultiAgent(Agent): group_id=self.group_id, ) ] + + # Convert Message objects to MessageCreate objects + message_creates = [ + MessageCreate( + role=m.role, + content=m.content[0].text if m.content and len(m.content) == 1 else m.content, + name=m.name, + otid=m.otid, + sender_id=m.sender_id, + ) + for m in participant_agent_messages + ] + result = participant_agent.step( - messages=participant_agent_messages, + input_messages=message_creates, chaining=chaining, max_chaining_steps=max_chaining_steps, stream=token_streaming, @@ -173,7 +186,7 @@ class SleeptimeMultiAgent(Agent): def step( self, - messages: List[MessageCreate], + input_messages: List[MessageCreate], chaining: bool = True, max_chaining_steps: Optional[int] = None, put_inner_thoughts_first: bool = True, @@ -181,33 +194,28 @@ class SleeptimeMultiAgent(Agent): ) -> LettaUsageStatistics: run_ids = [] + # Load settings token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None - messages = [ - Message( - id=Message.generate_id(), - agent_id=self.agent_state.id, - role=message.role, - content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content, - name=message.name, - model=None, - tool_calls=None, - tool_call_id=None, - group_id=self.group_id, - otid=message.otid, - ) - for message in messages - ] + # Prepare new messages + new_messages = [] + for message in input_messages: + if isinstance(message.content, str): + message.content = [TextContent(text=message.content)] + message.group_id = self.group_id + new_messages.append(message) try: + # Load main agent main_agent = Agent( agent_state=self.agent_state, interface=self.interface, user=self.user, ) + # Perform main agent step usage_stats = main_agent.step( - messages=messages, + input_messages=new_messages, chaining=chaining, max_chaining_steps=max_chaining_steps, stream=token_streaming, @@ -216,10 +224,12 @@ class SleeptimeMultiAgent(Agent): put_inner_thoughts_first=put_inner_thoughts_first, ) + # Update turns counter turns_counter = None if self.sleeptime_agent_frequency is not None and self.sleeptime_agent_frequency > 0: turns_counter = self.group_manager.bump_turns_counter(group_id=self.group_id, actor=self.user) + # Perform participant steps if self.sleeptime_agent_frequency is None or ( turns_counter is not None and turns_counter % self.sleeptime_agent_frequency == 0 ): diff --git a/letta/groups/supervisor_multi_agent.py b/letta/groups/supervisor_multi_agent.py index bdd8f4f9..ba8f6261 100644 --- a/letta/groups/supervisor_multi_agent.py +++ b/letta/groups/supervisor_multi_agent.py @@ -9,7 +9,7 @@ from letta.interface import AgentInterface from letta.orm import User from letta.orm.enums import ToolType from letta.schemas.letta_message_content import TextContent -from letta.schemas.message import Message, MessageCreate +from letta.schemas.message import MessageCreate from letta.schemas.tool import Tool from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from letta.schemas.usage import LettaUsageStatistics @@ -37,17 +37,18 @@ class SupervisorMultiAgent(Agent): def step( self, - messages: List[MessageCreate], + input_messages: List[MessageCreate], chaining: bool = True, max_chaining_steps: Optional[int] = None, put_inner_thoughts_first: bool = True, assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, **kwargs, ) -> LettaUsageStatistics: + # Load settings token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None - # add multi agent tool + # Prepare supervisor agent if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None: multi_agent_tool = Tool( name=send_message_to_all_agents_in_group.__name__, @@ -64,7 +65,6 @@ class SupervisorMultiAgent(Agent): ) self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user) - # override tool rules old_tool_rules = self.agent_state.tool_rules self.agent_state.tool_rules = [ InitToolRule( @@ -79,24 +79,25 @@ class SupervisorMultiAgent(Agent): ), ] - supervisor_messages = [ - Message( - agent_id=self.agent_state.id, - role="user", - content=[TextContent(text=message.content)], - name=None, - model=None, - tool_calls=None, - tool_call_id=None, - group_id=self.group_id, - otid=message.otid, - ) - for message in messages - ] + # Prepare new messages + new_messages = [] + for message in input_messages: + if isinstance(message.content, str): + message.content = [TextContent(text=message.content)] + message.group_id = self.group_id + new_messages.append(message) + try: - supervisor_agent = Agent(agent_state=self.agent_state, interface=self.interface, user=self.user) + # Load supervisor agent + supervisor_agent = Agent( + agent_state=self.agent_state, + interface=self.interface, + user=self.user, + ) + + # Perform supervisor step usage_stats = supervisor_agent.step( - messages=supervisor_messages, + input_messages=new_messages, chaining=chaining, max_chaining_steps=max_chaining_steps, stream=token_streaming, diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index 5f5b6c04..41d2b8f6 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -40,4 +40,5 @@ def prepare_input_message_create( tool_call_id=None, otid=message.otid, sender_id=message.sender_id, + group_id=message.group_id, ) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index bc5b6141..dfc36fe2 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -82,6 +82,7 @@ class MessageCreate(BaseModel): name: Optional[str] = Field(None, description="The name of the participant.") otid: Optional[str] = Field(None, description="The offline threading id associated with this message") sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id") + group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in") def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]: data = super().model_dump(**kwargs) diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index d3715062..3ebdb3af 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -111,7 +111,7 @@ async def send_message_to_agent_chat_completions( server.send_messages, actor=actor, agent_id=letta_agent.agent_state.id, - messages=messages, + input_messages=messages, interface=streaming_interface, put_inner_thoughts_first=False, ) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 72d8dc68..e8571fa5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -412,7 +412,7 @@ def list_blocks( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) try: - agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor) + agent = server.agent_manager.get_agent_by_id(agent_id, actor) return agent.memory.blocks except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @@ -640,7 +640,7 @@ async def send_message( result = await server.send_message_to_agent( agent_id=agent_id, actor=actor, - messages=request.messages, + input_messages=request.messages, stream_steps=False, stream_tokens=False, # Support for AssistantMessage @@ -703,7 +703,7 @@ async def send_message_streaming( result = await server.send_message_to_agent( agent_id=agent_id, actor=actor, - messages=request.messages, + input_messages=request.messages, stream_steps=True, stream_tokens=request.stream_tokens, # Support for AssistantMessage @@ -730,7 +730,7 @@ async def process_message_background( result = await server.send_message_to_agent( agent_id=agent_id, actor=actor, - messages=messages, + input_messages=messages, stream_steps=False, # NOTE(matt) stream_tokens=False, use_assistant_message=use_assistant_message, diff --git a/letta/server/rest_api/routers/v1/groups.py b/letta/server/rest_api/routers/v1/groups.py index 158a36a6..59cbbc18 100644 --- a/letta/server/rest_api/routers/v1/groups.py +++ b/letta/server/rest_api/routers/v1/groups.py @@ -128,7 +128,7 @@ async def send_group_message( result = await server.send_group_message_to_agent( group_id=group_id, actor=actor, - messages=request.messages, + input_messages=request.messages, stream_steps=False, stream_tokens=False, # Support for AssistantMessage @@ -167,7 +167,7 @@ async def send_group_message_streaming( result = await server.send_group_message_to_agent( group_id=group_id, actor=actor, - messages=request.messages, + input_messages=request.messages, stream_steps=True, stream_tokens=request.stream_tokens, # Support for AssistantMessage diff --git a/letta/server/server.py b/letta/server/server.py index d8b9fddd..36842be7 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -28,7 +28,6 @@ from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerCo from letta.groups.helpers import load_multi_agent from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads -from letta.helpers.message_helper import prepare_input_message_create # TODO use custom interface from letta.interface import AgentInterface # abstract @@ -148,7 +147,7 @@ class Server(object): raise NotImplementedError @abstractmethod - def send_messages(self, user_id: str, agent_id: str, messages: Union[MessageCreate, List[Message]]) -> None: + def send_messages(self, user_id: str, agent_id: str, input_messages: List[MessageCreate]) -> None: """Send a list of messages to the agent""" raise NotImplementedError @@ -372,19 +371,13 @@ class SyncServer(Server): self, actor: User, agent_id: str, - input_messages: Union[Message, List[Message]], + input_messages: List[MessageCreate], interface: Union[AgentInterface, None] = None, # needed to getting responses put_inner_thoughts_first: bool = True, # timestamp: Optional[datetime], ) -> LettaUsageStatistics: """Send the input message through the agent""" # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - # Input validation - if isinstance(input_messages, Message): - input_messages = [input_messages] - if not all(isinstance(m, Message) for m in input_messages): - raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}") - logger.debug(f"Got input messages: {input_messages}") letta_agent = None try: @@ -400,8 +393,9 @@ class SyncServer(Server): metadata = interface.metadata if hasattr(interface, "metadata") else None else: metadata = None + usage_stats = letta_agent.step( - messages=input_messages, + input_messages=input_messages, chaining=self.chaining, max_chaining_steps=self.max_chaining_steps, stream=token_streaming, @@ -572,23 +566,14 @@ class SyncServer(Server): ) # NOTE: eventually deprecate and only allow passing Message types - # Convert to a Message object - if timestamp: - message = Message( - agent_id=agent_id, - role="user", - content=[TextContent(text=packaged_user_message)], - created_at=timestamp, - ) - else: - message = Message( - agent_id=agent_id, - role="user", - content=[TextContent(text=packaged_user_message)], - ) + message = MessageCreate( + agent_id=agent_id, + role="user", + content=[TextContent(text=packaged_user_message)], + ) # Run the agent state forward - usage = self._step(actor=actor, agent_id=agent_id, input_messages=message) + usage = self._step(actor=actor, agent_id=agent_id, input_messages=[message]) return usage def system_message( @@ -660,23 +645,14 @@ class SyncServer(Server): self, actor: User, agent_id: str, - messages: Union[List[MessageCreate], List[Message]], + input_messages: List[MessageCreate], wrap_user_message: bool = True, wrap_system_message: bool = True, interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed for responses metadata: Optional[dict] = None, # Pass through metadata to interface put_inner_thoughts_first: bool = True, ) -> LettaUsageStatistics: - """Send a list of messages to the agent. - - If messages are of type MessageCreate, convert them to Message objects before sending. - """ - if all(isinstance(m, MessageCreate) for m in messages): - message_objects = [prepare_input_message_create(m, agent_id, wrap_user_message, wrap_system_message) for m in messages] - elif all(isinstance(m, Message) for m in messages): - message_objects = messages - else: - raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(m) for m in messages]}") + """Send a list of messages to the agent.""" # Store metadata in interface if provided if metadata and hasattr(interface, "metadata"): @@ -686,7 +662,7 @@ class SyncServer(Server): return self._step( actor=actor, agent_id=agent_id, - input_messages=message_objects, + input_messages=input_messages, interface=interface, put_inner_thoughts_first=put_inner_thoughts_first, ) @@ -1018,12 +994,8 @@ class SyncServer(Server): agent = self.load_agent(agent_id=sleeptime_agent.id, actor=actor) for passage in self.list_data_source_passages(source_id=source.id, user_id=actor.id): agent.step( - messages=[ - Message( - role="user", - content=[TextContent(text=passage.text)], - agent_id=sleeptime_agent.id, - ), + input_messages=[ + MessageCreate(role="user", content=passage.text), ] ) self.agent_manager.delete_agent(agent_id=sleeptime_agent.id, actor=actor) @@ -1563,7 +1535,7 @@ class SyncServer(Server): agent_id: str, actor: User, # role: MessageRole, - messages: Union[List[Message], List[MessageCreate]], + input_messages: List[MessageCreate], stream_steps: bool, stream_tokens: bool, # related to whether or not we return `LettaMessage`s or `Message`s @@ -1643,7 +1615,7 @@ class SyncServer(Server): self.send_messages, actor=actor, agent_id=agent_id, - messages=messages, + input_messages=input_messages, interface=streaming_interface, metadata=metadata, ) @@ -1697,7 +1669,7 @@ class SyncServer(Server): self, group_id: str, actor: User, - messages: Union[List[Message], List[MessageCreate]], + input_messages: Union[List[Message], List[MessageCreate]], stream_steps: bool, stream_tokens: bool, chat_completion_mode: bool = False, @@ -1747,7 +1719,7 @@ class SyncServer(Server): task = asyncio.create_task( asyncio.to_thread( letta_multi_agent.step, - messages=messages, + input_messages=input_messages, chaining=self.chaining, max_chaining_steps=self.max_chaining_steps, ) diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 0752a681..3d215baa 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -132,7 +132,7 @@ async def test_sleeptime_group_chat(server, actor): response = await server.send_message_to_agent( agent_id=main_agent.id, actor=actor, - messages=[ + input_messages=[ MessageCreate( role="user", content=text, @@ -206,7 +206,7 @@ async def test_sleeptime_removes_redundant_information(server, actor): _ = await server.send_message_to_agent( agent_id=main_agent.id, actor=actor, - messages=[ + input_messages=[ MessageCreate( role="user", content=test_message, @@ -270,7 +270,7 @@ async def test_sleeptime_edit(server, actor): _ = await server.send_message_to_agent( agent_id=sleeptime_agent.id, actor=actor, - messages=[ + input_messages=[ MessageCreate( role="user", content="Messi has now moved to playing for Inter Miami", diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index 529b7a48..7599e02e 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -454,7 +454,7 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False server.send_messages( - actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="hello")] + actor=default_user, agent_id=serialize_test_agent.id, input_messages=[MessageCreate(role=MessageRole.user, content="hello")] ) result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) @@ -470,10 +470,12 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test # Make sure both agents can receive messages after server.send_messages( - actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")] + actor=default_user, + agent_id=serialize_test_agent.id, + input_messages=[MessageCreate(role=MessageRole.user, content="and hello again")], ) server.send_messages( - actor=other_user, agent_id=agent_copy.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")] + actor=other_user, agent_id=agent_copy.id, input_messages=[MessageCreate(role=MessageRole.user, content="and hello again")] ) @@ -483,7 +485,7 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s server.send_messages( actor=default_user, agent_id=serialize_test_agent.id, - messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in San Francisco?")], + input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in San Francisco?")], ) result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) @@ -501,12 +503,12 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s original_agent_response = server.send_messages( actor=default_user, agent_id=serialize_test_agent.id, - messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")], + input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")], ) copy_agent_response = server.send_messages( actor=other_user, agent_id=agent_copy.id, - messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")], + input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")], ) assert original_agent_response.completion_tokens > 0 and original_agent_response.step_count > 0 @@ -519,12 +521,12 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server server.send_messages( actor=default_user, agent_id=serialize_test_agent.id, - messages=[MessageCreate(role=MessageRole.user, content="Append 'banana' to core_memory.")], + input_messages=[MessageCreate(role=MessageRole.user, content="Append 'banana' to core_memory.")], ) server.send_messages( actor=default_user, agent_id=serialize_test_agent.id, - messages=[MessageCreate(role=MessageRole.user, content="What do you think about that?")], + input_messages=[MessageCreate(role=MessageRole.user, content="What do you think about that?")], ) result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) @@ -543,12 +545,12 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server original_agent_response = server.send_messages( actor=default_user, agent_id=serialize_test_agent.id, - messages=[MessageCreate(role=MessageRole.user, content="Hi")], + input_messages=[MessageCreate(role=MessageRole.user, content="Hi")], ) copy_agent_response = server.send_messages( actor=other_user, agent_id=agent_copy.id, - messages=[MessageCreate(role=MessageRole.user, content="Hi")], + input_messages=[MessageCreate(role=MessageRole.user, content="Hi")], ) assert original_agent_response.completion_tokens > 0 and original_agent_response.step_count > 0 @@ -635,5 +637,5 @@ def test_upload_agentfile_from_disk(server, disable_e2b_api_key, fastapi_client, server.send_messages( actor=other_user, agent_id=copied_agent_id, - messages=[MessageCreate(role=MessageRole.user, content="Hello there!")], + input_messages=[MessageCreate(role=MessageRole.user, content="Hello there!")], ) diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index acbee470..f0dc5e68 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -158,7 +158,7 @@ async def test_empty_group(server, actor): await server.send_group_message_to_agent( group_id=group.id, actor=actor, - messages=[ + input_messages=[ MessageCreate( role="user", content="what is everyone up to for the holidays?", @@ -246,7 +246,7 @@ async def test_round_robin(server, actor, participant_agents): response = await server.send_group_message_to_agent( group_id=group.id, actor=actor, - messages=[ + input_messages=[ MessageCreate( role="user", content="what is everyone up to for the holidays?", @@ -301,7 +301,7 @@ async def test_round_robin(server, actor, participant_agents): response = await server.send_group_message_to_agent( group_id=group.id, actor=actor, - messages=[ + input_messages=[ MessageCreate( role="user", content="what is everyone up to for the holidays?", @@ -367,7 +367,7 @@ async def test_supervisor(server, actor, participant_agents): response = await server.send_group_message_to_agent( group_id=group.id, actor=actor, - messages=[ + input_messages=[ MessageCreate( role="user", content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.", @@ -449,7 +449,7 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen response = await server.send_group_message_to_agent( group_id=group.id, actor=actor, - messages=[ + input_messages=[ MessageCreate(role="user", content="what is everyone up to for the holidays?"), ], stream_steps=False,