diff --git a/alembic/versions/a66510f83fc2_add_ordered_agent_ids_to_groups.py b/alembic/versions/a66510f83fc2_add_ordered_agent_ids_to_groups.py new file mode 100644 index 00000000..b769627b --- /dev/null +++ b/alembic/versions/a66510f83fc2_add_ordered_agent_ids_to_groups.py @@ -0,0 +1,31 @@ +"""add ordered agent ids to groups + +Revision ID: a66510f83fc2 +Revises: bdddd421ec41 +Create Date: 2025-03-27 11:11:51.709498 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a66510f83fc2" +down_revision: Union[str, None] = "bdddd421ec41" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("groups", sa.Column("agent_ids", sa.JSON(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("groups", "agent_ids") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 3c6141e6..631020b6 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -220,6 +220,7 @@ class Agent(BaseAgent): messages: List[Message], tool_returns: Optional[List[ToolReturn]] = None, include_function_failed_message: bool = False, + group_id: Optional[str] = None, ) -> List[Message]: """ Handle error from function call response @@ -240,7 +241,9 @@ class Agent(BaseAgent): "content": function_response, "tool_call_id": tool_call_id, }, + name=self.agent_state.name, tool_returns=tool_returns, + group_id=group_id, ) messages.append(new_message) self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message) @@ -329,6 +332,7 @@ class Agent(BaseAgent): stream=stream, stream_interface=self.interface, put_inner_thoughts_first=put_inner_thoughts_first, + name=self.agent_state.name, ) log_telemetry(self.logger, "_get_ai_reply create finish") @@ -372,6 +376,7 @@ class Agent(BaseAgent): # and now we want to use it in the creation of the Message object # TODO figure out a cleaner way to do this response_message_id: Optional[str] = None, + group_id: Optional[str] = None, ) -> Tuple[List[Message], bool, bool]: """Handles parsing and function execution""" log_telemetry(self.logger, "_handle_ai_response start") @@ -417,6 +422,8 @@ class Agent(BaseAgent): user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=response_message.model_dump(), + name=self.agent_state.name, + group_id=group_id, ) ) # extend conversation with assistant's reply self.logger.debug(f"Function call message: {messages[-1]}") @@ -449,7 +456,7 @@ class Agent(BaseAgent): error_msg = f"No function named {function_name}" function_response = "None" # more like "never ran?" messages = self._handle_function_error_response( - error_msg, tool_call_id, function_name, function_args, function_response, messages + error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id ) return messages, False, True # force a heartbeat to allow agent to handle error @@ -464,7 +471,7 @@ class Agent(BaseAgent): error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}" function_response = "None" # more like "never ran?" messages = self._handle_function_error_response( - error_msg, tool_call_id, function_name, function_args, function_response, messages + error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id ) return messages, False, True # force a heartbeat to allow agent to handle error @@ -535,6 +542,7 @@ class Agent(BaseAgent): function_response, messages, [tool_return], + group_id=group_id, ) return messages, False, True # force a heartbeat to allow agent to handle error @@ -571,6 +579,7 @@ class Agent(BaseAgent): messages, [ToolReturn(status="error", stderr=[error_msg_user])], include_function_failed_message=True, + group_id=group_id, ) return messages, False, True # force a heartbeat to allow agent to handle error @@ -595,6 +604,7 @@ class Agent(BaseAgent): messages, [tool_return], include_function_failed_message=True, + group_id=group_id, ) return messages, False, True # force a heartbeat to allow agent to handle error @@ -620,7 +630,9 @@ class Agent(BaseAgent): "content": function_response, "tool_call_id": tool_call_id, }, + name=self.agent_state.name, tool_returns=[tool_return] if sandbox_run_result else None, + group_id=group_id, ) ) # extend conversation with function response self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) @@ -636,6 +648,8 @@ class Agent(BaseAgent): user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=response_message.model_dump(), + name=self.agent_state.name, + group_id=group_id, ) ) # extend conversation with assistant's reply self.interface.internal_monologue(response_message.content, msg_obj=messages[-1]) @@ -799,7 +813,11 @@ class Agent(BaseAgent): in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) input_message_sequence = in_context_messages + messages - if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user": + if ( + len(input_message_sequence) > 1 + and input_message_sequence[-1].role != "user" + and input_message_sequence[-1].group_id is None + ): self.logger.warning(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue") # Step 2: send the conversation and available functions to the LLM @@ -832,6 +850,7 @@ class Agent(BaseAgent): # TODO this is kind of hacky, find a better way to handle this # the only time we set up message creation ahead of time is when streaming is on response_message_id=response.id if stream else None, + group_id=input_message_sequence[-1].group_id, ) # Step 6: extend the message history diff --git a/letta/dynamic_multi_agent.py b/letta/dynamic_multi_agent.py index 93599324..c807efa7 100644 --- a/letta/dynamic_multi_agent.py +++ b/letta/dynamic_multi_agent.py @@ -16,7 +16,7 @@ class DynamicMultiAgent(Agent): self, interface: AgentInterface, agent_state: AgentState, - user: User = None, + user: User, # custom group_id: str = "", agent_ids: List[str] = [], @@ -128,7 +128,7 @@ class DynamicMultiAgent(Agent): ) for message in assistant_messages ] - message_index[agent_id] = len(chat_history) + len(new_messages) + message_index[speaker_id] = len(chat_history) + len(new_messages) # sum usage total_usage.prompt_tokens += usage_stats.prompt_tokens @@ -251,10 +251,10 @@ class DynamicMultiAgent(Agent): chat_history: List[Message], agent_id_options: List[str], ) -> Message: - chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history] + text_chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history] for message in new_messages: - chat_history.append(f"{message.name or 'user'}: {message.content}") - context_messages = "\n".join(chat_history) + text_chat_history.append(f"{message.name or 'user'}: {message.content}") + context_messages = "\n".join(text_chat_history) message_text = ( "Choose the most suitable agent to reply to the latest message in the " diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 50f789ad..4cf3f420 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -859,6 +859,7 @@ def anthropic_chat_completions_process_stream( create_message_id: bool = True, create_message_datetime: bool = True, betas: List[str] = ["tools-2024-04-04"], + name: Optional[str] = None, ) -> ChatCompletionResponse: """Process a streaming completion response from Anthropic, similar to OpenAI's streaming. @@ -951,6 +952,7 @@ def anthropic_chat_completions_process_stream( # if extended_thinking is on, then reasoning_content will be flowing as chunks # TODO handle emitting redacted reasoning content (e.g. as concat?) expect_reasoning_content=extended_thinking, + name=name, ) elif isinstance(stream_interface, AgentRefreshStreamingInterface): stream_interface.process_refresh(chat_completion_response) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 05ce7b5e..f489b873 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -140,6 +140,7 @@ def create( stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None, model_settings: Optional[dict] = None, # TODO: eventually pass from server put_inner_thoughts_first: bool = True, + name: Optional[str] = None, ) -> ChatCompletionResponse: """Return response to chat completion with backoff""" from letta.utils import printd @@ -206,6 +207,7 @@ def create( api_key=api_key, chat_completion_request=data, stream_interface=stream_interface, + name=name, ) else: # Client did not request token streaming (expect a blocking backend response) data.stream = False @@ -255,6 +257,7 @@ def create( api_key=api_key, chat_completion_request=data, stream_interface=stream_interface, + name=name, ) else: # Client did not request token streaming (expect a blocking backend response) data.stream = False @@ -359,6 +362,7 @@ def create( stream_interface=stream_interface, extended_thinking=llm_config.enable_reasoner, max_reasoning_tokens=llm_config.max_reasoning_tokens, + name=name, ) else: @@ -531,6 +535,7 @@ def create( api_key=model_settings.deepseek_api_key, chat_completion_request=data, stream_interface=stream_interface, + name=name, ) else: # Client did not request token streaming (expect a blocking backend response) data.stream = False diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 948730b6..a99ffb78 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -185,6 +185,7 @@ def openai_chat_completions_process_stream( # however, we don't necessarily want to put these # expect_reasoning_content: bool = False, expect_reasoning_content: bool = True, + name: Optional[str] = None, ) -> ChatCompletionResponse: """Process a streaming completion response, and return a ChatCompletionRequest at the end. @@ -272,6 +273,7 @@ def openai_chat_completions_process_stream( message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id, message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, expect_reasoning_content=expect_reasoning_content, + name=name, ) elif isinstance(stream_interface, AgentRefreshStreamingInterface): stream_interface.process_refresh(chat_completion_response) diff --git a/letta/orm/group.py b/letta/orm/group.py index 308c43d1..3599386b 100644 --- a/letta/orm/group.py +++ b/letta/orm/group.py @@ -1,7 +1,7 @@ import uuid from typing import List, Optional -from sqlalchemy import ForeignKey, String +from sqlalchemy import JSON, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin @@ -23,11 +23,8 @@ class Group(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="groups") + agent_ids: Mapped[List[str]] = mapped_column(JSON, nullable=False, doc="Ordered list of agent IDs in this group") agents: Mapped[List["Agent"]] = relationship( "Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups" ) manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group") - - @property - def agent_ids(self) -> List[str]: - return [agent.id for agent in self.agents] diff --git a/letta/round_robin_multi_agent.py b/letta/round_robin_multi_agent.py index 1796b882..9bb62146 100644 --- a/letta/round_robin_multi_agent.py +++ b/letta/round_robin_multi_agent.py @@ -14,7 +14,7 @@ class RoundRobinMultiAgent(Agent): self, interface: AgentInterface, agent_state: AgentState, - user: User = None, + user: User, # custom group_id: str = "", agent_ids: List[str] = [], @@ -45,7 +45,7 @@ class RoundRobinMultiAgent(Agent): for agent_id in self.agent_ids: agents[agent_id] = self.load_participant_agent(agent_id=agent_id) - message_index = {} + message_index = {agent_id: 0 for agent_id in self.agent_ids} chat_history: List[Message] = [] new_messages = messages speaker_id = None @@ -91,7 +91,7 @@ class RoundRobinMultiAgent(Agent): MessageCreate( role="system", content=message.content, - name=participant_agent.agent_state.name, + name=message.name, ) for message in assistant_messages ] @@ -138,10 +138,21 @@ class RoundRobinMultiAgent(Agent): agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) persona_block = agent_state.memory.get_block(label="persona") group_chat_participant_persona = ( - "\n\n====Group Chat Contex====" - f"\nYou are speaking in a group chat with {len(self.agent_ids) - 1} other " - "agents and one user. Respond to new messages in the group chat when prompted. " - f"Description of the group: {self.description}" + f"%%% GROUP CHAT CONTEXT %%% " + f"You are speaking in a group chat with {len(self.agent_ids)} other participants. " + f"Group Description: {self.description} " + "INTERACTION GUIDELINES:\n" + "1. Be aware that others can see your messages - communicate as if in a real group conversation\n" + "2. Acknowledge and build upon others' contributions when relevant\n" + "3. Stay on topic while adding your unique perspective based on your role and personality\n" + "4. Be concise but engaging - give others space to contribute\n" + "5. Maintain your character's personality while being collaborative\n" + "6. Feel free to ask questions to other participants to encourage discussion\n" + "7. If someone addresses you directly, acknowledge their message\n" + "8. Share relevant experiences or knowledge that adds value to the conversation\n\n" + "Remember: This is a natural group conversation. Interact as you would in a real group setting, " + "staying true to your character while fostering meaningful dialogue. " + "%%% END GROUP CHAT CONTEXT %%%" ) agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_participant_persona) return Agent( diff --git a/letta/schemas/group.py b/letta/schemas/group.py index 254baa8b..e3d5fadf 100644 --- a/letta/schemas/group.py +++ b/letta/schemas/group.py @@ -62,4 +62,10 @@ ManagerConfigUnion = Annotated[ class GroupCreate(BaseModel): agent_ids: List[str] = Field(..., description="") description: str = Field(..., description="") + manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="") + + +class GroupUpdate(BaseModel): + agent_ids: Optional[List[str]] = Field(None, description="") + description: Optional[str] = Field(None, description="") manager_config: Optional[ManagerConfigUnion] = Field(None, description="") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 3e736600..65e5d84b 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -226,6 +226,7 @@ class Message(BaseMessage): id=self.id, date=self.created_at, reasoning=self.content[0].text, + name=self.name, ) ) # Otherwise, we may have a list of multiple types @@ -239,6 +240,7 @@ class Message(BaseMessage): id=self.id, date=self.created_at, reasoning=content_part.text, + name=self.name, ) ) elif isinstance(content_part, ReasoningContent): @@ -250,6 +252,7 @@ class Message(BaseMessage): reasoning=content_part.reasoning, source="reasoner_model", # TODO do we want to tag like this? signature=content_part.signature, + name=self.name, ) ) elif isinstance(content_part, RedactedReasoningContent): @@ -260,6 +263,7 @@ class Message(BaseMessage): date=self.created_at, state="redacted", hidden_reasoning=content_part.data, + name=self.name, ) ) else: @@ -282,6 +286,7 @@ class Message(BaseMessage): id=self.id, date=self.created_at, content=message_string, + name=self.name, ) ) else: @@ -294,6 +299,7 @@ class Message(BaseMessage): arguments=tool_call.function.arguments, tool_call_id=tool_call.id, ), + name=self.name, ) ) elif self.role == MessageRole.tool: @@ -334,6 +340,7 @@ class Message(BaseMessage): tool_call_id=self.tool_call_id, stdout=self.tool_returns[0].stdout if self.tool_returns else None, stderr=self.tool_returns[0].stderr if self.tool_returns else None, + name=self.name, ) ) elif self.role == MessageRole.user: @@ -349,6 +356,7 @@ class Message(BaseMessage): id=self.id, date=self.created_at, content=message_str or text_content, + name=self.name, ) ) elif self.role == MessageRole.system: @@ -363,6 +371,7 @@ class Message(BaseMessage): id=self.id, date=self.created_at, content=text_content, + name=self.name, ) ) else: @@ -379,6 +388,8 @@ class Message(BaseMessage): allow_functions_style: bool = False, # allow deprecated functions style? created_at: Optional[datetime] = None, id: Optional[str] = None, + name: Optional[str] = None, + group_id: Optional[str] = None, tool_returns: Optional[List[ToolReturn]] = None, ): """Convert a ChatCompletion message object into a Message object (synced to DB)""" @@ -426,12 +437,13 @@ class Message(BaseMessage): # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole.tool, # NOTE content=content, - name=openai_message_dict["name"] if "name" in openai_message_dict else None, + name=name, tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, created_at=created_at, id=str(id), tool_returns=tool_returns, + group_id=group_id, ) else: return Message( @@ -440,11 +452,12 @@ class Message(BaseMessage): # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole.tool, # NOTE content=content, - name=openai_message_dict["name"] if "name" in openai_message_dict else None, + name=name, tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, created_at=created_at, tool_returns=tool_returns, + group_id=group_id, ) elif "function_call" in openai_message_dict and openai_message_dict["function_call"] is not None: @@ -473,12 +486,13 @@ class Message(BaseMessage): # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole(openai_message_dict["role"]), content=content, - name=openai_message_dict["name"] if "name" in openai_message_dict else None, + name=name, tool_calls=tool_calls, tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool' created_at=created_at, id=str(id), tool_returns=tool_returns, + group_id=group_id, ) else: return Message( @@ -492,6 +506,7 @@ class Message(BaseMessage): tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool' created_at=created_at, tool_returns=tool_returns, + group_id=group_id, ) else: @@ -520,12 +535,13 @@ class Message(BaseMessage): # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole(openai_message_dict["role"]), content=content, - name=openai_message_dict["name"] if "name" in openai_message_dict else None, + name=name, tool_calls=tool_calls, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, created_at=created_at, id=str(id), tool_returns=tool_returns, + group_id=group_id, ) else: return Message( @@ -534,11 +550,12 @@ class Message(BaseMessage): # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole(openai_message_dict["role"]), content=content, - name=openai_message_dict["name"] if "name" in openai_message_dict else None, + name=name, tool_calls=tool_calls, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, created_at=created_at, tool_returns=tool_returns, + group_id=group_id, ) def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict: @@ -579,9 +596,6 @@ class Message(BaseMessage): "content": text_content, "role": self.role, } - # Optional field, do not include if null - if self.name is not None: - openai_message["name"] = self.name elif self.role == "user": assert all([v is not None for v in [text_content, self.role]]), vars(self) @@ -589,9 +603,6 @@ class Message(BaseMessage): "content": text_content, "role": self.role, } - # Optional field, do not include if null - if self.name is not None: - openai_message["name"] = self.name elif self.role == "assistant": assert self.tool_calls is not None or text_content is not None @@ -599,9 +610,7 @@ class Message(BaseMessage): "content": None if put_inner_thoughts_in_kwargs else text_content, "role": self.role, } - # Optional fields, do not include if null - if self.name is not None: - openai_message["name"] = self.name + if self.tool_calls is not None: if put_inner_thoughts_in_kwargs: # put the inner thoughts inside the tool call before casting to a dict diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index fefb2077..405eb476 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -465,6 +465,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # if we expect `reasoning_content``, then that's what gets mapped to ReasoningMessage # and `content` needs to be handled outside the interface expect_reasoning_content: bool = False, + name: Optional[str] = None, ) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]: """ Example data from non-streaming response looks like: @@ -497,6 +498,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): reasoning=message_delta.reasoning_content, signature=message_delta.reasoning_content_signature, source="reasoner_model" if message_delta.reasoning_content_signature else "non_reasoner_model", + name=name, ) elif expect_reasoning_content and message_delta.redacted_reasoning_content is not None: processed_chunk = HiddenReasoningMessage( @@ -504,6 +506,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=message_date, hidden_reasoning=message_delta.redacted_reasoning_content, state="redacted", + name=name, ) elif expect_reasoning_content and message_delta.content is not None: # "ignore" content if we expect reasoning content @@ -530,6 +533,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): arguments=json.dumps(json_reasoning_content.get("arguments")), tool_call_id=None, ), + name=name, ) except json.JSONDecodeError as e: @@ -559,6 +563,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): id=message_id, date=message_date, reasoning=message_delta.content, + name=name, ) # tool calls @@ -607,7 +612,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # TODO: Assumes consistent state and that prev_content is subset of new_content diff = new_content.replace(prev_content, "", 1) self.current_json_parse_result = parsed_args - processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff) + processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name) else: return None @@ -639,6 +644,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): arguments=tool_call_delta.get("arguments"), tool_call_id=tool_call_delta.get("id"), ), + name=name, ) elif self.inner_thoughts_in_kwargs and tool_call.function: @@ -674,6 +680,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): id=message_id, date=message_date, reasoning=updates_inner_thoughts, + name=name, ) # Additionally inner thoughts may stream back with a chunk of main JSON # In that case, since we can only return a chunk at a time, we should buffer it @@ -709,6 +716,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): arguments=None, tool_call_id=self.function_id_buffer, ), + name=name, ) # Record what the last function name we flushed was @@ -765,6 +773,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): id=message_id, date=message_date, content=combined_chunk, + name=name, ) # Store the ID of the tool call so allow skipping the corresponding response if self.function_id_buffer: @@ -789,7 +798,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # TODO: Assumes consistent state and that prev_content is subset of new_content diff = new_content.replace(prev_content, "", 1) self.current_json_parse_result = parsed_args - processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff) + processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name) else: return None @@ -813,6 +822,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): arguments=combined_chunk, tool_call_id=self.function_id_buffer, ), + name=name, ) # clear buffer self.function_args_buffer = None @@ -827,6 +837,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): arguments=updates_main_json, tool_call_id=self.function_id_buffer, ), + name=name, ) self.function_id_buffer = None @@ -955,6 +966,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): arguments=tool_call_delta.get("arguments"), tool_call_id=tool_call_delta.get("id"), ), + name=name, ) elif choice.finish_reason is not None: @@ -1035,6 +1047,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): message_id: str, message_date: datetime, expect_reasoning_content: bool = False, + name: Optional[str] = None, ): """Process a streaming chunk from an OpenAI-compatible server. @@ -1060,6 +1073,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): message_id=message_id, message_date=message_date, expect_reasoning_content=expect_reasoning_content, + name=name, ) if processed_chunk is None: @@ -1087,6 +1101,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): id=msg_obj.id, date=msg_obj.created_at, reasoning=msg, + name=msg_obj.name, ) self._push_to_buffer(processed_chunk) @@ -1097,6 +1112,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): id=msg_obj.id, date=msg_obj.created_at, reasoning=content.text, + name=msg_obj.name, ) elif isinstance(content, ReasoningContent): processed_chunk = ReasoningMessage( @@ -1105,6 +1121,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): source="reasoner_model", reasoning=content.reasoning, signature=content.signature, + name=msg_obj.name, ) elif isinstance(content, RedactedReasoningContent): processed_chunk = HiddenReasoningMessage( @@ -1112,6 +1129,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=msg_obj.created_at, state="redacted", hidden_reasoning=content.data, + name=msg_obj.name, ) self._push_to_buffer(processed_chunk) @@ -1172,6 +1190,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): id=msg_obj.id, date=msg_obj.created_at, content=func_args["message"], + name=msg_obj.name, ) self._push_to_buffer(processed_chunk) except Exception as e: @@ -1194,6 +1213,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): id=msg_obj.id, date=msg_obj.created_at, content=func_args[self.assistant_message_tool_kwarg], + name=msg_obj.name, ) # Store the ID of the tool call so allow skipping the corresponding response self.prev_assistant_message_id = function_call.id @@ -1206,6 +1226,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): arguments=function_call.function.arguments, tool_call_id=function_call.id, ), + name=msg_obj.name, ) # processed_chunk = { @@ -1245,6 +1266,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=msg_obj.tool_call_id, stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None, stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None, + name=msg_obj.name, ) elif msg.startswith("Error: "): @@ -1259,6 +1281,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=msg_obj.tool_call_id, stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None, stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None, + name=msg_obj.name, ) else: diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 6a05566c..1dc6c632 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -527,6 +527,7 @@ def list_messages( after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."), before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), limit: int = Query(10, description="Maximum number of messages to retrieve."), + group_id: Optional[str] = Query(None, description="Group ID to filter messages by."), use_assistant_message: bool = Query(True, description="Whether to use assistant messages"), assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."), assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."), @@ -543,6 +544,7 @@ def list_messages( after=after, before=before, limit=limit, + group_id=group_id, reverse=True, return_message_object=False, use_assistant_message=use_assistant_message, diff --git a/letta/server/rest_api/routers/v1/groups.py b/letta/server/rest_api/routers/v1/groups.py index 9e02ddd7..575b6a54 100644 --- a/letta/server/rest_api/routers/v1/groups.py +++ b/letta/server/rest_api/routers/v1/groups.py @@ -1,11 +1,13 @@ from typing import Annotated, List, Optional -from fastapi import APIRouter, Body, Depends, Header, Query +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status +from fastapi.responses import JSONResponse from pydantic import Field from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG -from letta.schemas.group import Group, GroupCreate, ManagerType -from letta.schemas.letta_message import LettaMessageUnion +from letta.orm.errors import NoResultFound +from letta.schemas.group import Group, GroupCreate, GroupUpdate, ManagerType +from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse from letta.server.rest_api.utils import get_letta_server @@ -14,21 +16,6 @@ from letta.server.server import SyncServer router = APIRouter(prefix="/groups", tags=["groups"]) -@router.post("/", response_model=Group, operation_id="create_group") -async def create_group( - server: SyncServer = Depends(get_letta_server), - request: GroupCreate = Body(...), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): - """ - Create a multi-agent group with a specified management pattern. When no - management config is specified, this endpoint will use round robin for - speaker selection. - """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.group_manager.create_group(request, actor=actor) - - @router.get("/", response_model=List[Group], operation_id="list_groups") def list_groups( server: "SyncServer" = Depends(get_letta_server), @@ -53,6 +40,23 @@ def list_groups( ) +@router.get("/{group_id}", response_model=Group, operation_id="retrieve_group") +def retrieve_group( + group_id: str, + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Retrieve the group by id. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + + try: + return server.group_manager.retrieve_group(group_id=group_id, actor=actor) + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) + + @router.post("/", response_model=Group, operation_id="create_group") def create_group( group: GroupCreate = Body(...), @@ -70,9 +74,10 @@ def create_group( raise HTTPException(status_code=500, detail=str(e)) -@router.put("/", response_model=Group, operation_id="upsert_group") -def upsert_group( - group: GroupCreate = Body(...), +@router.put("/{group_id}", response_model=Group, operation_id="modify_group") +def modify_group( + group_id: str, + group: GroupUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware @@ -82,7 +87,7 @@ def upsert_group( """ try: actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.group_manager.create_group(group, actor=actor) + return server.group_manager.modify_group(group_id=group_id, group_update=group, actor=actor) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -110,7 +115,7 @@ def delete_group( operation_id="send_group_message", ) async def send_group_message( - agent_id: str, + group_id: str, server: SyncServer = Depends(get_letta_server), request: LettaRequest = Body(...), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -178,6 +183,22 @@ GroupMessagesResponse = Annotated[ ] +@router.patch("/{group_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_group_message") +def modify_group_message( + group_id: str, + message_id: str, + request: LettaMessageUpdateUnion = Body(...), + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Update the details of a message associated with an agent. + """ + # TODO: support modifying tool calls/returns + actor = server.user_manager.get_user_or_default(user_id=actor_id) + return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor) + + @router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages") def list_group_messages( group_id: str, @@ -194,40 +215,42 @@ def list_group_messages( Retrieve message history for an agent. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - - return server.group_manager.list_group_messages( - group_id=group_id, - before=before, - after=after, - limit=limit, - actor=actor, - use_assistant_message=use_assistant_message, - assistant_message_tool_name=assistant_message_tool_name, - assistant_message_tool_kwarg=assistant_message_tool_kwarg, - ) + group = server.group_manager.retrieve_group(group_id=group_id, actor=actor) + if group.manager_agent_id: + return server.get_agent_recall( + user_id=actor.id, + agent_id=group.manager_agent_id, + after=after, + before=before, + limit=limit, + group_id=group_id, + reverse=True, + return_message_object=False, + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + ) + else: + return server.group_manager.list_group_messages( + group_id=group_id, + after=after, + before=before, + limit=limit, + actor=actor, + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + ) -''' @router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages") def reset_group_messages( group_id: str, - add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ - Resets the messages for all agents that are part of the multi-agent group. - TODO: only delete group messages not all messages! + Delete the group messages for all agents that are part of the multi-agent group. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - group = server.group_manager.retrieve_group(group_id=group_id, actor=actor) - agent_ids = group.agent_ids - if group.manager_agent_id: - agent_ids.append(group.manager_agent_id) - for agent_id in agent_ids: - server.agent_manager.reset_messages( - agent_id=agent_id, - actor=actor, - add_default_initial_messages=add_default_initial_messages, - ) -''' + server.group_manager.reset_messages(group_id=group_id, actor=actor) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index d8ca9804..72ec07eb 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -211,7 +211,6 @@ def create_tool_call_messages_from_openai_response( tool_calls=[], tool_call_id=tool_call_id, created_at=get_utc_time(), - name=function_name, ) messages.append(tool_message) diff --git a/letta/server/server.py b/letta/server/server.py index ae889275..636a5289 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -367,6 +367,9 @@ class SyncServer(Server): def load_multi_agent( self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None ) -> Agent: + if len(group.agent_ids) == 0: + raise ValueError("Empty group: group must have at least one agent") + match group.manager_type: case ManagerType.round_robin: agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor) @@ -862,6 +865,7 @@ class SyncServer(Server): after: Optional[str] = None, before: Optional[str] = None, limit: Optional[int] = 100, + group_id: Optional[str] = None, reverse: Optional[bool] = False, return_message_object: bool = True, use_assistant_message: bool = True, @@ -879,6 +883,7 @@ class SyncServer(Server): before=before, limit=limit, ascending=not reverse, + group_id=group_id, ) if not return_message_object: @@ -1591,88 +1596,76 @@ class SyncServer(Server): ) -> Union[StreamingResponse, LettaResponse]: include_final_message = True if not stream_steps and stream_tokens: - raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'") + raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'") - try: - # fetch the group - group = self.group_manager.retrieve_group(group_id=group_id, actor=actor) - letta_multi_agent = self.load_multi_agent(group=group, actor=actor) + group = self.group_manager.retrieve_group(group_id=group_id, actor=actor) + letta_multi_agent = self.load_multi_agent(group=group, actor=actor) - llm_config = letta_multi_agent.agent_state.llm_config - supports_token_streaming = ["openai", "anthropic", "deepseek"] - if stream_tokens and ( - llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint - ): - warnings.warn( - f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." - ) - stream_tokens = False + llm_config = letta_multi_agent.agent_state.llm_config + supports_token_streaming = ["openai", "anthropic", "deepseek"] + if stream_tokens and ( + llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint + ): + warnings.warn( + f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." + ) + stream_tokens = False - # Create a new interface per request - letta_multi_agent.interface = StreamingServerInterface( - use_assistant_message=use_assistant_message, - assistant_message_tool_name=assistant_message_tool_name, - assistant_message_tool_kwarg=assistant_message_tool_kwarg, - inner_thoughts_in_kwargs=( - llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False + # Create a new interface per request + letta_multi_agent.interface = StreamingServerInterface( + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + inner_thoughts_in_kwargs=( + llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False + ), + ) + streaming_interface = letta_multi_agent.interface + if not isinstance(streaming_interface, StreamingServerInterface): + raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") + streaming_interface.streaming_mode = stream_tokens + streaming_interface.streaming_chat_completion_mode = chat_completion_mode + if metadata and hasattr(streaming_interface, "metadata"): + streaming_interface.metadata = metadata + + streaming_interface.stream_start() + task = asyncio.create_task( + asyncio.to_thread( + letta_multi_agent.step, + messages=messages, + chaining=self.chaining, + max_chaining_steps=self.max_chaining_steps, + ) + ) + + if stream_steps: + # return a stream + return StreamingResponse( + sse_async_generator( + streaming_interface.get_generator(), + usage_task=task, + finish_message=include_final_message, ), - ) - streaming_interface = letta_multi_agent.interface - if not isinstance(streaming_interface, StreamingServerInterface): - raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") - streaming_interface.streaming_mode = stream_tokens - streaming_interface.streaming_chat_completion_mode = chat_completion_mode - if metadata and hasattr(streaming_interface, "metadata"): - streaming_interface.metadata = metadata - - streaming_interface.stream_start() - task = asyncio.create_task( - asyncio.to_thread( - letta_multi_agent.step, - messages=messages, - chaining=self.chaining, - max_chaining_steps=self.max_chaining_steps, - ) + media_type="text/event-stream", ) - if stream_steps: - # return a stream - return StreamingResponse( - sse_async_generator( - streaming_interface.get_generator(), - usage_task=task, - finish_message=include_final_message, - ), - media_type="text/event-stream", - ) + else: + # buffer the stream, then return the list + generated_stream = [] + async for message in streaming_interface.get_generator(): + assert ( + isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or isinstance(message, MessageStreamStatus) + ), type(message) + generated_stream.append(message) + if message == MessageStreamStatus.done: + break - else: - # buffer the stream, then return the list - generated_stream = [] - async for message in streaming_interface.get_generator(): - assert ( - isinstance(message, LettaMessage) - or isinstance(message, LegacyLettaMessage) - or isinstance(message, MessageStreamStatus) - ), type(message) - generated_stream.append(message) - if message == MessageStreamStatus.done: - break + # Get rid of the stream status messages + filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] + usage = await task - # Get rid of the stream status messages - filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] - usage = await task - - # By default the stream will be messages of type LettaMessage or LettaLegacyMessage - # If we want to convert these to Message, we can use the attached IDs - # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) - # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead - return LettaResponse(messages=filtered_stream, usage=usage) - except HTTPException: - raise - except Exception as e: - print(e) - import traceback - - traceback.print_exc() - raise HTTPException(status_code=500, detail=f"{e}") + # By default the stream will be messages of type LettaMessage or LettaLegacyMessage + # If we want to convert these to Message, we can use the attached IDs + # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) + # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead + return LettaResponse(messages=filtered_stream, usage=usage) diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index c3576ee8..043f7059 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -7,7 +7,9 @@ from letta.orm.errors import NoResultFound from letta.orm.group import Group as GroupModel from letta.orm.message import Message as MessageModel from letta.schemas.group import Group as PydanticGroup -from letta.schemas.group import GroupCreate, ManagerType +from letta.schemas.group import GroupCreate, GroupUpdate, ManagerType +from letta.schemas.letta_message import LettaMessage +from letta.schemas.message import Message as PydanticMessage from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types @@ -22,12 +24,12 @@ class GroupManager: @enforce_types def list_groups( self, + actor: PydanticUser, project_id: Optional[str] = None, manager_type: Optional[ManagerType] = None, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 50, - actor: PydanticUser = None, ) -> list[PydanticGroup]: with self.session_maker() as session: filters = {"organization_id": actor.organization_id} @@ -56,27 +58,66 @@ class GroupManager: new_group = GroupModel() new_group.organization_id = actor.organization_id new_group.description = group.description + + match group.manager_config.manager_type: + case ManagerType.round_robin: + new_group.manager_type = ManagerType.round_robin + new_group.max_turns = group.manager_config.max_turns + case ManagerType.dynamic: + new_group.manager_type = ManagerType.dynamic + new_group.manager_agent_id = group.manager_config.manager_agent_id + new_group.max_turns = group.manager_config.max_turns + new_group.termination_token = group.manager_config.termination_token + case ManagerType.supervisor: + new_group.manager_type = ManagerType.supervisor + new_group.manager_agent_id = group.manager_config.manager_agent_id + case _: + raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") + self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False) - if group.manager_config is None: - new_group.manager_type = ManagerType.round_robin - else: - match group.manager_config.manager_type: - case ManagerType.round_robin: - new_group.manager_type = ManagerType.round_robin - new_group.max_turns = group.manager_config.max_turns - case ManagerType.dynamic: - new_group.manager_type = ManagerType.dynamic - new_group.manager_agent_id = group.manager_config.manager_agent_id - new_group.max_turns = group.manager_config.max_turns - new_group.termination_token = group.manager_config.termination_token - case ManagerType.supervisor: - new_group.manager_type = ManagerType.supervisor - new_group.manager_agent_id = group.manager_config.manager_agent_id - case _: - raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") + new_group.create(session, actor=actor) return new_group.to_pydantic() + @enforce_types + def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup: + with self.session_maker() as session: + group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + + max_turns = None + termination_token = None + manager_agent_id = None + if group_update.manager_config: + if group_update.manager_config.manager_type != group.manager_type: + raise ValueError(f"Cannot change group pattern after creation") + match group_update.manager_config.manager_type: + case ManagerType.round_robin: + max_turns = group_update.manager_config.max_turns + case ManagerType.dynamic: + manager_agent_id = group_update.manager_config.manager_agent_id + max_turns = group_update.manager_config.max_turns + termination_token = group_update.manager_config.termination_token + case ManagerType.supervisor: + manager_agent_id = group_update.manager_config.manager_agent_id + case _: + raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}") + + if max_turns: + group.max_turns = max_turns + if termination_token: + group.termination_token = termination_token + if manager_agent_id: + group.manager_agent_id = manager_agent_id + if group_update.description: + group.description = group_update.description + if group_update.agent_ids: + self._process_agent_relationship( + session=session, group=group, agent_ids=group_update.agent_ids, allow_partial=False, replace=True + ) + + group.update(session, actor=actor) + return group.to_pydantic() + @enforce_types def delete_group(self, group_id: str, actor: PydanticUser) -> None: with self.session_maker() as session: @@ -87,23 +128,19 @@ class GroupManager: @enforce_types def list_group_messages( self, + actor: PydanticUser, group_id: Optional[str] = None, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 50, - actor: PydanticUser = None, use_assistant_message: bool = True, assistant_message_tool_name: str = "send_message", assistant_message_tool_kwarg: str = "message", - ) -> list[PydanticGroup]: + ) -> list[LettaMessage]: with self.session_maker() as session: - group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) - agent_id = group.manager_agent_id if group.manager_agent_id else group.agent_ids[0] - filters = { "organization_id": actor.organization_id, "group_id": group_id, - "agent_id": agent_id, } messages = MessageModel.list( db_session=session, @@ -114,21 +151,39 @@ class GroupManager: ) messages = PydanticMessage.to_letta_messages_from_list( - messages=messages, + messages=[msg.to_pydantic() for msg in messages], use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, ) + # TODO: filter messages to return a clean conversation history + return messages + @enforce_types + def reset_messages(self, group_id: str, actor: PydanticUser) -> None: + with self.session_maker() as session: + # Ensure group is loadable by user + group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + + # Delete all messages in the group + session.query(MessageModel).filter( + MessageModel.organization_id == actor.organization_id, MessageModel.group_id == group_id + ).delete(synchronize_session=False) + + session.commit() + def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True): - current_relationship = getattr(group, "agents", []) if not agent_ids: if replace: setattr(group, "agents", []) + setattr(group, "agent_ids", []) return + if group.manager_type == ManagerType.dynamic and len(agent_ids) != len(set(agent_ids)): + raise ValueError("Duplicate agent ids found in list") + # Retrieve models for the provided IDs found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all() @@ -137,11 +192,14 @@ class GroupManager: missing = set(agent_ids) - {item.id for item in found_items} raise NoResultFound(f"Items not found in agents: {missing}") + if group.manager_type == ManagerType.dynamic: + names = [item.name for item in found_items] + if len(names) != len(set(names)): + raise ValueError("Duplicate agent names found in the provided agent IDs.") + if replace: # Replace the relationship setattr(group, "agents", found_items) + setattr(group, "agent_ids", agent_ids) else: - # Extend the relationship (only add new items) - current_ids = {item.id for item in current_relationship} - new_items = [item for item in found_items if item.id not in current_ids] - current_relationship.extend(new_items) + raise ValueError("Extend relationship is not supported for groups.") diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index a343ed5a..364e76e6 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -264,6 +264,7 @@ class MessageManager: roles: Optional[Sequence[MessageRole]] = None, limit: Optional[int] = 50, ascending: bool = True, + group_id: Optional[str] = None, ) -> List[PydanticMessage]: """ Most performant query to list messages for an agent by directly querying the Message table. @@ -296,6 +297,10 @@ class MessageManager: # Build a query that directly filters the Message table by agent_id. query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id) + # If group_id is provided, filter messages by group_id. + if group_id: + query = query.filter(MessageModel.group_id == group_id) + # If query_text is provided, filter messages using subquery. if query_text: content_element = func.json_array_elements(MessageModel.content).alias("content_element") diff --git a/letta/supervisor_multi_agent.py b/letta/supervisor_multi_agent.py index 23ce96e3..98c3d7ab 100644 --- a/letta/supervisor_multi_agent.py +++ b/letta/supervisor_multi_agent.py @@ -22,7 +22,7 @@ class SupervisorMultiAgent(Agent): self, interface: AgentInterface, agent_state: AgentState, - user: User = None, + user: User, # custom group_id: str = "", agent_ids: List[str] = [], @@ -65,6 +65,7 @@ class SupervisorMultiAgent(Agent): self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user) # override tool rules + old_tool_rules = self.agent_state.tool_rules self.agent_state.tool_rules = [ InitToolRule( tool_name="send_message_to_all_agents_in_group", @@ -106,6 +107,7 @@ class SupervisorMultiAgent(Agent): raise e finally: self.interface.step_yield() + self.agent_state.tool_rules = old_tool_rules self.interface.step_complete() diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index e44864dd..2214b143 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -5,7 +5,7 @@ from letta.config import LettaConfig from letta.orm import Provider, Step from letta.schemas.agent import CreateAgent from letta.schemas.block import CreateBlock -from letta.schemas.group import DynamicManager, GroupCreate, SupervisorManager +from letta.schemas.group import DynamicManager, GroupCreate, GroupUpdate, ManagerType, RoundRobinManager, SupervisorManager from letta.schemas.message import MessageCreate from letta.server.server import SyncServer @@ -45,7 +45,7 @@ def actor(server, org_id): @pytest.fixture(scope="module") -def participant_agent_ids(server, actor): +def participant_agents(server, actor): agent_fred = server.create_agent( request=CreateAgent( name="fred", @@ -102,7 +102,7 @@ def participant_agent_ids(server, actor): ), actor=actor, ) - yield [agent_fred.id, agent_velma.id, agent_daphne.id, agent_shaggy.id] + yield [agent_fred, agent_velma, agent_daphne, agent_shaggy] # cleanup server.agent_manager.delete_agent(agent_fred.id, actor=actor) @@ -112,7 +112,7 @@ def participant_agent_ids(server, actor): @pytest.fixture(scope="module") -def manager_agent_id(server, actor): +def manager_agent(server, actor): agent_scooby = server.create_agent( request=CreateAgent( name="scooby", @@ -131,22 +131,84 @@ def manager_agent_id(server, actor): ), actor=actor, ) - yield agent_scooby.id + yield agent_scooby # cleanup server.agent_manager.delete_agent(agent_scooby.id, actor=actor) @pytest.mark.asyncio -async def test_round_robin(server, actor, participant_agent_ids): +async def test_empty_group(server, actor): group = server.group_manager.create_group( group=GroupCreate( description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", - agent_ids=participant_agent_ids, + agent_ids=[], ), actor=actor, ) + with pytest.raises(ValueError, match="Empty group"): + await server.send_group_message_to_agent( + group_id=group.id, + actor=actor, + messages=[ + MessageCreate( + role="user", + content="what is everyone up to for the holidays?", + ), + ], + stream_steps=False, + stream_tokens=False, + ) + server.group_manager.delete_group(group_id=group.id, actor=actor) + + +@pytest.mark.asyncio +async def test_modify_group_pattern(server, actor, participant_agents, manager_agent): + group = server.group_manager.create_group( + group=GroupCreate( + description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", + agent_ids=[agent.id for agent in participant_agents], + ), + actor=actor, + ) + with pytest.raises(ValueError, match="Cannot change group pattern"): + server.group_manager.modify_group( + group_id=group.id, + group_update=GroupUpdate( + manager_config=DynamicManager( + manager_type=ManagerType.dynamic, + manager_agent_id=manager_agent.id, + ), + ), + actor=actor, + ) + + server.group_manager.delete_group(group_id=group.id, actor=actor) + + +@pytest.mark.asyncio +async def test_round_robin(server, actor, participant_agents): + description = ( + "This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries." + ) + group = server.group_manager.create_group( + group=GroupCreate( + description=description, + agent_ids=[agent.id for agent in participant_agents], + ), + actor=actor, + ) + + # verify group creation + assert group.manager_type == ManagerType.round_robin + assert group.description == description + assert group.agent_ids == [agent.id for agent in participant_agents] + assert group.max_turns == None + assert group.manager_agent_id is None + assert group.termination_token is None + try: + server.group_manager.reset_messages(group_id=group.id, actor=actor) response = await server.send_group_message_to_agent( group_id=group.id, actor=actor, @@ -159,15 +221,85 @@ async def test_round_robin(server, actor, participant_agent_ids): stream_steps=False, stream_tokens=False, ) - assert response.usage.step_count == len(participant_agent_ids) + assert response.usage.step_count == len(group.agent_ids) assert len(response.messages) == response.usage.step_count * 2 + for i, message in enumerate(response.messages): + assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message" + assert message.name == participant_agents[i // 2].name + + for agent_id in group.agent_ids: + agent_messages = server.get_agent_recall( + user_id=actor.id, + agent_id=agent_id, + group_id=group.id, + reverse=True, + return_message_object=False, + ) + assert len(agent_messages) == len(group.agent_ids) + 2 # add one for user message, one for reasoning message + + # TODO: filter this to return a clean conversation history + messages = server.group_manager.list_group_messages( + group_id=group.id, + actor=actor, + ) + assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids) + + max_turns = 3 + group = server.group_manager.modify_group( + group_id=group.id, + group_update=GroupUpdate( + agent_ids=[agent.id for agent in participant_agents][::-1], + manager_config=RoundRobinManager( + max_turns=max_turns, + ), + ), + actor=actor, + ) + assert group.manager_type == ManagerType.round_robin + assert group.description == description + assert group.agent_ids == [agent.id for agent in participant_agents][::-1] + assert group.max_turns == max_turns + assert group.manager_agent_id is None + assert group.termination_token is None + + server.group_manager.reset_messages(group_id=group.id, actor=actor) + + response = await server.send_group_message_to_agent( + group_id=group.id, + actor=actor, + messages=[ + MessageCreate( + role="user", + content="what is everyone up to for the holidays?", + ), + ], + stream_steps=False, + stream_tokens=False, + ) + assert response.usage.step_count == max_turns + assert len(response.messages) == max_turns * 2 + + for i, message in enumerate(response.messages): + assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message" + assert message.name == participant_agents[::-1][i // 2].name + + for i in range(len(group.agent_ids)): + agent_messages = server.get_agent_recall( + user_id=actor.id, + agent_id=group.agent_ids[i], + group_id=group.id, + reverse=True, + return_message_object=False, + ) + expected_message_count = max_turns + 1 if i >= max_turns else max_turns + 2 + assert len(agent_messages) == expected_message_count finally: server.group_manager.delete_group(group_id=group.id, actor=actor) @pytest.mark.asyncio -async def test_supervisor(server, actor, participant_agent_ids): +async def test_supervisor(server, actor, participant_agents): agent_scrappy = server.create_agent( request=CreateAgent( name="shaggy", @@ -186,10 +318,11 @@ async def test_supervisor(server, actor, participant_agent_ids): ), actor=actor, ) + group = server.group_manager.create_group( group=GroupCreate( description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", - agent_ids=participant_agent_ids, + agent_ids=[agent.id for agent in participant_agents], manager_config=SupervisorManager( manager_agent_id=agent_scrappy.id, ), @@ -219,7 +352,7 @@ async def test_supervisor(server, actor, participant_agent_ids): and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group" ) assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len( - participant_agent_ids + participant_agents ) assert response.messages[3].message_type == "reasoning_message" assert response.messages[4].message_type == "assistant_message" @@ -230,13 +363,50 @@ async def test_supervisor(server, actor, participant_agent_ids): @pytest.mark.asyncio -async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_agent_ids): +async def test_dynamic_group_chat(server, actor, manager_agent, participant_agents): + description = ( + "This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries." + ) + # error on duplicate agent in participant list + with pytest.raises(ValueError, match="Duplicate agent ids"): + server.group_manager.create_group( + group=GroupCreate( + description=description, + agent_ids=[agent.id for agent in participant_agents] + [participant_agents[0].id], + manager_config=DynamicManager( + manager_agent_id=manager_agent.id, + ), + ), + actor=actor, + ) + # error on duplicate agent names + duplicate_agent_shaggy = server.create_agent( + request=CreateAgent( + name="shaggy", + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ), + actor=actor, + ) + with pytest.raises(ValueError, match="Duplicate agent names"): + server.group_manager.create_group( + group=GroupCreate( + description=description, + agent_ids=[agent.id for agent in participant_agents] + [duplicate_agent_shaggy.id], + manager_config=DynamicManager( + manager_agent_id=manager_agent.id, + ), + ), + actor=actor, + ) + server.agent_manager.delete_agent(duplicate_agent_shaggy.id, actor=actor) + group = server.group_manager.create_group( group=GroupCreate( - description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", - agent_ids=participant_agent_ids, + description=description, + agent_ids=[agent.id for agent in participant_agents], manager_config=DynamicManager( - manager_agent_id=manager_agent_id, + manager_agent_id=manager_agent.id, ), ), actor=actor, @@ -251,7 +421,7 @@ async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_a stream_steps=False, stream_tokens=False, ) - assert response.usage.step_count == len(participant_agent_ids) * 2 + assert response.usage.step_count == len(participant_agents) * 2 assert len(response.messages) == response.usage.step_count * 2 finally: