feat: multi-agent (#1243)
This commit is contained in:
62
alembic/versions/77de976590ae_add_groups_for_multi_agent.py
Normal file
62
alembic/versions/77de976590ae_add_groups_for_multi_agent.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""add groups for multi agent
|
||||
|
||||
Revision ID: 77de976590ae
|
||||
Revises: 167491cfb7a8
|
||||
Create Date: 2025-03-12 14:01:58.034385
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "77de976590ae"
|
||||
down_revision: Union[str, None] = "167491cfb7a8"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"groups",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=False),
|
||||
sa.Column("manager_type", sa.String(), nullable=False),
|
||||
sa.Column("manager_agent_id", sa.String(), nullable=True),
|
||||
sa.Column("termination_token", sa.String(), nullable=True),
|
||||
sa.Column("max_turns", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["manager_agent_id"], ["agents.id"], ondelete="RESTRICT"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"groups_agents",
|
||||
sa.Column("group_id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("group_id", "agent_id"),
|
||||
)
|
||||
op.add_column("messages", sa.Column("group_id", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("messages", "group_id")
|
||||
op.drop_table("groups_agents")
|
||||
op.drop_table("groups")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,6 +1,7 @@
|
||||
from letta_client import Letta
|
||||
from pprint import pprint
|
||||
|
||||
from letta_client import Letta
|
||||
|
||||
client = Letta(base_url="http://localhost:8283")
|
||||
|
||||
mcp_server_name = "everything"
|
||||
|
||||
@@ -95,6 +95,7 @@ class Agent(BaseAgent):
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
# MCP sessions, state held in-memory in the server
|
||||
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
|
||||
save_last_response: bool = False,
|
||||
):
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
|
||||
# Hold a copy of the state that was used to init the agent
|
||||
@@ -149,6 +150,10 @@ class Agent(BaseAgent):
|
||||
# Load last function response from message history
|
||||
self.last_function_response = self.load_last_function_response()
|
||||
|
||||
# Save last responses in memory
|
||||
self.save_last_response = save_last_response
|
||||
self.last_response_messages = []
|
||||
|
||||
# Logger that the Agent specifically can use, will also report the agent_state ID with the logs
|
||||
self.logger = get_logger(agent_state.id)
|
||||
|
||||
@@ -926,6 +931,9 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
all_new_messages = all_response_messages
|
||||
|
||||
if self.save_last_response:
|
||||
self.last_response_messages = all_response_messages
|
||||
|
||||
# Check the memory pressure and potentially issue a memory pressure warning
|
||||
current_total_tokens = response.usage.total_tokens
|
||||
active_memory_warning = False
|
||||
@@ -1052,6 +1060,7 @@ class Agent(BaseAgent):
|
||||
|
||||
else:
|
||||
logger.error(f"step() failed with an unrecognized exception: '{str(e)}'")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:
|
||||
|
||||
274
letta/dynamic_multi_agent.py
Normal file
274
letta/dynamic_multi_agent.py
Normal file
@@ -0,0 +1,274 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.tool_manager import ToolManager
|
||||
|
||||
|
||||
class DynamicMultiAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
description: str = "",
|
||||
max_turns: Optional[int] = None,
|
||||
termination_token: str = "DONE!",
|
||||
):
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.group_id = group_id
|
||||
self.agent_ids = agent_ids
|
||||
self.description = description
|
||||
self.max_turns = max_turns or len(agent_ids)
|
||||
self.termination_token = termination_token
|
||||
|
||||
self.tool_manager = ToolManager()
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
agents = {}
|
||||
message_index = {self.agent_state.id: 0}
|
||||
agents[self.agent_state.id] = self.load_manager_agent()
|
||||
for agent_id in self.agent_ids:
|
||||
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
|
||||
message_index[agent_id] = 0
|
||||
|
||||
chat_history: List[Message] = []
|
||||
new_messages = messages
|
||||
speaker_id = None
|
||||
try:
|
||||
for _ in range(self.max_turns):
|
||||
agent_id_options = [agent_id for agent_id in self.agent_ids if agent_id != speaker_id]
|
||||
manager_message = self.ask_manager_to_choose_participant_message(new_messages, chat_history, agent_id_options)
|
||||
manager_agent = agents[self.agent_state.id]
|
||||
usage_stats = manager_agent.step(
|
||||
messages=[manager_message],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages)
|
||||
assistant_message = [response for response in responses if response.message_type == "assistant_message"][0]
|
||||
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
|
||||
if name.lower() in assistant_message.content.lower():
|
||||
speaker_id = agent_id
|
||||
|
||||
# sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# initialize input messages
|
||||
for message in chat_history[message_index[speaker_id] :]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = speaker_id
|
||||
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=speaker_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
|
||||
# load agent and perform step
|
||||
participant_agent = agents[speaker_id]
|
||||
usage_stats = participant_agent.step(
|
||||
messages=chat_history[message_index[speaker_id] :],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# parse new messages for next step
|
||||
responses = Message.to_letta_messages_from_list(
|
||||
participant_agent.last_response_messages,
|
||||
)
|
||||
|
||||
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
|
||||
new_messages = [
|
||||
MessageCreate(
|
||||
role="system",
|
||||
content=message.content,
|
||||
name=participant_agent.agent_state.name,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
message_index[agent_id] = len(chat_history) + len(new_messages)
|
||||
|
||||
# sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# check for termination token
|
||||
if any(self.termination_token in message.content for message in new_messages):
|
||||
break
|
||||
|
||||
# persist remaining chat history
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
for agent_id, index in message_index.items():
|
||||
if agent_id == speaker_id:
|
||||
continue
|
||||
for message in chat_history[index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = agent_id
|
||||
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
self.interface.step_yield()
|
||||
|
||||
self.interface.step_complete()
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
||||
def load_manager_agent(self) -> Agent:
|
||||
for participant_agent_id in self.agent_ids:
|
||||
participant_agent_state = self.agent_manager.get_agent_by_id(agent_id=participant_agent_id, actor=self.user)
|
||||
participant_persona_block = participant_agent_state.memory.get_block(label="persona")
|
||||
new_block = self.block_manager.create_or_update_block(
|
||||
block=Block(
|
||||
label=participant_agent_id,
|
||||
value=participant_persona_block.value,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
self.agent_state = self.agent_manager.update_block_with_label(
|
||||
agent_id=self.agent_state.id,
|
||||
block_label=participant_agent_id,
|
||||
new_block_id=new_block.id,
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
persona_block = self.agent_state.memory.get_block(label="persona")
|
||||
group_chat_manager_persona = (
|
||||
f"You are overseeing a group chat with {len(self.agent_ids) - 1} agents and "
|
||||
f"one user. Description of the group: {self.description}\n"
|
||||
"On each turn, you will be provided with the chat history and latest message. "
|
||||
"Your task is to decide which participant should speak next in the chat based "
|
||||
"on the chat history. Each agent has a memory block labeled with their ID which "
|
||||
"holds info about them, and you should use this context to inform your decision."
|
||||
)
|
||||
self.agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_manager_persona)
|
||||
return Agent(
|
||||
agent_state=self.agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
save_last_response=True,
|
||||
)
|
||||
|
||||
def load_participant_agent(self, agent_id: str) -> Agent:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
persona_block = agent_state.memory.get_block(label="persona")
|
||||
group_chat_participant_persona = (
|
||||
f"You are a participant in a group chat with {len(self.agent_ids) - 1} other "
|
||||
"agents and one user. Respond to new messages in the group chat when prompted. "
|
||||
f"Description of the group: {self.description}. About you: "
|
||||
)
|
||||
agent_state.memory.update_block_value(label="persona", value=group_chat_participant_persona + persona_block.value)
|
||||
return Agent(
|
||||
agent_state=agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
save_last_response=True,
|
||||
)
|
||||
|
||||
'''
|
||||
def attach_choose_next_participant_tool(self) -> AgentState:
|
||||
def choose_next_participant(next_speaker_agent_id: str) -> str:
|
||||
"""
|
||||
Returns ID of the agent in the group chat that should reply to the latest message in the conversation. The agent ID will always be in the format: `agent-{UUID}`.
|
||||
Args:
|
||||
next_speaker_agent_id (str): The ID of the agent that is most suitable to be the next speaker.
|
||||
Returns:
|
||||
str: The ID of the agent that should be the next speaker.
|
||||
"""
|
||||
return next_speaker_agent_id
|
||||
source_code = parse_source_code(choose_next_participant)
|
||||
tool = self.tool_manager.create_or_update_tool(
|
||||
Tool(
|
||||
source_type="python",
|
||||
source_code=source_code,
|
||||
name="choose_next_participant",
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
return self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=tool.id, actor=self.user)
|
||||
'''
|
||||
|
||||
def ask_manager_to_choose_participant_message(
|
||||
self,
|
||||
new_messages: List[MessageCreate],
|
||||
chat_history: List[Message],
|
||||
agent_id_options: List[str],
|
||||
) -> Message:
|
||||
chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
|
||||
for message in new_messages:
|
||||
chat_history.append(f"{message.name or 'user'}: {message.content}")
|
||||
context_messages = "\n".join(chat_history)
|
||||
|
||||
message_text = (
|
||||
"Choose the most suitable agent to reply to the latest message in the "
|
||||
f"group chat from the following options: {agent_id_options}. Do not "
|
||||
"respond to the messages yourself, your task is only to decide the "
|
||||
f"next speaker, not to participate. \nChat history:\n{context_messages}"
|
||||
)
|
||||
return Message(
|
||||
agent_id=self.agent_state.id,
|
||||
role="user",
|
||||
content=[TextContent(text=message_text)],
|
||||
name=None,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List
|
||||
|
||||
from letta.functions.helpers import (
|
||||
_send_message_to_agents_matching_tags_async,
|
||||
_send_message_to_all_agents_in_group_async,
|
||||
execute_send_message_to_agent,
|
||||
fire_and_forget_send_to_agent,
|
||||
)
|
||||
@@ -86,3 +87,19 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
|
||||
"""
|
||||
|
||||
return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some))
|
||||
|
||||
|
||||
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:
|
||||
"""
|
||||
Sends a message to all agents within the same multi-agent group.
|
||||
|
||||
Args:
|
||||
message (str): The content of the message to be sent to each matching agent.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of responses from the agents that matched the filtering criteria. Each
|
||||
response corresponds to a single agent. Agents that do not respond will not have an entry
|
||||
in the returned list.
|
||||
"""
|
||||
|
||||
return asyncio.run(_send_message_to_all_agents_in_group_async(self, message))
|
||||
|
||||
@@ -604,6 +604,47 @@ async def _send_message_to_agents_matching_tags_async(
|
||||
return final
|
||||
|
||||
|
||||
async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", message: str) -> List[str]:
|
||||
server = get_letta_server()
|
||||
|
||||
augmented_message = (
|
||||
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
worker_agents_ids = sender_agent.agent_state.multi_agent_group.agent_ids
|
||||
worker_agents = [server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=sender_agent.user) for agent_id in worker_agents_ids]
|
||||
|
||||
# Create a system message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
|
||||
|
||||
# Possibly limit concurrency to avoid meltdown:
|
||||
sem = asyncio.Semaphore(settings.multi_agent_concurrent_sends)
|
||||
|
||||
async def _send_single(agent_state):
|
||||
async with sem:
|
||||
return await async_send_message_with_retries(
|
||||
server=server,
|
||||
sender_agent=sender_agent,
|
||||
target_agent_id=agent_state.id,
|
||||
messages=messages,
|
||||
max_retries=3,
|
||||
timeout=settings.multi_agent_send_message_timeout,
|
||||
)
|
||||
|
||||
tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in worker_agents]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
final = []
|
||||
for r in results:
|
||||
if isinstance(r, Exception):
|
||||
final.append(str(r))
|
||||
else:
|
||||
final.append(r)
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def generate_model_from_args_json_schema(schema: Dict[str, Any]) -> Type[BaseModel]:
|
||||
"""Creates a Pydantic model from a JSON schema.
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.groups_agents import GroupsAgents
|
||||
from letta.orm.identities_agents import IdentitiesAgents
|
||||
from letta.orm.identities_blocks import IdentitiesBlocks
|
||||
from letta.orm.identity import Identity
|
||||
|
||||
@@ -128,11 +128,25 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
back_populates="agents",
|
||||
passive_deletes=True,
|
||||
)
|
||||
groups: Mapped[List["Group"]] = relationship(
|
||||
"Group",
|
||||
secondary="groups_agents",
|
||||
lazy="selectin",
|
||||
back_populates="agents",
|
||||
passive_deletes=True,
|
||||
)
|
||||
multi_agent_group: Mapped["Group"] = relationship(
|
||||
"Group",
|
||||
lazy="joined",
|
||||
viewonly=True,
|
||||
back_populates="manager_agent",
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> PydanticAgentState:
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
# add default rule for having send_message be a terminal tool
|
||||
tool_rules = self.tool_rules
|
||||
multi_agent_group = self.multi_agent_group
|
||||
state = {
|
||||
"id": self.id,
|
||||
"organization_id": self.organization_id,
|
||||
@@ -159,6 +173,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"base_template_id": self.base_template_id,
|
||||
"identity_ids": [identity.id for identity in self.identities],
|
||||
"message_buffer_autoclear": self.message_buffer_autoclear,
|
||||
"multi_agent_group": multi_agent_group,
|
||||
}
|
||||
|
||||
return self.__pydantic_model__(**state)
|
||||
|
||||
33
letta/orm/group.py
Normal file
33
letta/orm/group.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
|
||||
|
||||
class Group(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
__tablename__ = "groups"
|
||||
__pydantic_model__ = PydanticGroup
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"group-{uuid.uuid4()}")
|
||||
description: Mapped[str] = mapped_column(nullable=False, doc="")
|
||||
manager_type: Mapped[str] = mapped_column(nullable=False, doc="")
|
||||
manager_agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id", ondelete="RESTRICT"), nullable=True, doc="")
|
||||
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||
)
|
||||
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")
|
||||
|
||||
@property
|
||||
def agent_ids(self) -> List[str]:
|
||||
return [agent.id for agent in self.agents]
|
||||
13
letta/orm/groups_agents.py
Normal file
13
letta/orm/groups_agents.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class GroupsAgents(Base):
|
||||
"""Agents may have one or many groups associated with them."""
|
||||
|
||||
__tablename__ = "groups_agents"
|
||||
|
||||
group_id: Mapped[str] = mapped_column(String, ForeignKey("groups.id", ondelete="CASCADE"), primary_key=True)
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
|
||||
@@ -36,6 +36,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
tool_returns: Mapped[List[ToolReturn]] = mapped_column(
|
||||
ToolReturnColumn, nullable=True, doc="Tool execution return information for prior tool calls"
|
||||
)
|
||||
group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in")
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
|
||||
@@ -49,6 +49,7 @@ class Organization(SqlalchemyBase):
|
||||
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
|
||||
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
|
||||
identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan")
|
||||
groups: Mapped[List["Group"]] = relationship("Group", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@property
|
||||
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||
|
||||
@@ -139,11 +139,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
else:
|
||||
# Match ANY tag - use join and filter
|
||||
query = (
|
||||
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id)
|
||||
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
|
||||
) # Deduplicate results
|
||||
|
||||
# Group by primary key and all necessary columns to avoid JSON comparison
|
||||
query = query.group_by(cls.id)
|
||||
# select distinct primary key
|
||||
query = query.distinct(cls.id).order_by(cls.id)
|
||||
|
||||
if identifier_keys and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
|
||||
152
letta/round_robin_multi_agent.py
Normal file
152
letta/round_robin_multi_agent.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
|
||||
class RoundRobinMultiAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
description: str = "",
|
||||
max_turns: Optional[int] = None,
|
||||
):
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.group_id = group_id
|
||||
self.agent_ids = agent_ids
|
||||
self.description = description
|
||||
self.max_turns = max_turns or len(agent_ids)
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
agents = {}
|
||||
for agent_id in self.agent_ids:
|
||||
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
|
||||
|
||||
message_index = {}
|
||||
chat_history: List[Message] = []
|
||||
new_messages = messages
|
||||
speaker_id = None
|
||||
try:
|
||||
for i in range(self.max_turns):
|
||||
speaker_id = self.agent_ids[i % len(self.agent_ids)]
|
||||
# initialize input messages
|
||||
start_index = message_index[speaker_id] if speaker_id in message_index else 0
|
||||
for message in chat_history[start_index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = speaker_id
|
||||
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=speaker_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
|
||||
# load agent and perform step
|
||||
participant_agent = agents[speaker_id]
|
||||
usage_stats = participant_agent.step(
|
||||
messages=chat_history[start_index:],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# parse new messages for next step
|
||||
responses = Message.to_letta_messages_from_list(participant_agent.last_response_messages)
|
||||
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
|
||||
new_messages = [
|
||||
MessageCreate(
|
||||
role="system",
|
||||
content=message.content,
|
||||
name=participant_agent.agent_state.name,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
message_index[speaker_id] = len(chat_history) + len(new_messages)
|
||||
|
||||
# sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# persist remaining chat history
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
for agent_id, index in message_index.items():
|
||||
if agent_id == speaker_id:
|
||||
continue
|
||||
for message in chat_history[index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = agent_id
|
||||
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
self.interface.step_yield()
|
||||
|
||||
self.interface.step_complete()
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
||||
def load_participant_agent(self, agent_id: str) -> Agent:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
persona_block = agent_state.memory.get_block(label="persona")
|
||||
group_chat_participant_persona = (
|
||||
"\n\n====Group Chat Contex===="
|
||||
f"\nYou are speaking in a group chat with {len(self.agent_ids) - 1} other "
|
||||
"agents and one user. Respond to new messages in the group chat when prompted. "
|
||||
f"Description of the group: {self.description}"
|
||||
)
|
||||
agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_participant_persona)
|
||||
return Agent(
|
||||
agent_state=agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
save_last_response=True,
|
||||
)
|
||||
@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
@@ -90,6 +91,8 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
)
|
||||
|
||||
multi_agent_group: Optional[Group] = Field(None, description="The multi-agent group that this agent manages")
|
||||
|
||||
def get_agent_env_vars_as_dict(self) -> Dict[str, str]:
|
||||
# Get environment variables for this agent specifically
|
||||
per_agent_env_vars = {}
|
||||
|
||||
65
letta/schemas/group.py
Normal file
65
letta/schemas/group.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class ManagerType(str, Enum):
|
||||
round_robin = "round_robin"
|
||||
supervisor = "supervisor"
|
||||
dynamic = "dynamic"
|
||||
swarm = "swarm"
|
||||
|
||||
|
||||
class GroupBase(LettaBase):
|
||||
__id_prefix__ = "group"
|
||||
|
||||
|
||||
class Group(GroupBase):
|
||||
id: str = Field(..., description="The id of the group. Assigned by the database.")
|
||||
manager_type: ManagerType = Field(..., description="")
|
||||
agent_ids: List[str] = Field(..., description="")
|
||||
description: str = Field(..., description="")
|
||||
# Pattern fields
|
||||
manager_agent_id: Optional[str] = Field(None, description="")
|
||||
termination_token: Optional[str] = Field(None, description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class ManagerConfig(BaseModel):
|
||||
manager_type: ManagerType = Field(..., description="")
|
||||
|
||||
|
||||
class RoundRobinManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.round_robin] = Field(ManagerType.round_robin, description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class SupervisorManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.supervisor] = Field(ManagerType.supervisor, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
|
||||
|
||||
class DynamicManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.dynamic] = Field(ManagerType.dynamic, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
termination_token: Optional[str] = Field("DONE!", description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
# class SwarmGroup(ManagerConfig):
|
||||
# manager_type: Literal[ManagerType.swarm] = Field(ManagerType.swarm, description="")
|
||||
|
||||
|
||||
ManagerConfigUnion = Annotated[
|
||||
Union[RoundRobinManager, SupervisorManager, DynamicManager],
|
||||
Field(discriminator="manager_type"),
|
||||
]
|
||||
|
||||
|
||||
class GroupCreate(BaseModel):
|
||||
agent_ids: List[str] = Field(..., description="")
|
||||
description: str = Field(..., description="")
|
||||
manager_config: Optional[ManagerConfigUnion] = Field(None, description="")
|
||||
@@ -129,6 +129,7 @@ class Message(BaseMessage):
|
||||
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
|
||||
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
|
||||
tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls")
|
||||
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
|
||||
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from letta.server.rest_api.routers.v1.agents import router as agents_router
|
||||
from letta.server.rest_api.routers.v1.blocks import router as blocks_router
|
||||
from letta.server.rest_api.routers.v1.groups import router as groups_router
|
||||
from letta.server.rest_api.routers.v1.health import router as health_router
|
||||
from letta.server.rest_api.routers.v1.identities import router as identities_router
|
||||
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
|
||||
@@ -17,6 +18,7 @@ ROUTERS = [
|
||||
tools_router,
|
||||
sources_router,
|
||||
agents_router,
|
||||
groups_router,
|
||||
identities_router,
|
||||
llm_router,
|
||||
blocks_router,
|
||||
|
||||
233
letta/server/rest_api/routers/v1/groups.py
Normal file
233
letta/server/rest_api/routers/v1/groups.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.group import Group, GroupCreate, ManagerType
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/groups", tags=["groups"])
|
||||
|
||||
|
||||
@router.post("/", response_model=Group, operation_id="create_group")
|
||||
async def create_group(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: GroupCreate = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a multi-agent group with a specified management pattern. When no
|
||||
management config is specified, this endpoint will use round robin for
|
||||
speaker selection.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(request, actor=actor)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Group], operation_id="list_groups")
|
||||
def list_groups(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
manager_type: Optional[ManagerType] = Query(None, description="Search groups by manager type"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(None, description="Limit for pagination"),
|
||||
project_id: Optional[str] = Query(None, description="Search groups by project id"),
|
||||
):
|
||||
"""
|
||||
Fetch all multi-agent groups matching query.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.list_groups(
|
||||
project_id=project_id,
|
||||
manager_type=manager_type,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=Group, operation_id="create_group")
|
||||
def create_group(
|
||||
group: GroupCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
"""
|
||||
Create a new multi-agent group with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(group, actor=actor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/", response_model=Group, operation_id="upsert_group")
|
||||
def upsert_group(
|
||||
group: GroupCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
"""
|
||||
Create a new multi-agent group with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(group, actor=actor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{group_id}", response_model=None, operation_id="delete_group")
|
||||
def delete_group(
|
||||
group_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a multi-agent group.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
server.group_manager.delete_group(group_id=group_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Group id={group_id} successfully deleted"})
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Group id={group_id} not found for user_id={actor.id}.")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{group_id}/messages",
|
||||
response_model=LettaResponse,
|
||||
operation_id="send_group_message",
|
||||
)
|
||||
async def send_group_message(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the group's response.
|
||||
This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{group_id}/messages/stream",
|
||||
response_model=None,
|
||||
operation_id="send_group_message_streaming",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def send_group_message_streaming(
|
||||
group_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the group's responses.
|
||||
This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern.
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
GroupMessagesResponse = Annotated[
|
||||
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages")
|
||||
def list_group_messages(
|
||||
group_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
|
||||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.group_manager.list_group_messages(
|
||||
group_id=group_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
@router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages")
|
||||
def reset_group_messages(
|
||||
group_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Resets the messages for all agents that are part of the multi-agent group.
|
||||
TODO: only delete group messages not all messages!
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
agent_ids = group.agent_ids
|
||||
if group.manager_agent_id:
|
||||
agent_ids.append(group.manager_agent_id)
|
||||
for agent_id in agent_ids:
|
||||
server.agent_manager.reset_messages(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
add_default_initial_messages=add_default_initial_messages,
|
||||
)
|
||||
'''
|
||||
@@ -19,6 +19,7 @@ import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.dynamic_multi_agent import DynamicMultiAgent
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.mcp_helpers import (
|
||||
@@ -37,6 +38,7 @@ from letta.interface import CLIInterface # for printing to terminal
|
||||
from letta.log import get_logger
|
||||
from letta.offline_memory_agent import OfflineMemoryAgent
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.round_robin_multi_agent import RoundRobinMultiAgent
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -44,6 +46,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
@@ -80,6 +83,7 @@ from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.utils import sse_async_generator
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.group_manager import GroupManager
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
@@ -94,6 +98,7 @@ from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
from letta.supervisor_multi_agent import SupervisorMultiAgent
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@@ -207,6 +212,7 @@ class SyncServer(Server):
|
||||
self.provider_manager = ProviderManager()
|
||||
self.step_manager = StepManager()
|
||||
self.identity_manager = IdentityManager()
|
||||
self.group_manager = GroupManager()
|
||||
|
||||
# Managers that interface with parallelism
|
||||
self.per_agent_lock_manager = PerAgentLockManager()
|
||||
@@ -353,6 +359,8 @@ class SyncServer(Server):
|
||||
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
|
||||
with agent_lock:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.multi_agent_group:
|
||||
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
|
||||
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
@@ -364,6 +372,46 @@ class SyncServer(Server):
|
||||
|
||||
return agent
|
||||
|
||||
def load_multi_agent(
|
||||
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
|
||||
) -> Agent:
|
||||
match group.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor)
|
||||
return RoundRobinMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
max_turns=group.max_turns,
|
||||
)
|
||||
case ManagerType.dynamic:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
|
||||
return DynamicMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
max_turns=group.max_turns,
|
||||
termination_token=group.termination_token,
|
||||
)
|
||||
case ManagerType.supervisor:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
|
||||
return SupervisorMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Type {group.manager_type} is not supported.")
|
||||
|
||||
def _step(
|
||||
self,
|
||||
actor: User,
|
||||
@@ -1403,3 +1451,106 @@ class SyncServer(Server):
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@trace_method
|
||||
async def send_group_message_to_agent(
|
||||
self,
|
||||
group_id: str,
|
||||
actor: User,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
stream_steps: bool,
|
||||
stream_tokens: bool,
|
||||
chat_completion_mode: bool = False,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Union[StreamingResponse, LettaResponse]:
|
||||
include_final_message = True
|
||||
if not stream_steps and stream_tokens:
|
||||
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
|
||||
|
||||
try:
|
||||
# fetch the group
|
||||
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
|
||||
|
||||
llm_config = letta_multi_agent.agent_state.llm_config
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"]
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
|
||||
):
|
||||
warnings.warn(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
)
|
||||
stream_tokens = False
|
||||
|
||||
# Create a new interface per request
|
||||
letta_multi_agent.interface = StreamingServerInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
inner_thoughts_in_kwargs=(
|
||||
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
|
||||
),
|
||||
)
|
||||
streaming_interface = letta_multi_agent.interface
|
||||
if not isinstance(streaming_interface, StreamingServerInterface):
|
||||
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
||||
streaming_interface.streaming_mode = stream_tokens
|
||||
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
|
||||
if metadata and hasattr(streaming_interface, "metadata"):
|
||||
streaming_interface.metadata = metadata
|
||||
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
letta_multi_agent.step,
|
||||
messages=messages,
|
||||
chaining=self.chaining,
|
||||
max_chaining_steps=self.max_chaining_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
# return a stream
|
||||
return StreamingResponse(
|
||||
sse_async_generator(
|
||||
streaming_interface.get_generator(),
|
||||
usage_task=task,
|
||||
finish_message=include_final_message,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
else:
|
||||
# buffer the stream, then return the list
|
||||
generated_stream = []
|
||||
async for message in streaming_interface.get_generator():
|
||||
assert (
|
||||
isinstance(message, LettaMessage)
|
||||
or isinstance(message, LegacyLettaMessage)
|
||||
or isinstance(message, MessageStreamStatus)
|
||||
), type(message)
|
||||
generated_stream.append(message)
|
||||
if message == MessageStreamStatus.done:
|
||||
break
|
||||
|
||||
# Get rid of the stream status messages
|
||||
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
|
||||
usage = await task
|
||||
|
||||
# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
|
||||
# If we want to convert these to Message, we can use the attached IDs
|
||||
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
|
||||
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
|
||||
return LettaResponse(messages=filtered_stream, usage=usage)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@@ -399,7 +399,7 @@ class AgentManager:
|
||||
# Ensures agents match at least one tag in match_some
|
||||
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
|
||||
|
||||
query = query.group_by(AgentModel.id).limit(limit)
|
||||
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
|
||||
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@@ -434,6 +434,7 @@ class AgentManager:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
# TODO check if it is managing a group
|
||||
agent.hard_delete(session)
|
||||
|
||||
@enforce_types
|
||||
|
||||
147
letta/services/group_manager.py
Normal file
147
letta/services/group_manager.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.group import Group as GroupModel
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
from letta.schemas.group import GroupCreate, ManagerType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class GroupManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def list_groups(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
manager_type: Optional[ManagerType] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> list[PydanticGroup]:
|
||||
with self.session_maker() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
if manager_type:
|
||||
filters["manager_type"] = manager_type
|
||||
groups = GroupModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
**filters,
|
||||
)
|
||||
return [group.to_pydantic() for group in groups]
|
||||
|
||||
@enforce_types
|
||||
def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
|
||||
with self.session_maker() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
return group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup:
|
||||
with self.session_maker() as session:
|
||||
new_group = GroupModel()
|
||||
new_group.organization_id = actor.organization_id
|
||||
new_group.description = group.description
|
||||
self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False)
|
||||
if group.manager_config is None:
|
||||
new_group.manager_type = ManagerType.round_robin
|
||||
else:
|
||||
match group.manager_config.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
new_group.manager_type = ManagerType.round_robin
|
||||
new_group.max_turns = group.manager_config.max_turns
|
||||
case ManagerType.dynamic:
|
||||
new_group.manager_type = ManagerType.dynamic
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
new_group.max_turns = group.manager_config.max_turns
|
||||
new_group.termination_token = group.manager_config.termination_token
|
||||
case ManagerType.supervisor:
|
||||
new_group.manager_type = ManagerType.supervisor
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
|
||||
new_group.create(session, actor=actor)
|
||||
return new_group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
group.hard_delete(session)
|
||||
|
||||
@enforce_types
|
||||
def list_group_messages(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = "send_message",
|
||||
assistant_message_tool_kwarg: str = "message",
|
||||
) -> list[PydanticGroup]:
|
||||
with self.session_maker() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
agent_id = group.manager_agent_id if group.manager_agent_id else group.agent_ids[0]
|
||||
|
||||
filters = {
|
||||
"organization_id": actor.organization_id,
|
||||
"group_id": group_id,
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
messages = MessageModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
**filters,
|
||||
)
|
||||
|
||||
messages = PydanticMessage.to_letta_messages_from_list(
|
||||
messages=messages,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
|
||||
current_relationship = getattr(group, "agents", [])
|
||||
if not agent_ids:
|
||||
if replace:
|
||||
setattr(group, "agents", [])
|
||||
return
|
||||
|
||||
# Retrieve models for the provided IDs
|
||||
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
|
||||
|
||||
# Validate all items are found if allow_partial is False
|
||||
if not allow_partial and len(found_items) != len(agent_ids):
|
||||
missing = set(agent_ids) - {item.id for item in found_items}
|
||||
raise NoResultFound(f"Items not found in agents: {missing}")
|
||||
|
||||
if replace:
|
||||
# Replace the relationship
|
||||
setattr(group, "agents", found_items)
|
||||
else:
|
||||
# Extend the relationship (only add new items)
|
||||
current_ids = {item.id for item in current_relationship}
|
||||
new_items = [item for item in found_items if item.id not in current_ids]
|
||||
current_relationship.extend(new_items)
|
||||
103
letta/supervisor_multi_agent.py
Normal file
103
letta/supervisor_multi_agent.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL
|
||||
from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from tests.helpers.utils import create_tool_from_func
|
||||
|
||||
|
||||
class SupervisorMultiAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
description: str = "",
|
||||
):
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.group_id = group_id
|
||||
self.agent_ids = agent_ids
|
||||
self.description = description
|
||||
self.agent_manager = AgentManager()
|
||||
self.tool_manager = ToolManager()
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
# add multi agent tool
|
||||
if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
|
||||
multi_agent_tool = create_tool_from_func(send_message_to_all_agents_in_group)
|
||||
multi_agent_tool.tool_type = ToolType.LETTA_MULTI_AGENT_CORE
|
||||
multi_agent_tool = self.tool_manager.create_or_update_tool(
|
||||
pydantic_tool=multi_agent_tool,
|
||||
actor=self.user,
|
||||
)
|
||||
self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
|
||||
|
||||
# override tool rules
|
||||
self.agent_state.tool_rules = [
|
||||
InitToolRule(
|
||||
tool_name="send_message_to_all_agents_in_group",
|
||||
),
|
||||
TerminalToolRule(
|
||||
tool_name=assistant_message_tool_name,
|
||||
),
|
||||
ChildToolRule(
|
||||
tool_name="send_message_to_all_agents_in_group",
|
||||
children=[assistant_message_tool_name],
|
||||
),
|
||||
]
|
||||
|
||||
supervisor_messages = [
|
||||
Message(
|
||||
agent_id=self.agent_state.id,
|
||||
role="user",
|
||||
content=[TextContent(text=message.content)],
|
||||
name=None,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
try:
|
||||
supervisor_agent = Agent(agent_state=self.agent_state, interface=self.interface, user=self.user)
|
||||
usage_stats = supervisor_agent.step(
|
||||
messages=supervisor_messages,
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
self.interface.step_yield()
|
||||
|
||||
self.interface.step_complete()
|
||||
|
||||
return usage_stats
|
||||
233
tests/test_multi_agent.py
Normal file
233
tests/test_multi_agent.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Provider, Step
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.group import DynamicManager, GroupCreate, SupervisorManager
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def org_id(server):
|
||||
org = server.organization_manager.create_default_organization()
|
||||
|
||||
yield org.id
|
||||
|
||||
# cleanup
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
server.organization_manager.delete_organization_by_id(org.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def actor(server, org_id):
|
||||
user = server.user_manager.create_default_user()
|
||||
yield user
|
||||
|
||||
# cleanup
|
||||
server.user_manager.delete_user_by_id(user.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def participant_agent_ids(server, actor):
|
||||
agent_fred = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="fred",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="Your name is fred and you like to ski and have been wanting to go on a ski trip soon. You are speaking in a group chat with other agent pals where you participate in friendly banter.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent_velma = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="velma",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="Your name is velma and you like tropical locations. You are speaking in a group chat with other agent friends and you love to include everyone.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent_daphne = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="daphne",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="Your name is daphne and you love traveling abroad. You are speaking in a group chat with other agent friends and you love to keep in touch with them.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent_shaggy = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="shaggy",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="Your name is shaggy and your best friend is your dog, scooby. You are speaking in a group chat with other agent friends and you like to solve mysteries with them.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
yield [agent_fred.id, agent_velma.id, agent_daphne.id, agent_shaggy.id]
|
||||
|
||||
# cleanup
|
||||
server.agent_manager.delete_agent(agent_fred.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_velma.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_daphne.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_shaggy.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def manager_agent_id(server, actor):
|
||||
agent_scooby = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="scooby",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a puppy operations agent for Letta and you help run multi-agent group chats. Your job is to get to know the agents in your group and pick who is best suited to speak next in the conversation.",
|
||||
),
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
yield agent_scooby.id
|
||||
|
||||
# cleanup
|
||||
server.agent_manager.delete_agent(agent_scooby.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin(server, actor, participant_agent_ids):
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="what is everyone up to for the holidays?",
|
||||
),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == len(participant_agent_ids)
|
||||
assert len(response.messages) == response.usage.step_count * 2
|
||||
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor(server, actor, manager_agent_id, participant_agent_ids):
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
manager_config=SupervisorManager(
|
||||
manager_agent_id=manager_agent_id,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.",
|
||||
),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == 2
|
||||
assert len(response.messages) == 5
|
||||
|
||||
# verify tool call
|
||||
assert response.messages[0].message_type == "reasoning_message"
|
||||
assert (
|
||||
response.messages[1].message_type == "tool_call_message"
|
||||
and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group"
|
||||
)
|
||||
assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len(
|
||||
participant_agent_ids
|
||||
)
|
||||
assert response.messages[3].message_type == "reasoning_message"
|
||||
assert response.messages[4].message_type == "assistant_message"
|
||||
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_agent_ids):
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
manager_config=DynamicManager(
|
||||
manager_agent_id=manager_agent_id,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(role="user", content="what is everyone up to for the holidays?"),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == len(participant_agent_ids) * 2
|
||||
assert len(response.messages) == response.usage.step_count * 2
|
||||
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
Reference in New Issue
Block a user