feat: multi-agent (#1243)

This commit is contained in:
cthomas
2025-03-12 22:51:55 -07:00
committed by GitHub
parent 885d27719a
commit 5304831a8e
24 changed files with 1565 additions and 5 deletions

View File

@@ -0,0 +1,62 @@
"""add groups for multi agent
Revision ID: 77de976590ae
Revises: 167491cfb7a8
Create Date: 2025-03-12 14:01:58.034385
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "77de976590ae"
down_revision: Union[str, None] = "167491cfb7a8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"groups",
sa.Column("id", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=False),
sa.Column("manager_type", sa.String(), nullable=False),
sa.Column("manager_agent_id", sa.String(), nullable=True),
sa.Column("termination_token", sa.String(), nullable=True),
sa.Column("max_turns", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["manager_agent_id"], ["agents.id"], ondelete="RESTRICT"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"groups_agents",
sa.Column("group_id", sa.String(), nullable=False),
sa.Column("agent_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("group_id", "agent_id"),
)
op.add_column("messages", sa.Column("group_id", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("messages", "group_id")
op.drop_table("groups_agents")
op.drop_table("groups")
# ### end Alembic commands ###

View File

@@ -1,6 +1,7 @@
from letta_client import Letta
from pprint import pprint
from letta_client import Letta
client = Letta(base_url="http://localhost:8283")
mcp_server_name = "everything"

View File

@@ -95,6 +95,7 @@ class Agent(BaseAgent):
first_message_verify_mono: bool = True, # TODO move to config?
# MCP sessions, state held in-memory in the server
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
save_last_response: bool = False,
):
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
# Hold a copy of the state that was used to init the agent
@@ -149,6 +150,10 @@ class Agent(BaseAgent):
# Load last function response from message history
self.last_function_response = self.load_last_function_response()
# Save last responses in memory
self.save_last_response = save_last_response
self.last_response_messages = []
# Logger that the Agent specifically can use, will also report the agent_state ID with the logs
self.logger = get_logger(agent_state.id)
@@ -926,6 +931,9 @@ class Agent(BaseAgent):
else:
all_new_messages = all_response_messages
if self.save_last_response:
self.last_response_messages = all_response_messages
# Check the memory pressure and potentially issue a memory pressure warning
current_total_tokens = response.usage.total_tokens
active_memory_warning = False
@@ -1052,6 +1060,7 @@ class Agent(BaseAgent):
else:
logger.error(f"step() failed with an unrecognized exception: '{str(e)}'")
traceback.print_exc()
raise e
def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:

View File

@@ -0,0 +1,274 @@
from typing import List, Optional
from letta.agent import Agent, AgentState
from letta.interface import AgentInterface
from letta.orm import User
from letta.schemas.block import Block
from letta.schemas.letta_message import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.usage import LettaUsageStatistics
from letta.services.tool_manager import ToolManager
class DynamicMultiAgent(Agent):
def __init__(
self,
interface: AgentInterface,
agent_state: AgentState,
user: User = None,
# custom
group_id: str = "",
agent_ids: List[str] = [],
description: str = "",
max_turns: Optional[int] = None,
termination_token: str = "DONE!",
):
super().__init__(interface, agent_state, user)
self.group_id = group_id
self.agent_ids = agent_ids
self.description = description
self.max_turns = max_turns or len(agent_ids)
self.termination_token = termination_token
self.tool_manager = ToolManager()
def step(
self,
messages: List[MessageCreate],
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
**kwargs,
) -> LettaUsageStatistics:
total_usage = UsageStatistics()
step_count = 0
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
agents = {}
message_index = {self.agent_state.id: 0}
agents[self.agent_state.id] = self.load_manager_agent()
for agent_id in self.agent_ids:
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
message_index[agent_id] = 0
chat_history: List[Message] = []
new_messages = messages
speaker_id = None
try:
for _ in range(self.max_turns):
agent_id_options = [agent_id for agent_id in self.agent_ids if agent_id != speaker_id]
manager_message = self.ask_manager_to_choose_participant_message(new_messages, chat_history, agent_id_options)
manager_agent = agents[self.agent_state.id]
usage_stats = manager_agent.step(
messages=[manager_message],
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
skip_verify=True,
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
)
responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages)
assistant_message = [response for response in responses if response.message_type == "assistant_message"][0]
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
if name.lower() in assistant_message.content.lower():
speaker_id = agent_id
# sum usage
total_usage.prompt_tokens += usage_stats.prompt_tokens
total_usage.completion_tokens += usage_stats.completion_tokens
total_usage.total_tokens += usage_stats.total_tokens
step_count += 1
# initialize input messages
for message in chat_history[message_index[speaker_id] :]:
message.id = Message.generate_id()
message.agent_id = speaker_id
for message in new_messages:
chat_history.append(
Message(
agent_id=speaker_id,
role=message.role,
content=[TextContent(text=message.content)],
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)
)
# load agent and perform step
participant_agent = agents[speaker_id]
usage_stats = participant_agent.step(
messages=chat_history[message_index[speaker_id] :],
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
skip_verify=True,
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
)
# parse new messages for next step
responses = Message.to_letta_messages_from_list(
participant_agent.last_response_messages,
)
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
new_messages = [
MessageCreate(
role="system",
content=message.content,
name=participant_agent.agent_state.name,
)
for message in assistant_messages
]
message_index[agent_id] = len(chat_history) + len(new_messages)
# sum usage
total_usage.prompt_tokens += usage_stats.prompt_tokens
total_usage.completion_tokens += usage_stats.completion_tokens
total_usage.total_tokens += usage_stats.total_tokens
step_count += 1
# check for termination token
if any(self.termination_token in message.content for message in new_messages):
break
# persist remaining chat history
for message in new_messages:
chat_history.append(
Message(
agent_id=agent_id,
role=message.role,
content=[TextContent(text=message.content)],
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)
)
for agent_id, index in message_index.items():
if agent_id == speaker_id:
continue
for message in chat_history[index:]:
message.id = Message.generate_id()
message.agent_id = agent_id
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
except Exception as e:
raise e
finally:
self.interface.step_yield()
self.interface.step_complete()
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
def load_manager_agent(self) -> Agent:
for participant_agent_id in self.agent_ids:
participant_agent_state = self.agent_manager.get_agent_by_id(agent_id=participant_agent_id, actor=self.user)
participant_persona_block = participant_agent_state.memory.get_block(label="persona")
new_block = self.block_manager.create_or_update_block(
block=Block(
label=participant_agent_id,
value=participant_persona_block.value,
),
actor=self.user,
)
self.agent_state = self.agent_manager.update_block_with_label(
agent_id=self.agent_state.id,
block_label=participant_agent_id,
new_block_id=new_block.id,
actor=self.user,
)
persona_block = self.agent_state.memory.get_block(label="persona")
group_chat_manager_persona = (
f"You are overseeing a group chat with {len(self.agent_ids) - 1} agents and "
f"one user. Description of the group: {self.description}\n"
"On each turn, you will be provided with the chat history and latest message. "
"Your task is to decide which participant should speak next in the chat based "
"on the chat history. Each agent has a memory block labeled with their ID which "
"holds info about them, and you should use this context to inform your decision."
)
self.agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_manager_persona)
return Agent(
agent_state=self.agent_state,
interface=self.interface,
user=self.user,
save_last_response=True,
)
def load_participant_agent(self, agent_id: str) -> Agent:
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
persona_block = agent_state.memory.get_block(label="persona")
group_chat_participant_persona = (
f"You are a participant in a group chat with {len(self.agent_ids) - 1} other "
"agents and one user. Respond to new messages in the group chat when prompted. "
f"Description of the group: {self.description}. About you: "
)
agent_state.memory.update_block_value(label="persona", value=group_chat_participant_persona + persona_block.value)
return Agent(
agent_state=agent_state,
interface=self.interface,
user=self.user,
save_last_response=True,
)
'''
def attach_choose_next_participant_tool(self) -> AgentState:
def choose_next_participant(next_speaker_agent_id: str) -> str:
"""
Returns ID of the agent in the group chat that should reply to the latest message in the conversation. The agent ID will always be in the format: `agent-{UUID}`.
Args:
next_speaker_agent_id (str): The ID of the agent that is most suitable to be the next speaker.
Returns:
str: The ID of the agent that should be the next speaker.
"""
return next_speaker_agent_id
source_code = parse_source_code(choose_next_participant)
tool = self.tool_manager.create_or_update_tool(
Tool(
source_type="python",
source_code=source_code,
name="choose_next_participant",
),
actor=self.user,
)
return self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=tool.id, actor=self.user)
'''
def ask_manager_to_choose_participant_message(
self,
new_messages: List[MessageCreate],
chat_history: List[Message],
agent_id_options: List[str],
) -> Message:
chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
for message in new_messages:
chat_history.append(f"{message.name or 'user'}: {message.content}")
context_messages = "\n".join(chat_history)
message_text = (
"Choose the most suitable agent to reply to the latest message in the "
f"group chat from the following options: {agent_id_options}. Do not "
"respond to the messages yourself, your task is only to decide the "
f"next speaker, not to participate. \nChat history:\n{context_messages}"
)
return Message(
agent_id=self.agent_state.id,
role="user",
content=[TextContent(text=message_text)],
name=None,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List
from letta.functions.helpers import (
_send_message_to_agents_matching_tags_async,
_send_message_to_all_agents_in_group_async,
execute_send_message_to_agent,
fire_and_forget_send_to_agent,
)
@@ -86,3 +87,19 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
"""
return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some))
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:
"""
Sends a message to all agents within the same multi-agent group.
Args:
message (str): The content of the message to be sent to each matching agent.
Returns:
List[str]: A list of responses from the agents that matched the filtering criteria. Each
response corresponds to a single agent. Agents that do not respond will not have an entry
in the returned list.
"""
return asyncio.run(_send_message_to_all_agents_in_group_async(self, message))

