Files
letta-server/letta/groups/supervisor_multi_agent.py
2025-10-07 17:50:45 -07:00

121 lines
4.5 KiB
Python

from typing import List, Optional
from letta.agents.base_agent import BaseAgent
from letta.constants import DEFAULT_MESSAGE_TOOL
from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.interface import AgentInterface
from letta.orm import User
from letta.schemas.agent import AgentState
from letta.schemas.enums import ToolType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import MessageCreate
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager
from letta.services.tool_manager import ToolManager
class SupervisorMultiAgent(BaseAgent):
def __init__(
self,
interface: AgentInterface,
agent_state: AgentState,
user: User,
# custom
group_id: str = "",
agent_ids: List[str] = [],
description: str = "",
):
super().__init__(interface, agent_state, user)
self.group_id = group_id
self.agent_ids = agent_ids
self.description = description
self.agent_manager = AgentManager()
self.tool_manager = ToolManager()
#
# def step(
# self,
# input_messages: List[MessageCreate],
# chaining: bool = True,
# max_chaining_steps: Optional[int] = None,
# put_inner_thoughts_first: bool = True,
# assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
# **kwargs,
# ) -> LettaUsageStatistics:
# # Load settings
# token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
# metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
#
# # Prepare supervisor agent
# if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
# multi_agent_tool = Tool(
# name=send_message_to_all_agents_in_group.__name__,
# description="",
# source_type="python",
# tags=[],
# source_code=parse_source_code(send_message_to_all_agents_in_group),
# json_schema=generate_schema(send_message_to_all_agents_in_group, None),
# )
# multi_agent_tool.tool_type = ToolType.LETTA_MULTI_AGENT_CORE
# multi_agent_tool = self.tool_manager.create_or_update_tool(
# pydantic_tool=multi_agent_tool,
# actor=self.user,
# )
# self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
#
# old_tool_rules = self.agent_state.tool_rules
# self.agent_state.tool_rules = [
# InitToolRule(
# tool_name="send_message_to_all_agents_in_group",
# ),
# TerminalToolRule(
# tool_name=assistant_message_tool_name,
# ),
# ChildToolRule(
# tool_name="send_message_to_all_agents_in_group",
# children=[assistant_message_tool_name],
# ),
# ]
#
# # Prepare new messages
# new_messages = []
# for message in input_messages:
# if isinstance(message.content, str):
# message.content = [TextContent(text=message.content)]
# message.group_id = self.group_id
# new_messages.append(message)
#
# try:
# # Load supervisor agent
# supervisor_agent = Agent(
# agent_state=self.agent_state,
# interface=self.interface,
# user=self.user,
# )
#
# # Perform supervisor step
# usage_stats = supervisor_agent.step(
# input_messages=new_messages,
# chaining=chaining,
# max_chaining_steps=max_chaining_steps,
# stream=token_streaming,
# skip_verify=True,
# metadata=metadata,
# put_inner_thoughts_first=put_inner_thoughts_first,
# )
# except Exception as e:
# raise e
# finally:
# self.interface.step_yield()
# self.agent_state.tool_rules = old_tool_rules
#
# self.interface.step_complete()
#
# return usage_stats
#