* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
111 lines
3.9 KiB
Python
111 lines
3.9 KiB
Python
from typing import List
|
|
|
|
from letta.agents.base_agent import BaseAgent
|
|
from letta.interface import AgentInterface
|
|
from letta.orm import User
|
|
from letta.schemas.agent import AgentState
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.tool_manager import ToolManager
|
|
|
|
|
|
class SupervisorMultiAgent(BaseAgent):
|
|
def __init__(
|
|
self,
|
|
interface: AgentInterface,
|
|
agent_state: AgentState,
|
|
user: User,
|
|
# custom
|
|
group_id: str = "",
|
|
agent_ids: List[str] = [],
|
|
description: str = "",
|
|
):
|
|
super().__init__(interface, agent_state, user)
|
|
self.group_id = group_id
|
|
self.agent_ids = agent_ids
|
|
self.description = description
|
|
self.agent_manager = AgentManager()
|
|
self.tool_manager = ToolManager()
|
|
|
|
|
|
#
|
|
# def step(
|
|
# self,
|
|
# input_messages: List[MessageCreate],
|
|
# chaining: bool = True,
|
|
# max_chaining_steps: Optional[int] = None,
|
|
# put_inner_thoughts_first: bool = True,
|
|
# assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
|
# **kwargs,
|
|
# ) -> LettaUsageStatistics:
|
|
# # Load settings
|
|
# token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
|
# metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
|
#
|
|
# # Prepare supervisor agent
|
|
# if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
|
|
# multi_agent_tool = Tool(
|
|
# name=send_message_to_all_agents_in_group.__name__,
|
|
# description="",
|
|
# source_type="python",
|
|
# tags=[],
|
|
# source_code=parse_source_code(send_message_to_all_agents_in_group),
|
|
# json_schema=generate_schema(send_message_to_all_agents_in_group, None),
|
|
# )
|
|
# multi_agent_tool.tool_type = ToolType.LETTA_MULTI_AGENT_CORE
|
|
# multi_agent_tool = self.tool_manager.create_or_update_tool(
|
|
# pydantic_tool=multi_agent_tool,
|
|
# actor=self.user,
|
|
# )
|
|
# self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
|
|
#
|
|
# old_tool_rules = self.agent_state.tool_rules
|
|
# self.agent_state.tool_rules = [
|
|
# InitToolRule(
|
|
# tool_name="send_message_to_all_agents_in_group",
|
|
# ),
|
|
# TerminalToolRule(
|
|
# tool_name=assistant_message_tool_name,
|
|
# ),
|
|
# ChildToolRule(
|
|
# tool_name="send_message_to_all_agents_in_group",
|
|
# children=[assistant_message_tool_name],
|
|
# ),
|
|
# ]
|
|
#
|
|
# # Prepare new messages
|
|
# new_messages = []
|
|
# for message in input_messages:
|
|
# if isinstance(message.content, str):
|
|
# message.content = [TextContent(text=message.content)]
|
|
# message.group_id = self.group_id
|
|
# new_messages.append(message)
|
|
#
|
|
# try:
|
|
# # Load supervisor agent
|
|
# supervisor_agent = Agent(
|
|
# agent_state=self.agent_state,
|
|
# interface=self.interface,
|
|
# user=self.user,
|
|
# )
|
|
#
|
|
# # Perform supervisor step
|
|
# usage_stats = supervisor_agent.step(
|
|
# input_messages=new_messages,
|
|
# chaining=chaining,
|
|
# max_chaining_steps=max_chaining_steps,
|
|
# stream=token_streaming,
|
|
# skip_verify=True,
|
|
# metadata=metadata,
|
|
# put_inner_thoughts_first=put_inner_thoughts_first,
|
|
# )
|
|
# except Exception as e:
|
|
# raise e
|
|
# finally:
|
|
# self.interface.step_yield()
|
|
# self.agent_state.tool_rules = old_tool_rules
|
|
#
|
|
# self.interface.step_complete()
|
|
#
|
|
# return usage_stats
|
|
#
|