View File

@@ -604,6 +604,47 @@ async def _send_message_to_agents_matching_tags_async(
return final
async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", message: str) -> List[str]:
server = get_letta_server()
augmented_message = (
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
worker_agents_ids = sender_agent.agent_state.multi_agent_group.agent_ids
worker_agents = [server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=sender_agent.user) for agent_id in worker_agents_ids]
# Create a system message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
# Possibly limit concurrency to avoid meltdown:
sem = asyncio.Semaphore(settings.multi_agent_concurrent_sends)
async def _send_single(agent_state):
async with sem:
return await async_send_message_with_retries(
server=server,
sender_agent=sender_agent,
target_agent_id=agent_state.id,
messages=messages,
max_retries=3,
timeout=settings.multi_agent_send_message_timeout,
)
tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in worker_agents]
results = await asyncio.gather(*tasks, return_exceptions=True)
final = []
for r in results:
if isinstance(r, Exception):
final.append(str(r))
else:
final.append(r)
return final
def generate_model_from_args_json_schema(schema: Dict[str, Any]) -> Type[BaseModel]:
"""Creates a Pydantic model from a JSON schema.

View File

@@ -4,6 +4,8 @@ from letta.orm.base import Base
from letta.orm.block import Block
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.group import Group
from letta.orm.groups_agents import GroupsAgents
from letta.orm.identities_agents import IdentitiesAgents
from letta.orm.identities_blocks import IdentitiesBlocks
from letta.orm.identity import Identity

View File

@@ -128,11 +128,25 @@ class Agent(SqlalchemyBase, OrganizationMixin):
back_populates="agents",
passive_deletes=True,
)
groups: Mapped[List["Group"]] = relationship(
"Group",
secondary="groups_agents",
lazy="selectin",
back_populates="agents",
passive_deletes=True,
)
multi_agent_group: Mapped["Group"] = relationship(
"Group",
lazy="joined",
viewonly=True,
back_populates="manager_agent",
)
def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart"""
# add default rule for having send_message be a terminal tool
tool_rules = self.tool_rules
multi_agent_group = self.multi_agent_group
state = {
"id": self.id,
"organization_id": self.organization_id,
@@ -159,6 +173,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"base_template_id": self.base_template_id,
"identity_ids": [identity.id for identity in self.identities],
"message_buffer_autoclear": self.message_buffer_autoclear,
"multi_agent_group": multi_agent_group,
}
return self.__pydantic_model__(**state)

