diff --git a/alembic/versions/77de976590ae_add_groups_for_multi_agent.py b/alembic/versions/77de976590ae_add_groups_for_multi_agent.py new file mode 100644 index 00000000..fdf446e2 --- /dev/null +++ b/alembic/versions/77de976590ae_add_groups_for_multi_agent.py @@ -0,0 +1,62 @@ +"""add groups for multi agent + +Revision ID: 77de976590ae +Revises: 167491cfb7a8 +Create Date: 2025-03-12 14:01:58.034385 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "77de976590ae" +down_revision: Union[str, None] = "167491cfb7a8" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "groups", + sa.Column("id", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=False), + sa.Column("manager_type", sa.String(), nullable=False), + sa.Column("manager_agent_id", sa.String(), nullable=True), + sa.Column("termination_token", sa.String(), nullable=True), + sa.Column("max_turns", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["manager_agent_id"], ["agents.id"], ondelete="RESTRICT"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "groups_agents", + sa.Column("group_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("group_id", "agent_id"), + ) + op.add_column("messages", sa.Column("group_id", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("messages", "group_id") + op.drop_table("groups_agents") + op.drop_table("groups") + # ### end Alembic commands ### diff --git a/examples/mcp_example.py b/examples/mcp_example.py index e27621a3..a12c3faf 100644 --- a/examples/mcp_example.py +++ b/examples/mcp_example.py @@ -1,6 +1,7 @@ -from letta_client import Letta from pprint import pprint +from letta_client import Letta + client = Letta(base_url="http://localhost:8283") mcp_server_name = "everything" diff --git a/letta/agent.py b/letta/agent.py index 8f15faf5..daa9a6c8 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -95,6 +95,7 @@ class Agent(BaseAgent): first_message_verify_mono: bool = True, # TODO move to config? # MCP sessions, state held in-memory in the server mcp_clients: Optional[Dict[str, BaseMCPClient]] = None, + save_last_response: bool = False, ): assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}" # Hold a copy of the state that was used to init the agent @@ -149,6 +150,10 @@ class Agent(BaseAgent): # Load last function response from message history self.last_function_response = self.load_last_function_response() + # Save last responses in memory + self.save_last_response = save_last_response + self.last_response_messages = [] + # Logger that the Agent specifically can use, will also report the agent_state ID with the logs self.logger = get_logger(agent_state.id) @@ -926,6 +931,9 @@ class Agent(BaseAgent): else: all_new_messages = all_response_messages + if self.save_last_response: + self.last_response_messages = all_response_messages + # Check the memory pressure and potentially issue a memory pressure warning current_total_tokens = response.usage.total_tokens active_memory_warning = False @@ -1052,6 +1060,7 @@ class Agent(BaseAgent): else: logger.error(f"step() failed with an unrecognized exception: '{str(e)}'") + traceback.print_exc() raise e def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse: diff --git a/letta/dynamic_multi_agent.py b/letta/dynamic_multi_agent.py new file mode 100644 index 00000000..4b979ef8 --- /dev/null +++ b/letta/dynamic_multi_agent.py @@ -0,0 +1,274 @@ +from typing import List, Optional + +from letta.agent import Agent, AgentState +from letta.interface import AgentInterface +from letta.orm import User +from letta.schemas.block import Block +from letta.schemas.letta_message import TextContent +from letta.schemas.message import Message, MessageCreate +from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.usage import LettaUsageStatistics +from letta.services.tool_manager import ToolManager + + +class DynamicMultiAgent(Agent): + def __init__( + self, + interface: AgentInterface, + agent_state: AgentState, + user: User = None, + # custom + group_id: str = "", + agent_ids: List[str] = [], + description: str = "", + max_turns: Optional[int] = None, + termination_token: str = "DONE!", + ): + super().__init__(interface, agent_state, user) + self.group_id = group_id + self.agent_ids = agent_ids + self.description = description + self.max_turns = max_turns or len(agent_ids) + self.termination_token = termination_token + + self.tool_manager = ToolManager() + + def step( + self, + messages: List[MessageCreate], + chaining: bool = True, + max_chaining_steps: Optional[int] = None, + put_inner_thoughts_first: bool = True, + **kwargs, + ) -> LettaUsageStatistics: + total_usage = UsageStatistics() + step_count = 0 + + token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False + metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None + + agents = {} + message_index = {self.agent_state.id: 0} + agents[self.agent_state.id] = self.load_manager_agent() + for agent_id in self.agent_ids: + agents[agent_id] = self.load_participant_agent(agent_id=agent_id) + message_index[agent_id] = 0 + + chat_history: List[Message] = [] + new_messages = messages + speaker_id = None + try: + for _ in range(self.max_turns): + agent_id_options = [agent_id for agent_id in self.agent_ids if agent_id != speaker_id] + manager_message = self.ask_manager_to_choose_participant_message(new_messages, chat_history, agent_id_options) + manager_agent = agents[self.agent_state.id] + usage_stats = manager_agent.step( + messages=[manager_message], + chaining=chaining, + max_chaining_steps=max_chaining_steps, + stream=token_streaming, + skip_verify=True, + metadata=metadata, + put_inner_thoughts_first=put_inner_thoughts_first, + ) + responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages) + assistant_message = [response for response in responses if response.message_type == "assistant_message"][0] + for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]: + if name.lower() in assistant_message.content.lower(): + speaker_id = agent_id + + # sum usage + total_usage.prompt_tokens += usage_stats.prompt_tokens + total_usage.completion_tokens += usage_stats.completion_tokens + total_usage.total_tokens += usage_stats.total_tokens + step_count += 1 + + # initialize input messages + for message in chat_history[message_index[speaker_id] :]: + message.id = Message.generate_id() + message.agent_id = speaker_id + + for message in new_messages: + chat_history.append( + Message( + agent_id=speaker_id, + role=message.role, + content=[TextContent(text=message.content)], + name=message.name, + model=None, + tool_calls=None, + tool_call_id=None, + group_id=self.group_id, + ) + ) + + # load agent and perform step + participant_agent = agents[speaker_id] + usage_stats = participant_agent.step( + messages=chat_history[message_index[speaker_id] :], + chaining=chaining, + max_chaining_steps=max_chaining_steps, + stream=token_streaming, + skip_verify=True, + metadata=metadata, + put_inner_thoughts_first=put_inner_thoughts_first, + ) + + # parse new messages for next step + responses = Message.to_letta_messages_from_list( + participant_agent.last_response_messages, + ) + + assistant_messages = [response for response in responses if response.message_type == "assistant_message"] + new_messages = [ + MessageCreate( + role="system", + content=message.content, + name=participant_agent.agent_state.name, + ) + for message in assistant_messages + ] + message_index[agent_id] = len(chat_history) + len(new_messages) + + # sum usage + total_usage.prompt_tokens += usage_stats.prompt_tokens + total_usage.completion_tokens += usage_stats.completion_tokens + total_usage.total_tokens += usage_stats.total_tokens + step_count += 1 + + # check for termination token + if any(self.termination_token in message.content for message in new_messages): + break + + # persist remaining chat history + for message in new_messages: + chat_history.append( + Message( + agent_id=agent_id, + role=message.role, + content=[TextContent(text=message.content)], + name=message.name, + model=None, + tool_calls=None, + tool_call_id=None, + group_id=self.group_id, + ) + ) + for agent_id, index in message_index.items(): + if agent_id == speaker_id: + continue + for message in chat_history[index:]: + message.id = Message.generate_id() + message.agent_id = agent_id + self.message_manager.create_many_messages(chat_history[index:], actor=self.user) + + except Exception as e: + raise e + finally: + self.interface.step_yield() + + self.interface.step_complete() + + return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) + + def load_manager_agent(self) -> Agent: + for participant_agent_id in self.agent_ids: + participant_agent_state = self.agent_manager.get_agent_by_id(agent_id=participant_agent_id, actor=self.user) + participant_persona_block = participant_agent_state.memory.get_block(label="persona") + new_block = self.block_manager.create_or_update_block( + block=Block( + label=participant_agent_id, + value=participant_persona_block.value, + ), + actor=self.user, + ) + self.agent_state = self.agent_manager.update_block_with_label( + agent_id=self.agent_state.id, + block_label=participant_agent_id, + new_block_id=new_block.id, + actor=self.user, + ) + + persona_block = self.agent_state.memory.get_block(label="persona") + group_chat_manager_persona = ( + f"You are overseeing a group chat with {len(self.agent_ids) - 1} agents and " + f"one user. Description of the group: {self.description}\n" + "On each turn, you will be provided with the chat history and latest message. " + "Your task is to decide which participant should speak next in the chat based " + "on the chat history. Each agent has a memory block labeled with their ID which " + "holds info about them, and you should use this context to inform your decision." + ) + self.agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_manager_persona) + return Agent( + agent_state=self.agent_state, + interface=self.interface, + user=self.user, + save_last_response=True, + ) + + def load_participant_agent(self, agent_id: str) -> Agent: + agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) + persona_block = agent_state.memory.get_block(label="persona") + group_chat_participant_persona = ( + f"You are a participant in a group chat with {len(self.agent_ids) - 1} other " + "agents and one user. Respond to new messages in the group chat when prompted. " + f"Description of the group: {self.description}. About you: " + ) + agent_state.memory.update_block_value(label="persona", value=group_chat_participant_persona + persona_block.value) + return Agent( + agent_state=agent_state, + interface=self.interface, + user=self.user, + save_last_response=True, + ) + + ''' + def attach_choose_next_participant_tool(self) -> AgentState: + def choose_next_participant(next_speaker_agent_id: str) -> str: + """ + Returns ID of the agent in the group chat that should reply to the latest message in the conversation. The agent ID will always be in the format: `agent-{UUID}`. + Args: + next_speaker_agent_id (str): The ID of the agent that is most suitable to be the next speaker. + Returns: + str: The ID of the agent that should be the next speaker. + """ + return next_speaker_agent_id + source_code = parse_source_code(choose_next_participant) + tool = self.tool_manager.create_or_update_tool( + Tool( + source_type="python", + source_code=source_code, + name="choose_next_participant", + ), + actor=self.user, + ) + return self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=tool.id, actor=self.user) + ''' + + def ask_manager_to_choose_participant_message( + self, + new_messages: List[MessageCreate], + chat_history: List[Message], + agent_id_options: List[str], + ) -> Message: + chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history] + for message in new_messages: + chat_history.append(f"{message.name or 'user'}: {message.content}") + context_messages = "\n".join(chat_history) + + message_text = ( + "Choose the most suitable agent to reply to the latest message in the " + f"group chat from the following options: {agent_id_options}. Do not " + "respond to the messages yourself, your task is only to decide the " + f"next speaker, not to participate. \nChat history:\n{context_messages}" + ) + return Message( + agent_id=self.agent_state.id, + role="user", + content=[TextContent(text=message_text)], + name=None, + model=None, + tool_calls=None, + tool_call_id=None, + group_id=self.group_id, + ) diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py index 1f702b24..98f513ef 100644 --- a/letta/functions/function_sets/multi_agent.py +++ b/letta/functions/function_sets/multi_agent.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List from letta.functions.helpers import ( _send_message_to_agents_matching_tags_async, + _send_message_to_all_agents_in_group_async, execute_send_message_to_agent, fire_and_forget_send_to_agent, ) @@ -86,3 +87,19 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: """ return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some)) + + +def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]: + """ + Sends a message to all agents within the same multi-agent group. + + Args: + message (str): The content of the message to be sent to each matching agent. + + Returns: + List[str]: A list of responses from the agents that matched the filtering criteria. Each + response corresponds to a single agent. Agents that do not respond will not have an entry + in the returned list. + """ + + return asyncio.run(_send_message_to_all_agents_in_group_async(self, message)) diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 9f69280f..ebe78a6a 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -604,6 +604,47 @@ async def _send_message_to_agents_matching_tags_async( return final +async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", message: str) -> List[str]: + server = get_letta_server() + + augmented_message = ( + f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, " + f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " + f"{message}" + ) + + worker_agents_ids = sender_agent.agent_state.multi_agent_group.agent_ids + worker_agents = [server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=sender_agent.user) for agent_id in worker_agents_ids] + + # Create a system message + messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)] + + # Possibly limit concurrency to avoid meltdown: + sem = asyncio.Semaphore(settings.multi_agent_concurrent_sends) + + async def _send_single(agent_state): + async with sem: + return await async_send_message_with_retries( + server=server, + sender_agent=sender_agent, + target_agent_id=agent_state.id, + messages=messages, + max_retries=3, + timeout=settings.multi_agent_send_message_timeout, + ) + + tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in worker_agents] + results = await asyncio.gather(*tasks, return_exceptions=True) + final = [] + for r in results: + if isinstance(r, Exception): + final.append(str(r)) + else: + final.append(r) + + return final + + def generate_model_from_args_json_schema(schema: Dict[str, Any]) -> Type[BaseModel]: """Creates a Pydantic model from a JSON schema. diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 5963d36c..a43f6e0b 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -4,6 +4,8 @@ from letta.orm.base import Base from letta.orm.block import Block from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata +from letta.orm.group import Group +from letta.orm.groups_agents import GroupsAgents from letta.orm.identities_agents import IdentitiesAgents from letta.orm.identities_blocks import IdentitiesBlocks from letta.orm.identity import Identity diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 59f7f1ff..579f2279 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -128,11 +128,25 @@ class Agent(SqlalchemyBase, OrganizationMixin): back_populates="agents", passive_deletes=True, ) + groups: Mapped[List["Group"]] = relationship( + "Group", + secondary="groups_agents", + lazy="selectin", + back_populates="agents", + passive_deletes=True, + ) + multi_agent_group: Mapped["Group"] = relationship( + "Group", + lazy="joined", + viewonly=True, + back_populates="manager_agent", + ) def to_pydantic(self) -> PydanticAgentState: """converts to the basic pydantic model counterpart""" # add default rule for having send_message be a terminal tool tool_rules = self.tool_rules + multi_agent_group = self.multi_agent_group state = { "id": self.id, "organization_id": self.organization_id, @@ -159,6 +173,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): "base_template_id": self.base_template_id, "identity_ids": [identity.id for identity in self.identities], "message_buffer_autoclear": self.message_buffer_autoclear, + "multi_agent_group": multi_agent_group, } return self.__pydantic_model__(**state) diff --git a/letta/orm/group.py b/letta/orm/group.py new file mode 100644 index 00000000..308c43d1 --- /dev/null +++ b/letta/orm/group.py @@ -0,0 +1,33 @@ +import uuid +from typing import List, Optional + +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.group import Group as PydanticGroup + + +class Group(SqlalchemyBase, OrganizationMixin): + + __tablename__ = "groups" + __pydantic_model__ = PydanticGroup + + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"group-{uuid.uuid4()}") + description: Mapped[str] = mapped_column(nullable=False, doc="") + manager_type: Mapped[str] = mapped_column(nullable=False, doc="") + manager_agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id", ondelete="RESTRICT"), nullable=True, doc="") + termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") + max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="groups") + agents: Mapped[List["Agent"]] = relationship( + "Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups" + ) + manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group") + + @property + def agent_ids(self) -> List[str]: + return [agent.id for agent in self.agents] diff --git a/letta/orm/groups_agents.py b/letta/orm/groups_agents.py new file mode 100644 index 00000000..375b7fe0 --- /dev/null +++ b/letta/orm/groups_agents.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.base import Base + + +class GroupsAgents(Base): + """Agents may have one or many groups associated with them.""" + + __tablename__ = "groups_agents" + + group_id: Mapped[str] = mapped_column(String, ForeignKey("groups.id", ondelete="CASCADE"), primary_key=True) + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True) diff --git a/letta/orm/message.py b/letta/orm/message.py index e94eaa33..145642aa 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -36,6 +36,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): tool_returns: Mapped[List[ToolReturn]] = mapped_column( ToolReturnColumn, nullable=True, doc="Tool execution return information for prior tool calls" ) + group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in") # Relationships agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin") diff --git a/letta/orm/organization.py b/letta/orm/organization.py index fc8dcfc7..133b77d8 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -49,6 +49,7 @@ class Organization(SqlalchemyBase): agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan") providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan") identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan") + groups: Mapped[List["Group"]] = relationship("Group", back_populates="organization", cascade="all, delete-orphan") @property def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 11ac3070..fd211b86 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -139,11 +139,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): else: # Match ANY tag - use join and filter query = ( - query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id) + query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id) ) # Deduplicate results - # Group by primary key and all necessary columns to avoid JSON comparison - query = query.group_by(cls.id) + # select distinct primary key + query = query.distinct(cls.id).order_by(cls.id) if identifier_keys and hasattr(cls, "identities"): query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) diff --git a/letta/round_robin_multi_agent.py b/letta/round_robin_multi_agent.py new file mode 100644 index 00000000..bca50f40 --- /dev/null +++ b/letta/round_robin_multi_agent.py @@ -0,0 +1,152 @@ +from typing import List, Optional + +from letta.agent import Agent, AgentState +from letta.interface import AgentInterface +from letta.orm import User +from letta.schemas.letta_message import TextContent +from letta.schemas.message import Message, MessageCreate +from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.usage import LettaUsageStatistics + + +class RoundRobinMultiAgent(Agent): + def __init__( + self, + interface: AgentInterface, + agent_state: AgentState, + user: User = None, + # custom + group_id: str = "", + agent_ids: List[str] = [], + description: str = "", + max_turns: Optional[int] = None, + ): + super().__init__(interface, agent_state, user) + self.group_id = group_id + self.agent_ids = agent_ids + self.description = description + self.max_turns = max_turns or len(agent_ids) + + def step( + self, + messages: List[MessageCreate], + chaining: bool = True, + max_chaining_steps: Optional[int] = None, + put_inner_thoughts_first: bool = True, + **kwargs, + ) -> LettaUsageStatistics: + total_usage = UsageStatistics() + step_count = 0 + + token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False + metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None + + agents = {} + for agent_id in self.agent_ids: + agents[agent_id] = self.load_participant_agent(agent_id=agent_id) + + message_index = {} + chat_history: List[Message] = [] + new_messages = messages + speaker_id = None + try: + for i in range(self.max_turns): + speaker_id = self.agent_ids[i % len(self.agent_ids)] + # initialize input messages + start_index = message_index[speaker_id] if speaker_id in message_index else 0 + for message in chat_history[start_index:]: + message.id = Message.generate_id() + message.agent_id = speaker_id + + for message in new_messages: + chat_history.append( + Message( + agent_id=speaker_id, + role=message.role, + content=[TextContent(text=message.content)], + name=message.name, + model=None, + tool_calls=None, + tool_call_id=None, + group_id=self.group_id, + ) + ) + + # load agent and perform step + participant_agent = agents[speaker_id] + usage_stats = participant_agent.step( + messages=chat_history[start_index:], + chaining=chaining, + max_chaining_steps=max_chaining_steps, + stream=token_streaming, + skip_verify=True, + metadata=metadata, + put_inner_thoughts_first=put_inner_thoughts_first, + ) + + # parse new messages for next step + responses = Message.to_letta_messages_from_list(participant_agent.last_response_messages) + assistant_messages = [response for response in responses if response.message_type == "assistant_message"] + new_messages = [ + MessageCreate( + role="system", + content=message.content, + name=participant_agent.agent_state.name, + ) + for message in assistant_messages + ] + message_index[speaker_id] = len(chat_history) + len(new_messages) + + # sum usage + total_usage.prompt_tokens += usage_stats.prompt_tokens + total_usage.completion_tokens += usage_stats.completion_tokens + total_usage.total_tokens += usage_stats.total_tokens + step_count += 1 + + # persist remaining chat history + for message in new_messages: + chat_history.append( + Message( + agent_id=agent_id, + role=message.role, + content=[TextContent(text=message.content)], + name=message.name, + model=None, + tool_calls=None, + tool_call_id=None, + group_id=self.group_id, + ) + ) + for agent_id, index in message_index.items(): + if agent_id == speaker_id: + continue + for message in chat_history[index:]: + message.id = Message.generate_id() + message.agent_id = agent_id + self.message_manager.create_many_messages(chat_history[index:], actor=self.user) + + except Exception as e: + raise e + finally: + self.interface.step_yield() + + self.interface.step_complete() + + return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) + + def load_participant_agent(self, agent_id: str) -> Agent: + agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) + persona_block = agent_state.memory.get_block(label="persona") + group_chat_participant_persona = ( + "\n\n====Group Chat Contex====" + f"\nYou are speaking in a group chat with {len(self.agent_ids) - 1} other " + "agents and one user. Respond to new messages in the group chat when prompted. " + f"Description of the group: {self.description}" + ) + agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_participant_persona) + return Agent( + agent_state=agent_state, + interface=self.interface, + user=self.user, + save_last_response=True, + ) diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 44e6ecf6..62a34f3e 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -7,6 +7,7 @@ from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.environment_variables import AgentEnvironmentVariable +from letta.schemas.group import Group from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory @@ -90,6 +91,8 @@ class AgentState(OrmMetadataBase, validate_assignment=True): description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", ) + multi_agent_group: Optional[Group] = Field(None, description="The multi-agent group that this agent manages") + def get_agent_env_vars_as_dict(self) -> Dict[str, str]: # Get environment variables for this agent specifically per_agent_env_vars = {} diff --git a/letta/schemas/group.py b/letta/schemas/group.py new file mode 100644 index 00000000..254baa8b --- /dev/null +++ b/letta/schemas/group.py @@ -0,0 +1,65 @@ +from enum import Enum +from typing import Annotated, List, Literal, Optional, Union + +from pydantic import BaseModel, Field + +from letta.schemas.letta_base import LettaBase + + +class ManagerType(str, Enum): + round_robin = "round_robin" + supervisor = "supervisor" + dynamic = "dynamic" + swarm = "swarm" + + +class GroupBase(LettaBase): + __id_prefix__ = "group" + + +class Group(GroupBase): + id: str = Field(..., description="The id of the group. Assigned by the database.") + manager_type: ManagerType = Field(..., description="") + agent_ids: List[str] = Field(..., description="") + description: str = Field(..., description="") + # Pattern fields + manager_agent_id: Optional[str] = Field(None, description="") + termination_token: Optional[str] = Field(None, description="") + max_turns: Optional[int] = Field(None, description="") + + +class ManagerConfig(BaseModel): + manager_type: ManagerType = Field(..., description="") + + +class RoundRobinManager(ManagerConfig): + manager_type: Literal[ManagerType.round_robin] = Field(ManagerType.round_robin, description="") + max_turns: Optional[int] = Field(None, description="") + + +class SupervisorManager(ManagerConfig): + manager_type: Literal[ManagerType.supervisor] = Field(ManagerType.supervisor, description="") + manager_agent_id: str = Field(..., description="") + + +class DynamicManager(ManagerConfig): + manager_type: Literal[ManagerType.dynamic] = Field(ManagerType.dynamic, description="") + manager_agent_id: str = Field(..., description="") + termination_token: Optional[str] = Field("DONE!", description="") + max_turns: Optional[int] = Field(None, description="") + + +# class SwarmGroup(ManagerConfig): +# manager_type: Literal[ManagerType.swarm] = Field(ManagerType.swarm, description="") + + +ManagerConfigUnion = Annotated[ + Union[RoundRobinManager, SupervisorManager, DynamicManager], + Field(discriminator="manager_type"), +] + + +class GroupCreate(BaseModel): + agent_ids: List[str] = Field(..., description="") + description: str = Field(..., description="") + manager_config: Optional[ManagerConfigUnion] = Field(None, description="") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index f182422a..18f715b8 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -129,6 +129,7 @@ class Message(BaseMessage): step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.") otid: Optional[str] = Field(None, description="The offline threading id associated with this message") tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls") + group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in") # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index d4f86bf8..8b983fb7 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -1,5 +1,6 @@ from letta.server.rest_api.routers.v1.agents import router as agents_router from letta.server.rest_api.routers.v1.blocks import router as blocks_router +from letta.server.rest_api.routers.v1.groups import router as groups_router from letta.server.rest_api.routers.v1.health import router as health_router from letta.server.rest_api.routers.v1.identities import router as identities_router from letta.server.rest_api.routers.v1.jobs import router as jobs_router @@ -17,6 +18,7 @@ ROUTERS = [ tools_router, sources_router, agents_router, + groups_router, identities_router, llm_router, blocks_router, diff --git a/letta/server/rest_api/routers/v1/groups.py b/letta/server/rest_api/routers/v1/groups.py new file mode 100644 index 00000000..9e02ddd7 --- /dev/null +++ b/letta/server/rest_api/routers/v1/groups.py @@ -0,0 +1,233 @@ +from typing import Annotated, List, Optional + +from fastapi import APIRouter, Body, Depends, Header, Query +from pydantic import Field + +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.schemas.group import Group, GroupCreate, ManagerType +from letta.schemas.letta_message import LettaMessageUnion +from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest +from letta.schemas.letta_response import LettaResponse +from letta.server.rest_api.utils import get_letta_server +from letta.server.server import SyncServer + +router = APIRouter(prefix="/groups", tags=["groups"]) + + +@router.post("/", response_model=Group, operation_id="create_group") +async def create_group( + server: SyncServer = Depends(get_letta_server), + request: GroupCreate = Body(...), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Create a multi-agent group with a specified management pattern. When no + management config is specified, this endpoint will use round robin for + speaker selection. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + return server.group_manager.create_group(request, actor=actor) + + +@router.get("/", response_model=List[Group], operation_id="list_groups") +def list_groups( + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), + manager_type: Optional[ManagerType] = Query(None, description="Search groups by manager type"), + before: Optional[str] = Query(None, description="Cursor for pagination"), + after: Optional[str] = Query(None, description="Cursor for pagination"), + limit: Optional[int] = Query(None, description="Limit for pagination"), + project_id: Optional[str] = Query(None, description="Search groups by project id"), +): + """ + Fetch all multi-agent groups matching query. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + return server.group_manager.list_groups( + project_id=project_id, + manager_type=manager_type, + before=before, + after=after, + limit=limit, + actor=actor, + ) + + +@router.post("/", response_model=Group, operation_id="create_group") +def create_group( + group: GroupCreate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware +): + """ + Create a new multi-agent group with the specified configuration. + """ + try: + actor = server.user_manager.get_user_or_default(user_id=actor_id) + return server.group_manager.create_group(group, actor=actor) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.put("/", response_model=Group, operation_id="upsert_group") +def upsert_group( + group: GroupCreate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware +): + """ + Create a new multi-agent group with the specified configuration. + """ + try: + actor = server.user_manager.get_user_or_default(user_id=actor_id) + return server.group_manager.create_group(group, actor=actor) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/{group_id}", response_model=None, operation_id="delete_group") +def delete_group( + group_id: str, + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Delete a multi-agent group. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + try: + server.group_manager.delete_group(group_id=group_id, actor=actor) + return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Group id={group_id} successfully deleted"}) + except NoResultFound: + raise HTTPException(status_code=404, detail=f"Group id={group_id} not found for user_id={actor.id}.") + + +@router.post( + "/{group_id}/messages", + response_model=LettaResponse, + operation_id="send_group_message", +) +async def send_group_message( + agent_id: str, + server: SyncServer = Depends(get_letta_server), + request: LettaRequest = Body(...), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Process a user message and return the group's response. + This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + result = await server.send_group_message_to_agent( + group_id=group_id, + actor=actor, + messages=request.messages, + stream_steps=False, + stream_tokens=False, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + ) + return result + + +@router.post( + "/{group_id}/messages/stream", + response_model=None, + operation_id="send_group_message_streaming", + responses={ + 200: { + "description": "Successful response", + "content": { + "text/event-stream": {"description": "Server-Sent Events stream"}, + }, + } + }, +) +async def send_group_message_streaming( + group_id: str, + server: SyncServer = Depends(get_letta_server), + request: LettaStreamingRequest = Body(...), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Process a user message and return the group's responses. + This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern. + It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + result = await server.send_group_message_to_agent( + group_id=group_id, + actor=actor, + messages=request.messages, + stream_steps=True, + stream_tokens=request.stream_tokens, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + ) + return result + + +GroupMessagesResponse = Annotated[ + List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}}) +] + + +@router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages") +def list_group_messages( + group_id: str, + server: "SyncServer" = Depends(get_letta_server), + after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."), + before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), + limit: int = Query(10, description="Maximum number of messages to retrieve."), + use_assistant_message: bool = Query(True, description="Whether to use assistant messages"), + assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."), + assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Retrieve message history for an agent. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + + return server.group_manager.list_group_messages( + group_id=group_id, + before=before, + after=after, + limit=limit, + actor=actor, + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + ) + + +''' +@router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages") +def reset_group_messages( + group_id: str, + add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."), + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Resets the messages for all agents that are part of the multi-agent group. + TODO: only delete group messages not all messages! + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + group = server.group_manager.retrieve_group(group_id=group_id, actor=actor) + agent_ids = group.agent_ids + if group.manager_agent_id: + agent_ids.append(group.manager_agent_id) + for agent_id in agent_ids: + server.agent_manager.reset_messages( + agent_id=agent_id, + actor=actor, + add_default_initial_messages=add_default_initial_messages, + ) +''' diff --git a/letta/server/server.py b/letta/server/server.py index 082f88b7..1eebe729 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -19,6 +19,7 @@ import letta.system as system from letta.agent import Agent, save_agent from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector, load_data +from letta.dynamic_multi_agent import DynamicMultiAgent from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads from letta.helpers.mcp_helpers import ( @@ -37,6 +38,7 @@ from letta.interface import CLIInterface # for printing to terminal from letta.log import get_logger from letta.offline_memory_agent import OfflineMemoryAgent from letta.orm.errors import NoResultFound +from letta.round_robin_multi_agent import RoundRobinMultiAgent from letta.schemas.agent import AgentState, AgentType, CreateAgent from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig @@ -44,6 +46,7 @@ from letta.schemas.embedding_config import EmbeddingConfig # openai schemas from letta.schemas.enums import JobStatus, MessageStreamStatus from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate +from letta.schemas.group import Group, ManagerType from letta.schemas.job import Job, JobUpdate from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage from letta.schemas.letta_response import LettaResponse @@ -80,6 +83,7 @@ from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.utils import sse_async_generator from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager +from letta.services.group_manager import GroupManager from letta.services.identity_manager import IdentityManager from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager @@ -94,6 +98,7 @@ from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import model_settings, settings, tool_settings +from letta.supervisor_multi_agent import SupervisorMultiAgent from letta.tracing import trace_method from letta.utils import get_friendly_error_msg @@ -207,6 +212,7 @@ class SyncServer(Server): self.provider_manager = ProviderManager() self.step_manager = StepManager() self.identity_manager = IdentityManager() + self.group_manager = GroupManager() # Managers that interface with parallelism self.per_agent_lock_manager = PerAgentLockManager() @@ -353,6 +359,8 @@ class SyncServer(Server): agent_lock = self.per_agent_lock_manager.get_lock(agent_id) with agent_lock: agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + if agent_state.multi_agent_group: + return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state) interface = interface or self.default_interface_factory() if agent_state.agent_type == AgentType.memgpt_agent: @@ -364,6 +372,46 @@ class SyncServer(Server): return agent + def load_multi_agent( + self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None + ) -> Agent: + match group.manager_type: + case ManagerType.round_robin: + agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor) + return RoundRobinMultiAgent( + agent_state=agent_state, + interface=interface, + user=actor, + group_id=group.id, + agent_ids=group.agent_ids, + description=group.description, + max_turns=group.max_turns, + ) + case ManagerType.dynamic: + agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor) + return DynamicMultiAgent( + agent_state=agent_state, + interface=interface, + user=actor, + group_id=group.id, + agent_ids=group.agent_ids, + description=group.description, + max_turns=group.max_turns, + termination_token=group.termination_token, + ) + case ManagerType.supervisor: + agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor) + return SupervisorMultiAgent( + agent_state=agent_state, + interface=interface, + user=actor, + group_id=group.id, + agent_ids=group.agent_ids, + description=group.description, + ) + case _: + raise ValueError(f"Type {group.manager_type} is not supported.") + def _step( self, actor: User, @@ -1403,3 +1451,106 @@ class SyncServer(Server): traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") + + @trace_method + async def send_group_message_to_agent( + self, + group_id: str, + actor: User, + messages: Union[List[Message], List[MessageCreate]], + stream_steps: bool, + stream_tokens: bool, + chat_completion_mode: bool = False, + # Support for AssistantMessage + use_assistant_message: bool = True, + assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, + assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, + metadata: Optional[dict] = None, + ) -> Union[StreamingResponse, LettaResponse]: + include_final_message = True + if not stream_steps and stream_tokens: + raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'") + + try: + # fetch the group + group = self.group_manager.retrieve_group(group_id=group_id, actor=actor) + letta_multi_agent = self.load_multi_agent(group=group, actor=actor) + + llm_config = letta_multi_agent.agent_state.llm_config + supports_token_streaming = ["openai", "anthropic", "deepseek"] + if stream_tokens and ( + llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint + ): + warnings.warn( + f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." + ) + stream_tokens = False + + # Create a new interface per request + letta_multi_agent.interface = StreamingServerInterface( + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + inner_thoughts_in_kwargs=( + llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False + ), + ) + streaming_interface = letta_multi_agent.interface + if not isinstance(streaming_interface, StreamingServerInterface): + raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") + streaming_interface.streaming_mode = stream_tokens + streaming_interface.streaming_chat_completion_mode = chat_completion_mode + if metadata and hasattr(streaming_interface, "metadata"): + streaming_interface.metadata = metadata + + streaming_interface.stream_start() + task = asyncio.create_task( + asyncio.to_thread( + letta_multi_agent.step, + messages=messages, + chaining=self.chaining, + max_chaining_steps=self.max_chaining_steps, + ) + ) + + if stream_steps: + # return a stream + return StreamingResponse( + sse_async_generator( + streaming_interface.get_generator(), + usage_task=task, + finish_message=include_final_message, + ), + media_type="text/event-stream", + ) + + else: + # buffer the stream, then return the list + generated_stream = [] + async for message in streaming_interface.get_generator(): + assert ( + isinstance(message, LettaMessage) + or isinstance(message, LegacyLettaMessage) + or isinstance(message, MessageStreamStatus) + ), type(message) + generated_stream.append(message) + if message == MessageStreamStatus.done: + break + + # Get rid of the stream status messages + filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] + usage = await task + + # By default the stream will be messages of type LettaMessage or LettaLegacyMessage + # If we want to convert these to Message, we can use the attached IDs + # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) + # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead + return LettaResponse(messages=filtered_stream, usage=usage) + except HTTPException: + raise + except Exception as e: + print(e) + import traceback + + traceback.print_exc() + raise HTTPException(status_code=500, detail=f"{e}") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 8ea31fa7..a3ba62ff 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -399,7 +399,7 @@ class AgentManager: # Ensures agents match at least one tag in match_some query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some)) - query = query.group_by(AgentModel.id).limit(limit) + query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit) return list(session.execute(query).scalars()) @@ -434,6 +434,7 @@ class AgentManager: with self.session_maker() as session: # Retrieve the agent agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + # TODO check if it is managing a group agent.hard_delete(session) @enforce_types diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py new file mode 100644 index 00000000..c3576ee8 --- /dev/null +++ b/letta/services/group_manager.py @@ -0,0 +1,147 @@ +from typing import List, Optional + +from sqlalchemy.orm import Session + +from letta.orm.agent import Agent as AgentModel +from letta.orm.errors import NoResultFound +from letta.orm.group import Group as GroupModel +from letta.orm.message import Message as MessageModel +from letta.schemas.group import Group as PydanticGroup +from letta.schemas.group import GroupCreate, ManagerType +from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types + + +class GroupManager: + + def __init__(self): + from letta.server.db import db_context + + self.session_maker = db_context + + @enforce_types + def list_groups( + self, + project_id: Optional[str] = None, + manager_type: Optional[ManagerType] = None, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + actor: PydanticUser = None, + ) -> list[PydanticGroup]: + with self.session_maker() as session: + filters = {"organization_id": actor.organization_id} + if project_id: + filters["project_id"] = project_id + if manager_type: + filters["manager_type"] = manager_type + groups = GroupModel.list( + db_session=session, + before=before, + after=after, + limit=limit, + **filters, + ) + return [group.to_pydantic() for group in groups] + + @enforce_types + def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup: + with self.session_maker() as session: + group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + return group.to_pydantic() + + @enforce_types + def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup: + with self.session_maker() as session: + new_group = GroupModel() + new_group.organization_id = actor.organization_id + new_group.description = group.description + self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False) + if group.manager_config is None: + new_group.manager_type = ManagerType.round_robin + else: + match group.manager_config.manager_type: + case ManagerType.round_robin: + new_group.manager_type = ManagerType.round_robin + new_group.max_turns = group.manager_config.max_turns + case ManagerType.dynamic: + new_group.manager_type = ManagerType.dynamic + new_group.manager_agent_id = group.manager_config.manager_agent_id + new_group.max_turns = group.manager_config.max_turns + new_group.termination_token = group.manager_config.termination_token + case ManagerType.supervisor: + new_group.manager_type = ManagerType.supervisor + new_group.manager_agent_id = group.manager_config.manager_agent_id + case _: + raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") + new_group.create(session, actor=actor) + return new_group.to_pydantic() + + @enforce_types + def delete_group(self, group_id: str, actor: PydanticUser) -> None: + with self.session_maker() as session: + # Retrieve the agent + group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + group.hard_delete(session) + + @enforce_types + def list_group_messages( + self, + group_id: Optional[str] = None, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + actor: PydanticUser = None, + use_assistant_message: bool = True, + assistant_message_tool_name: str = "send_message", + assistant_message_tool_kwarg: str = "message", + ) -> list[PydanticGroup]: + with self.session_maker() as session: + group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + agent_id = group.manager_agent_id if group.manager_agent_id else group.agent_ids[0] + + filters = { + "organization_id": actor.organization_id, + "group_id": group_id, + "agent_id": agent_id, + } + messages = MessageModel.list( + db_session=session, + before=before, + after=after, + limit=limit, + **filters, + ) + + messages = PydanticMessage.to_letta_messages_from_list( + messages=messages, + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + ) + + return messages + + def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True): + current_relationship = getattr(group, "agents", []) + if not agent_ids: + if replace: + setattr(group, "agents", []) + return + + # Retrieve models for the provided IDs + found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all() + + # Validate all items are found if allow_partial is False + if not allow_partial and len(found_items) != len(agent_ids): + missing = set(agent_ids) - {item.id for item in found_items} + raise NoResultFound(f"Items not found in agents: {missing}") + + if replace: + # Replace the relationship + setattr(group, "agents", found_items) + else: + # Extend the relationship (only add new items) + current_ids = {item.id for item in current_relationship} + new_items = [item for item in found_items if item.id not in current_ids] + current_relationship.extend(new_items) diff --git a/letta/supervisor_multi_agent.py b/letta/supervisor_multi_agent.py new file mode 100644 index 00000000..55f481ad --- /dev/null +++ b/letta/supervisor_multi_agent.py @@ -0,0 +1,103 @@ +from typing import List, Optional + +from letta.agent import Agent, AgentState +from letta.constants import DEFAULT_MESSAGE_TOOL +from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group +from letta.interface import AgentInterface +from letta.orm import User +from letta.orm.enums import ToolType +from letta.schemas.letta_message import TextContent +from letta.schemas.message import Message, MessageCreate +from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule +from letta.schemas.usage import LettaUsageStatistics +from letta.services.agent_manager import AgentManager +from letta.services.tool_manager import ToolManager +from tests.helpers.utils import create_tool_from_func + + +class SupervisorMultiAgent(Agent): + def __init__( + self, + interface: AgentInterface, + agent_state: AgentState, + user: User = None, + # custom + group_id: str = "", + agent_ids: List[str] = [], + description: str = "", + ): + super().__init__(interface, agent_state, user) + self.group_id = group_id + self.agent_ids = agent_ids + self.description = description + self.agent_manager = AgentManager() + self.tool_manager = ToolManager() + + def step( + self, + messages: List[MessageCreate], + chaining: bool = True, + max_chaining_steps: Optional[int] = None, + put_inner_thoughts_first: bool = True, + assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, + **kwargs, + ) -> LettaUsageStatistics: + token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False + metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None + + # add multi agent tool + if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None: + multi_agent_tool = create_tool_from_func(send_message_to_all_agents_in_group) + multi_agent_tool.tool_type = ToolType.LETTA_MULTI_AGENT_CORE + multi_agent_tool = self.tool_manager.create_or_update_tool( + pydantic_tool=multi_agent_tool, + actor=self.user, + ) + self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user) + + # override tool rules + self.agent_state.tool_rules = [ + InitToolRule( + tool_name="send_message_to_all_agents_in_group", + ), + TerminalToolRule( + tool_name=assistant_message_tool_name, + ), + ChildToolRule( + tool_name="send_message_to_all_agents_in_group", + children=[assistant_message_tool_name], + ), + ] + + supervisor_messages = [ + Message( + agent_id=self.agent_state.id, + role="user", + content=[TextContent(text=message.content)], + name=None, + model=None, + tool_calls=None, + tool_call_id=None, + group_id=self.group_id, + ) + for message in messages + ] + try: + supervisor_agent = Agent(agent_state=self.agent_state, interface=self.interface, user=self.user) + usage_stats = supervisor_agent.step( + messages=supervisor_messages, + chaining=chaining, + max_chaining_steps=max_chaining_steps, + stream=token_streaming, + skip_verify=True, + metadata=metadata, + put_inner_thoughts_first=put_inner_thoughts_first, + ) + except Exception as e: + raise e + finally: + self.interface.step_yield() + + self.interface.step_complete() + + return usage_stats diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py new file mode 100644 index 00000000..a6c30936 --- /dev/null +++ b/tests/test_multi_agent.py @@ -0,0 +1,233 @@ +import pytest +from sqlalchemy import delete + +from letta.config import LettaConfig +from letta.orm import Provider, Step +from letta.schemas.agent import CreateAgent +from letta.schemas.block import CreateBlock +from letta.schemas.group import DynamicManager, GroupCreate, SupervisorManager +from letta.schemas.message import MessageCreate +from letta.server.server import SyncServer + + +@pytest.fixture(scope="module") +def server(): + config = LettaConfig.load() + print("CONFIG PATH", config.config_path) + + config.save() + + server = SyncServer() + return server + + +@pytest.fixture(scope="module") +def org_id(server): + org = server.organization_manager.create_default_organization() + + yield org.id + + # cleanup + with server.organization_manager.session_maker() as session: + session.execute(delete(Step)) + session.execute(delete(Provider)) + session.commit() + server.organization_manager.delete_organization_by_id(org.id) + + +@pytest.fixture(scope="module") +def actor(server, org_id): + user = server.user_manager.create_default_user() + yield user + + # cleanup + server.user_manager.delete_user_by_id(user.id) + + +@pytest.fixture(scope="module") +def participant_agent_ids(server, actor): + agent_fred = server.create_agent( + request=CreateAgent( + name="fred", + memory_blocks=[ + CreateBlock( + label="persona", + value="Your name is fred and you like to ski and have been wanting to go on a ski trip soon. You are speaking in a group chat with other agent pals where you participate in friendly banter.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ), + actor=actor, + ) + agent_velma = server.create_agent( + request=CreateAgent( + name="velma", + memory_blocks=[ + CreateBlock( + label="persona", + value="Your name is velma and you like tropical locations. You are speaking in a group chat with other agent friends and you love to include everyone.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ), + actor=actor, + ) + agent_daphne = server.create_agent( + request=CreateAgent( + name="daphne", + memory_blocks=[ + CreateBlock( + label="persona", + value="Your name is daphne and you love traveling abroad. You are speaking in a group chat with other agent friends and you love to keep in touch with them.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ), + actor=actor, + ) + agent_shaggy = server.create_agent( + request=CreateAgent( + name="shaggy", + memory_blocks=[ + CreateBlock( + label="persona", + value="Your name is shaggy and your best friend is your dog, scooby. You are speaking in a group chat with other agent friends and you like to solve mysteries with them.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ), + actor=actor, + ) + yield [agent_fred.id, agent_velma.id, agent_daphne.id, agent_shaggy.id] + + # cleanup + server.agent_manager.delete_agent(agent_fred.id, actor=actor) + server.agent_manager.delete_agent(agent_velma.id, actor=actor) + server.agent_manager.delete_agent(agent_daphne.id, actor=actor) + server.agent_manager.delete_agent(agent_shaggy.id, actor=actor) + + +@pytest.fixture(scope="module") +def manager_agent_id(server, actor): + agent_scooby = server.create_agent( + request=CreateAgent( + name="scooby", + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a puppy operations agent for Letta and you help run multi-agent group chats. Your job is to get to know the agents in your group and pick who is best suited to speak next in the conversation.", + ), + CreateBlock( + label="human", + value="", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ), + actor=actor, + ) + yield agent_scooby.id + + # cleanup + server.agent_manager.delete_agent(agent_scooby.id, actor=actor) + + +@pytest.mark.asyncio +async def test_round_robin(server, actor, participant_agent_ids): + group = server.group_manager.create_group( + group=GroupCreate( + description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", + agent_ids=participant_agent_ids, + ), + actor=actor, + ) + response = await server.send_group_message_to_agent( + group_id=group.id, + actor=actor, + messages=[ + MessageCreate( + role="user", + content="what is everyone up to for the holidays?", + ), + ], + stream_steps=False, + stream_tokens=False, + ) + assert response.usage.step_count == len(participant_agent_ids) + assert len(response.messages) == response.usage.step_count * 2 + + server.group_manager.delete_group(group_id=group.id, actor=actor) + + +@pytest.mark.asyncio +async def test_supervisor(server, actor, manager_agent_id, participant_agent_ids): + group = server.group_manager.create_group( + group=GroupCreate( + description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", + agent_ids=participant_agent_ids, + manager_config=SupervisorManager( + manager_agent_id=manager_agent_id, + ), + ), + actor=actor, + ) + response = await server.send_group_message_to_agent( + group_id=group.id, + actor=actor, + messages=[ + MessageCreate( + role="user", + content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.", + ), + ], + stream_steps=False, + stream_tokens=False, + ) + assert response.usage.step_count == 2 + assert len(response.messages) == 5 + + # verify tool call + assert response.messages[0].message_type == "reasoning_message" + assert ( + response.messages[1].message_type == "tool_call_message" + and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group" + ) + assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len( + participant_agent_ids + ) + assert response.messages[3].message_type == "reasoning_message" + assert response.messages[4].message_type == "assistant_message" + + server.group_manager.delete_group(group_id=group.id, actor=actor) + + +@pytest.mark.asyncio +async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_agent_ids): + group = server.group_manager.create_group( + group=GroupCreate( + description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", + agent_ids=participant_agent_ids, + manager_config=DynamicManager( + manager_agent_id=manager_agent_id, + ), + ), + actor=actor, + ) + response = await server.send_group_message_to_agent( + group_id=group.id, + actor=actor, + messages=[ + MessageCreate(role="user", content="what is everyone up to for the holidays?"), + ], + stream_steps=False, + stream_tokens=False, + ) + assert response.usage.step_count == len(participant_agent_ids) * 2 + assert len(response.messages) == response.usage.step_count * 2 + + server.group_manager.delete_group(group_id=group.id, actor=actor)