33
letta/orm/group.py Normal file
View File

@@ -0,0 +1,33 @@
import uuid
from typing import List, Optional
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.group import Group as PydanticGroup
class Group(SqlalchemyBase, OrganizationMixin):
__tablename__ = "groups"
__pydantic_model__ = PydanticGroup
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"group-{uuid.uuid4()}")
description: Mapped[str] = mapped_column(nullable=False, doc="")
manager_type: Mapped[str] = mapped_column(nullable=False, doc="")
manager_agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id", ondelete="RESTRICT"), nullable=True, doc="")
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
agents: Mapped[List["Agent"]] = relationship(
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
)
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")
@property
def agent_ids(self) -> List[str]:
return [agent.id for agent in self.agents]

View File

@@ -0,0 +1,13 @@
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
class GroupsAgents(Base):
"""Agents may have one or many groups associated with them."""
__tablename__ = "groups_agents"
group_id: Mapped[str] = mapped_column(String, ForeignKey("groups.id", ondelete="CASCADE"), primary_key=True)
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)

View File

@@ -36,6 +36,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
tool_returns: Mapped[List[ToolReturn]] = mapped_column(
ToolReturnColumn, nullable=True, doc="Tool execution return information for prior tool calls"
)
group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in")
# Relationships
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")

View File

@@ -49,6 +49,7 @@ class Organization(SqlalchemyBase):
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan")
groups: Mapped[List["Group"]] = relationship("Group", back_populates="organization", cascade="all, delete-orphan")
@property
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:

View File

@@ -139,11 +139,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
else:
# Match ANY tag - use join and filter
query = (
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id)
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
) # Deduplicate results
# Group by primary key and all necessary columns to avoid JSON comparison
query = query.group_by(cls.id)
# select distinct primary key
query = query.distinct(cls.id).order_by(cls.id)
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))

View File

@@ -0,0 +1,152 @@
from typing import List, Optional
from letta.agent import Agent, AgentState
from letta.interface import AgentInterface
from letta.orm import User
from letta.schemas.letta_message import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.usage import LettaUsageStatistics
class RoundRobinMultiAgent(Agent):
def __init__(
self,
interface: AgentInterface,
agent_state: AgentState,
user: User = None,
# custom
group_id: str = "",
agent_ids: List[str] = [],
description: str = "",
max_turns: Optional[int] = None,
):
super().__init__(interface, agent_state, user)
self.group_id = group_id
self.agent_ids = agent_ids
self.description = description
self.max_turns = max_turns or len(agent_ids)
def step(
self,
messages: List[MessageCreate],
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
**kwargs,
) -> LettaUsageStatistics:
total_usage = UsageStatistics()
step_count = 0
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
agents = {}
for agent_id in self.agent_ids:
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
message_index = {}
chat_history: List[Message] = []
new_messages = messages
speaker_id = None
try:
for i in range(self.max_turns):
speaker_id = self.agent_ids[i % len(self.agent_ids)]
# initialize input messages
start_index = message_index[speaker_id] if speaker_id in message_index else 0
for message in chat_history[start_index:]:
message.id = Message.generate_id()
message.agent_id = speaker_id
for message in new_messages:
chat_history.append(
Message(
agent_id=speaker_id,
role=message.role,
content=[TextContent(text=message.content)],
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)
)
# load agent and perform step
participant_agent = agents[speaker_id]
usage_stats = participant_agent.step(
messages=chat_history[start_index:],
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
skip_verify=True,
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
)
# parse new messages for next step
responses = Message.to_letta_messages_from_list(participant_agent.last_response_messages)
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
new_messages = [
MessageCreate(
role="system",
content=message.content,
name=participant_agent.agent_state.name,
)
for message in assistant_messages
]
message_index[speaker_id] = len(chat_history) + len(new_messages)
# sum usage
total_usage.prompt_tokens += usage_stats.prompt_tokens
total_usage.completion_tokens += usage_stats.completion_tokens
total_usage.total_tokens += usage_stats.total_tokens
step_count += 1
# persist remaining chat history
for message in new_messages:
chat_history.append(
Message(
agent_id=agent_id,
role=message.role,
content=[TextContent(text=message.content)],
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)
)
for agent_id, index in message_index.items():
if agent_id == speaker_id:
continue
for message in chat_history[index:]:
message.id = Message.generate_id()
message.agent_id = agent_id
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
except Exception as e:
raise e
finally:
self.interface.step_yield()
self.interface.step_complete()
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
def load_participant_agent(self, agent_id: str) -> Agent:
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
persona_block = agent_state.memory.get_block(label="persona")
group_chat_participant_persona = (
"\n\n====Group Chat Contex===="
f"\nYou are speaking in a group chat with {len(self.agent_ids) - 1} other "
"agents and one user. Respond to new messages in the group chat when prompted. "
f"Description of the group: {self.description}"
)
agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_participant_persona)
return Agent(
agent_state=agent_state,
interface=self.interface,
user=self.user,
save_last_response=True,
)

View File

@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.environment_variables import AgentEnvironmentVariable
from letta.schemas.group import Group
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
@@ -90,6 +91,8 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
)
multi_agent_group: Optional[Group] = Field(None, description="The multi-agent group that this agent manages")
def get_agent_env_vars_as_dict(self) -> Dict[str, str]:
# Get environment variables for this agent specifically
per_agent_env_vars = {}

65
letta/schemas/group.py Normal file
View File

@@ -0,0 +1,65 @@
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from letta.schemas.letta_base import LettaBase
class ManagerType(str, Enum):
round_robin = "round_robin"
supervisor = "supervisor"
dynamic = "dynamic"
swarm = "swarm"
class GroupBase(LettaBase):
__id_prefix__ = "group"
class Group(GroupBase):
id: str = Field(..., description="The id of the group. Assigned by the database.")
manager_type: ManagerType = Field(..., description="")
agent_ids: List[str] = Field(..., description="")
description: str = Field(..., description="")
# Pattern fields
manager_agent_id: Optional[str] = Field(None, description="")
termination_token: Optional[str] = Field(None, description="")
max_turns: Optional[int] = Field(None, description="")
class ManagerConfig(BaseModel):
manager_type: ManagerType = Field(..., description="")
class RoundRobinManager(ManagerConfig):
manager_type: Literal[ManagerType.round_robin] = Field(ManagerType.round_robin, description="")
max_turns: Optional[int] = Field(None, description="")
class SupervisorManager(ManagerConfig):
manager_type: Literal[ManagerType.supervisor] = Field(ManagerType.supervisor, description="")
manager_agent_id: str = Field(..., description="")
class DynamicManager(ManagerConfig):
manager_type: Literal[ManagerType.dynamic] = Field(ManagerType.dynamic, description="")
manager_agent_id: str = Field(..., description="")
termination_token: Optional[str] = Field("DONE!", description="")
max_turns: Optional[int] = Field(None, description="")
# class SwarmGroup(ManagerConfig):
# manager_type: Literal[ManagerType.swarm] = Field(ManagerType.swarm, description="")
ManagerConfigUnion = Annotated[
Union[RoundRobinManager, SupervisorManager, DynamicManager],
Field(discriminator="manager_type"),
]
class GroupCreate(BaseModel):
agent_ids: List[str] = Field(..., description="")
description: str = Field(..., description="")
manager_config: Optional[ManagerConfigUnion] = Field(None, description="")

View File

@@ -129,6 +129,7 @@ class Message(BaseMessage):
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls")
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

View File

@@ -1,5 +1,6 @@
from letta.server.rest_api.routers.v1.agents import router as agents_router
from letta.server.rest_api.routers.v1.blocks import router as blocks_router
from letta.server.rest_api.routers.v1.groups import router as groups_router
from letta.server.rest_api.routers.v1.health import router as health_router
from letta.server.rest_api.routers.v1.identities import router as identities_router
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
@@ -17,6 +18,7 @@ ROUTERS = [
tools_router,
sources_router,
agents_router,
groups_router,
identities_router,
llm_router,
blocks_router,

View File

@@ -0,0 +1,233 @@
from typing import Annotated, List, Optional
from fastapi import APIRouter, Body, Depends, Header, Query
from pydantic import Field
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.group import Group, GroupCreate, ManagerType
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
router = APIRouter(prefix="/groups", tags=["groups"])
@router.post("/", response_model=Group, operation_id="create_group")
async def create_group(
server: SyncServer = Depends(get_letta_server),
request: GroupCreate = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a multi-agent group with a specified management pattern. When no
management config is specified, this endpoint will use round robin for
speaker selection.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.create_group(request, actor=actor)
@router.get("/", response_model=List[Group], operation_id="list_groups")
def list_groups(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
manager_type: Optional[ManagerType] = Query(None, description="Search groups by manager type"),
before: Optional[str] = Query(None, description="Cursor for pagination"),
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(None, description="Limit for pagination"),
project_id: Optional[str] = Query(None, description="Search groups by project id"),
):
"""
Fetch all multi-agent groups matching query.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.list_groups(
project_id=project_id,
manager_type=manager_type,
before=before,
after=after,
limit=limit,
actor=actor,
)
@router.post("/", response_model=Group, operation_id="create_group")
def create_group(
group: GroupCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
):
"""
Create a new multi-agent group with the specified configuration.
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.create_group(group, actor=actor)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/", response_model=Group, operation_id="upsert_group")
def upsert_group(
group: GroupCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
):
"""
Create a new multi-agent group with the specified configuration.
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.create_group(group, actor=actor)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{group_id}", response_model=None, operation_id="delete_group")
def delete_group(
group_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a multi-agent group.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
server.group_manager.delete_group(group_id=group_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Group id={group_id} successfully deleted"})
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Group id={group_id} not found for user_id={actor.id}.")
@router.post(
"/{group_id}/messages",
response_model=LettaResponse,
operation_id="send_group_message",
)
async def send_group_message(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the group's response.
This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
result = await server.send_group_message_to_agent(
group_id=group_id,
actor=actor,
messages=request.messages,
stream_steps=False,
stream_tokens=False,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
)
return result
@router.post(
"/{group_id}/messages/stream",
response_model=None,
operation_id="send_group_message_streaming",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
)
async def send_group_message_streaming(
group_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaStreamingRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the group's responses.
This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern.
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
result = await server.send_group_message_to_agent(
group_id=group_id,
actor=actor,
messages=request.messages,
stream_steps=True,
stream_tokens=request.stream_tokens,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
)
return result
GroupMessagesResponse = Annotated[
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
]
@router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages")
def list_group_messages(
group_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve message history for an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.list_group_messages(
group_id=group_id,
before=before,
after=after,
limit=limit,
actor=actor,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
)
'''
@router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages")
def reset_group_messages(
group_id: str,
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Resets the messages for all agents that are part of the multi-agent group.
TODO: only delete group messages not all messages!
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
agent_ids = group.agent_ids
if group.manager_agent_id:
agent_ids.append(group.manager_agent_id)
for agent_id in agent_ids:
server.agent_manager.reset_messages(
agent_id=agent_id,
actor=actor,
add_default_initial_messages=add_default_initial_messages,
)
'''

View File

@@ -19,6 +19,7 @@ import letta.system as system
from letta.agent import Agent, save_agent
from letta.config import LettaConfig
from letta.data_sources.connectors import DataConnector, load_data
from letta.dynamic_multi_agent import DynamicMultiAgent
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.mcp_helpers import (
@@ -37,6 +38,7 @@ from letta.interface import CLIInterface # for printing to terminal
from letta.log import get_logger
from letta.offline_memory_agent import OfflineMemoryAgent
from letta.orm.errors import NoResultFound
from letta.round_robin_multi_agent import RoundRobinMultiAgent
from letta.schemas.agent import AgentState, AgentType, CreateAgent
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
@@ -44,6 +46,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.group import Group, ManagerType
from letta.schemas.job import Job, JobUpdate
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage
from letta.schemas.letta_response import LettaResponse
@@ -80,6 +83,7 @@ from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.utils import sse_async_generator
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.group_manager import GroupManager
from letta.services.identity_manager import IdentityManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
@@ -94,6 +98,7 @@ from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings, settings, tool_settings
from letta.supervisor_multi_agent import SupervisorMultiAgent
from letta.tracing import trace_method
from letta.utils import get_friendly_error_msg
@@ -207,6 +212,7 @@ class SyncServer(Server):
self.provider_manager = ProviderManager()
self.step_manager = StepManager()
self.identity_manager = IdentityManager()
self.group_manager = GroupManager()
# Managers that interface with parallelism
self.per_agent_lock_manager = PerAgentLockManager()
@@ -353,6 +359,8 @@ class SyncServer(Server):
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
with agent_lock:
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.multi_agent_group:
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
@@ -364,6 +372,46 @@ class SyncServer(Server):
return agent
def load_multi_agent(
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
) -> Agent:
match group.manager_type:
case ManagerType.round_robin:
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor)
return RoundRobinMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
max_turns=group.max_turns,
)
case ManagerType.dynamic:
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
return DynamicMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
max_turns=group.max_turns,
termination_token=group.termination_token,
)
case ManagerType.supervisor:
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
return SupervisorMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
)
case _:
raise ValueError(f"Type {group.manager_type} is not supported.")
def _step(
self,
actor: User,
@@ -1403,3 +1451,106 @@ class SyncServer(Server):
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
@trace_method
async def send_group_message_to_agent(
self,
group_id: str,
actor: User,
messages: Union[List[Message], List[MessageCreate]],
stream_steps: bool,
stream_tokens: bool,
chat_completion_mode: bool = False,
# Support for AssistantMessage
use_assistant_message: bool = True,
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
metadata: Optional[dict] = None,
) -> Union[StreamingResponse, LettaResponse]:
include_final_message = True
if not stream_steps and stream_tokens:
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
try:
# fetch the group
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
llm_config = letta_multi_agent.agent_state.llm_config
supports_token_streaming = ["openai", "anthropic", "deepseek"]
if stream_tokens and (
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
):
warnings.warn(
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
)
stream_tokens = False
# Create a new interface per request
letta_multi_agent.interface = StreamingServerInterface(
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
inner_thoughts_in_kwargs=(
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
),
)
streaming_interface = letta_multi_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
streaming_interface.streaming_mode = stream_tokens
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
if metadata and hasattr(streaming_interface, "metadata"):
streaming_interface.metadata = metadata
streaming_interface.stream_start()
task = asyncio.create_task(
asyncio.to_thread(
letta_multi_agent.step,
messages=messages,
chaining=self.chaining,
max_chaining_steps=self.max_chaining_steps,
)
)
if stream_steps:
# return a stream
return StreamingResponse(
sse_async_generator(
streaming_interface.get_generator(),
usage_task=task,
finish_message=include_final_message,
),
media_type="text/event-stream",
)
else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, LettaMessage)
or isinstance(message, LegacyLettaMessage)
or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message)
if message == MessageStreamStatus.done:
break
# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
usage = await task
# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
return LettaResponse(messages=filtered_stream, usage=usage)
except HTTPException:
raise
except Exception as e:
print(e)
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")

View File

@@ -399,7 +399,7 @@ class AgentManager:
# Ensures agents match at least one tag in match_some
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
query = query.group_by(AgentModel.id).limit(limit)
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
return list(session.execute(query).scalars())
@@ -434,6 +434,7 @@ class AgentManager:
with self.session_maker() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# TODO check if it is managing a group
agent.hard_delete(session)
@enforce_types

View File

@@ -0,0 +1,147 @@
from typing import List, Optional
from sqlalchemy.orm import Session
from letta.orm.agent import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.group import Group as GroupModel
from letta.orm.message import Message as MessageModel
from letta.schemas.group import Group as PydanticGroup
from letta.schemas.group import GroupCreate, ManagerType
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class GroupManager:
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def list_groups(
self,
project_id: Optional[str] = None,
manager_type: Optional[ManagerType] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> list[PydanticGroup]:
with self.session_maker() as session:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
if manager_type:
filters["manager_type"] = manager_type
groups = GroupModel.list(
db_session=session,
before=before,
after=after,
limit=limit,
**filters,
)
return [group.to_pydantic() for group in groups]
@enforce_types
def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
with self.session_maker() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
return group.to_pydantic()
@enforce_types
def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup:
with self.session_maker() as session:
new_group = GroupModel()
new_group.organization_id = actor.organization_id
new_group.description = group.description
self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False)
if group.manager_config is None:
new_group.manager_type = ManagerType.round_robin
else:
match group.manager_config.manager_type:
case ManagerType.round_robin:
new_group.manager_type = ManagerType.round_robin
new_group.max_turns = group.manager_config.max_turns
case ManagerType.dynamic:
new_group.manager_type = ManagerType.dynamic
new_group.manager_agent_id = group.manager_config.manager_agent_id
new_group.max_turns = group.manager_config.max_turns
new_group.termination_token = group.manager_config.termination_token
case ManagerType.supervisor:
new_group.manager_type = ManagerType.supervisor
new_group.manager_agent_id = group.manager_config.manager_agent_id
case _:
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
new_group.create(session, actor=actor)
return new_group.to_pydantic()
@enforce_types
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
# Retrieve the agent
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
group.hard_delete(session)
@enforce_types
def list_group_messages(
self,
group_id: Optional[str] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
use_assistant_message: bool = True,
assistant_message_tool_name: str = "send_message",
assistant_message_tool_kwarg: str = "message",
) -> list[PydanticGroup]:
with self.session_maker() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
agent_id = group.manager_agent_id if group.manager_agent_id else group.agent_ids[0]
filters = {
"organization_id": actor.organization_id,
"group_id": group_id,
"agent_id": agent_id,
}
messages = MessageModel.list(
db_session=session,
before=before,
after=after,
limit=limit,
**filters,
)
messages = PydanticMessage.to_letta_messages_from_list(
messages=messages,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
)
return messages
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
current_relationship = getattr(group, "agents", [])
if not agent_ids:
if replace:
setattr(group, "agents", [])
return
# Retrieve models for the provided IDs
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(agent_ids):
missing = set(agent_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in agents: {missing}")
if replace:
# Replace the relationship
setattr(group, "agents", found_items)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}
new_items = [item for item in found_items if item.id not in current_ids]
current_relationship.extend(new_items)

View File

@@ -0,0 +1,103 @@
from typing import List, Optional
from letta.agent import Agent, AgentState
from letta.constants import DEFAULT_MESSAGE_TOOL
from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group
from letta.interface import AgentInterface
from letta.orm import User
from letta.orm.enums import ToolType
from letta.schemas.letta_message import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager
from letta.services.tool_manager import ToolManager
from tests.helpers.utils import create_tool_from_func
class SupervisorMultiAgent(Agent):
def __init__(
self,
interface: AgentInterface,
agent_state: AgentState,
user: User = None,
# custom
group_id: str = "",
agent_ids: List[str] = [],
description: str = "",
):
super().__init__(interface, agent_state, user)
self.group_id = group_id
self.agent_ids = agent_ids
self.description = description
self.agent_manager = AgentManager()
self.tool_manager = ToolManager()
def step(
self,
messages: List[MessageCreate],
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
**kwargs,
) -> LettaUsageStatistics:
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
# add multi agent tool
if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
multi_agent_tool = create_tool_from_func(send_message_to_all_agents_in_group)
multi_agent_tool.tool_type = ToolType.LETTA_MULTI_AGENT_CORE
multi_agent_tool = self.tool_manager.create_or_update_tool(
pydantic_tool=multi_agent_tool,
actor=self.user,
)
self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
# override tool rules
self.agent_state.tool_rules = [
InitToolRule(
tool_name="send_message_to_all_agents_in_group",
),
TerminalToolRule(
tool_name=assistant_message_tool_name,
),
ChildToolRule(
tool_name="send_message_to_all_agents_in_group",
children=[assistant_message_tool_name],
),
]
supervisor_messages = [
Message(
agent_id=self.agent_state.id,
role="user",
content=[TextContent(text=message.content)],
name=None,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)
for message in messages
]
try:
supervisor_agent = Agent(agent_state=self.agent_state, interface=self.interface, user=self.user)
usage_stats = supervisor_agent.step(
messages=supervisor_messages,
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
skip_verify=True,
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
)
except Exception as e:
raise e
finally:
self.interface.step_yield()
self.interface.step_complete()
return usage_stats

233
tests/test_multi_agent.py Normal file
View File

@@ -0,0 +1,233 @@
import pytest
from sqlalchemy import delete
from letta.config import LettaConfig
from letta.orm import Provider, Step
from letta.schemas.agent import CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.group import DynamicManager, GroupCreate, SupervisorManager
from letta.schemas.message import MessageCreate
from letta.server.server import SyncServer
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
return server
@pytest.fixture(scope="module")
def org_id(server):
org = server.organization_manager.create_default_organization()
yield org.id
# cleanup
with server.organization_manager.session_maker() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()
server.organization_manager.delete_organization_by_id(org.id)
@pytest.fixture(scope="module")
def actor(server, org_id):
user = server.user_manager.create_default_user()
yield user
# cleanup
server.user_manager.delete_user_by_id(user.id)
@pytest.fixture(scope="module")
def participant_agent_ids(server, actor):
agent_fred = server.create_agent(
request=CreateAgent(
name="fred",
memory_blocks=[
CreateBlock(
label="persona",
value="Your name is fred and you like to ski and have been wanting to go on a ski trip soon. You are speaking in a group chat with other agent pals where you participate in friendly banter.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
agent_velma = server.create_agent(
request=CreateAgent(
name="velma",
memory_blocks=[
CreateBlock(
label="persona",
value="Your name is velma and you like tropical locations. You are speaking in a group chat with other agent friends and you love to include everyone.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
agent_daphne = server.create_agent(
request=CreateAgent(
name="daphne",
memory_blocks=[
CreateBlock(
label="persona",
value="Your name is daphne and you love traveling abroad. You are speaking in a group chat with other agent friends and you love to keep in touch with them.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
agent_shaggy = server.create_agent(
request=CreateAgent(
name="shaggy",
memory_blocks=[
CreateBlock(
label="persona",
value="Your name is shaggy and your best friend is your dog, scooby. You are speaking in a group chat with other agent friends and you like to solve mysteries with them.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
yield [agent_fred.id, agent_velma.id, agent_daphne.id, agent_shaggy.id]
# cleanup
server.agent_manager.delete_agent(agent_fred.id, actor=actor)
server.agent_manager.delete_agent(agent_velma.id, actor=actor)
server.agent_manager.delete_agent(agent_daphne.id, actor=actor)
server.agent_manager.delete_agent(agent_shaggy.id, actor=actor)
@pytest.fixture(scope="module")
def manager_agent_id(server, actor):
agent_scooby = server.create_agent(
request=CreateAgent(
name="scooby",
memory_blocks=[
CreateBlock(
label="persona",
value="You are a puppy operations agent for Letta and you help run multi-agent group chats. Your job is to get to know the agents in your group and pick who is best suited to speak next in the conversation.",
),
CreateBlock(
label="human",
value="",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
yield agent_scooby.id
# cleanup
server.agent_manager.delete_agent(agent_scooby.id, actor=actor)
@pytest.mark.asyncio
async def test_round_robin(server, actor, participant_agent_ids):
group = server.group_manager.create_group(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
agent_ids=participant_agent_ids,
),
actor=actor,
)
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
MessageCreate(
role="user",
content="what is everyone up to for the holidays?",
),
],
stream_steps=False,
stream_tokens=False,
)
assert response.usage.step_count == len(participant_agent_ids)
assert len(response.messages) == response.usage.step_count * 2
server.group_manager.delete_group(group_id=group.id, actor=actor)
@pytest.mark.asyncio
async def test_supervisor(server, actor, manager_agent_id, participant_agent_ids):
group = server.group_manager.create_group(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
agent_ids=participant_agent_ids,
manager_config=SupervisorManager(
manager_agent_id=manager_agent_id,
),
),
actor=actor,
)
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
MessageCreate(
role="user",
content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.",
),
],
stream_steps=False,
stream_tokens=False,
)
assert response.usage.step_count == 2
assert len(response.messages) == 5
# verify tool call
assert response.messages[0].message_type == "reasoning_message"
assert (
response.messages[1].message_type == "tool_call_message"
and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group"
)
assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len(
participant_agent_ids
)
assert response.messages[3].message_type == "reasoning_message"
assert response.messages[4].message_type == "assistant_message"
server.group_manager.delete_group(group_id=group.id, actor=actor)
@pytest.mark.asyncio
async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_agent_ids):
group = server.group_manager.create_group(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
agent_ids=participant_agent_ids,
manager_config=DynamicManager(
manager_agent_id=manager_agent_id,
),
),
actor=actor,
)
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
MessageCreate(role="user", content="what is everyone up to for the holidays?"),
],
stream_steps=False,
stream_tokens=False,
)
assert response.usage.step_count == len(participant_agent_ids) * 2
assert len(response.messages) == response.usage.step_count * 2
server.group_manager.delete_group(group_id=group.id, actor=actor)